diff --git a/src/utils/structures.py b/src/utils/structures.py index 31e8f561f..1be386ffd 100644 --- a/src/utils/structures.py +++ b/src/utils/structures.py @@ -454,5 +454,40 @@ class CacheDict(collections.MutableMapping): def __len__(self): return len(self.d) +class TruncatableSet(collections.MutableSet): + """A set that keeps track of the order of inserted elements so + the oldest can be removed.""" + def __init__(self, iterable): + self._ordered_items = list(iterable) + self._items = set(self._ordered_items) + def __contains__(self, item): + return item in self._items + def __iter__(self): + return iter(self._items) + def __len__(self): + return len(self._items) + def add(self, item): + if item not in self._items: + self._items.add(item) + self._ordered_items.append(item) + def discard(self, item): + self._items.discard(item) + self._ordered_items.remove(item) + def truncate(self, size): + assert size >= 0 + removed_size = len(self)-size + # I make two different cases depending on removed_size