diff --git a/plugins/Fediverse/plugin.py b/plugins/Fediverse/plugin.py index a438cf189..2bfcfa583 100644 --- a/plugins/Fediverse/plugin.py +++ b/plugins/Fediverse/plugin.py @@ -122,7 +122,7 @@ class Fediverse(callbacks.PluginRegexp): def __init__(self, irc): super().__init__(irc) self._startHttp() - self._actor_cache = utils.structures.TimeoutDict(timeout=600) + self._actor_cache = utils.structures.ExpiringDict(timeout=600) def _startHttp(self): callback = FediverseHttp() diff --git a/src/utils/structures.py b/src/utils/structures.py index 55aa7ee6f..31c810355 100644 --- a/src/utils/structures.py +++ b/src/utils/structures.py @@ -459,7 +459,8 @@ class CacheDict(collections.abc.MutableMapping): class TimeoutDict(collections.abc.MutableMapping): - """A dictionary that may drop its items when they are too old. + """An efficient dictionary that MAY drop its items when they are too old. + For guaranteed expiry, use ExpiringDict. Currently, this is implemented by internally alternating two "generation" dicts, which are dropped after a certain time.""" @@ -533,6 +534,81 @@ class TimeoutDict(collections.abc.MutableMapping): return len(set(self.new_gen.keys()) | set(self.old_gen.keys())) +class ExpiringDict: # Don't inherit from MutableMapping: not thread-safe + """A dictionary that drops its items after they have been in the dict + for a certain time. + + Use TimeoutDict for a more efficient implementation that doesn't require + guaranteed timeout. + """ + __slots__ = ('_lock', 'd', 'timeout') + __synchronized__ = ('_expire_generations',) + + def __init__(self, timeout, items=None): + expiry = time.time() + timeout + self._lock = threading.Lock() + self.d = {k: (expiry, v) for (k, v) in (items or {}).items()} + self.timeout = timeout + + def __reduce__(self): + return (self.__class__, (self.timeout, dict(self))) + + def __repr__(self): + return 'ExpiringDict(%s, %r)' % (self.timeout, dict(self)) + + def __getitem__(self, key): + with self._lock: + try: + (expiry, value) = self.d[key] + if expiry < time.time(): + del self.d[key] + raise KeyError + except KeyError: + raise KeyError(key) from None + + return value + + def __setitem__(self, key, value): + with self._lock: + self.d[key] = (time.time() + self.timeout, value) + + def clear(self): + with self._lock: + self.d.clear() + + def __delitem__(self, key): + with self._lock: + del self.d[key] + + def _items(self): + now = time.time() + with self._lock: + return [ + (k, v) for (k, (expiry, v)) in self.d.items() + if expiry >= now] + + def keys(self): + return [k for (k, v) in self._items()] + + def values(self): + return [v for (k, v) in self._items()] + + def items(self): + return self._items() + + def __iter__(self): + return (k for (k, v) in self._items()) + + def __len__(self): + return len(self._items()) + + def __eq__(self, other): + return self._items() == list(other.items()) + + def __ne__(self, other): + return not (self == other) + + class TruncatableSet(collections.abc.MutableSet): """A set that keeps track of the order of inserted elements so the oldest can be removed.""" diff --git a/test/test_utils.py b/test/test_utils.py index b45623152..f4ba2c0a8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1207,6 +1207,57 @@ class TestTimeoutDict(SupyTestCase): d2['baz'] = 'qux' # does not move it (7 seconds old) self.assertEqual(d1, d2) + +class TestExpiringDict(SupyTestCase): + def testInit(self): + d = ExpiringDict(10) + self.assertEqual(dict(d), {}) + d['foo'] = 'bar' + d['baz'] = 'qux' + self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'}) + + def testExpire(self): + d = ExpiringDict(10) + self.assertEqual(dict(d), {}) + d['foo'] = 'bar' + timeFastForward(11) + d['baz'] = 'qux' + self.assertEqual(dict(d), {'baz': 'qux'}) + + timeFastForward(11) + self.assertEqual(dict(d), {}) + + d['quux'] = 42 + self.assertEqual(dict(d), {'quux': 42}) + + def testEquality(self): + d1 = ExpiringDict(10) + d2 = ExpiringDict(10) + self.assertEqual(d1, d2) + + d1['foo'] = 'bar' + self.assertNotEqual(d1, d2) + + timeFastForward(5) # check they are equal despite the time difference + + d2['foo'] = 'bar' + self.assertEqual(d1, d2) + + timeFastForward(7) + self.assertNotEqual(d1, d2) + self.assertEqual(d1, {}) + self.assertEqual(d2, {'foo': 'bar'}) + + timeFastForward(7) + self.assertEqual(d1, d2) + self.assertEqual(d1, {}) + self.assertEqual(d2, {}) + + d1['baz'] = 'qux' + d2['baz'] = 'qux' + self.assertEqual(d1, d2) + + class TestTruncatableSet(SupyTestCase): def testBasics(self): s = TruncatableSet(['foo', 'bar', 'baz', 'qux'])