From da328b4985aab1a1956832ae4b372fb4b1ebd875 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Wed, 6 May 2020 20:39:21 +0200 Subject: [PATCH] Expire batches that never ended to avoid leaking memory. --- src/irclib.py | 9 +++-- src/utils/structures.py | 79 ++++++++++++++++++++++++++++++++++++++++- test/test_utils.py | 41 +++++++++++++++++++++ 3 files changed, 125 insertions(+), 4 deletions(-) diff --git a/src/irclib.py b/src/irclib.py index 85c01f3b4..2c756eca3 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -58,7 +58,7 @@ from . import conf, ircdb, ircmsgs, ircutils, log, utils, world from .drivers import Server from .utils.str import rsplit from .utils.iter import chain -from .utils.structures import smallqueue, RingBuffer +from .utils.structures import smallqueue, RingBuffer, TimeoutDict MAX_LINE_SIZE = 512 # Including \r\n @@ -539,7 +539,10 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.history = history self.channels = channels self.nicksToHostmasks = nicksToHostmasks - self.batches = {} + + # Batches should always finish and be way shorter than 3600s, but + # let's just make sure to avoid leaking memory. + self.batches = TimeoutDict(timeout=3600) def reset(self): """Resets the state to normal, unconnected state.""" @@ -550,7 +553,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.channels.clear() self.supported.clear() self.nicksToHostmasks.clear() - self.batches = {} + self.batches.clear() self.capabilities_req = set() self.capabilities_ack = set() self.capabilities_nak = set() diff --git a/src/utils/structures.py b/src/utils/structures.py index f24b9b8b1..55aa7ee6f 100644 --- a/src/utils/structures.py +++ b/src/utils/structures.py @@ -32,6 +32,7 @@ Data structures for Python. """ import time +import threading import collections.abc @@ -446,7 +447,7 @@ class CacheDict(collections.abc.MutableMapping): def keys(self): return self.d.keys() - + def items(self): return self.d.items() @@ -456,6 +457,82 @@ class CacheDict(collections.abc.MutableMapping): def __len__(self): return len(self.d) + +class TimeoutDict(collections.abc.MutableMapping): + """A dictionary that may drop its items when they are too old. + + Currently, this is implemented by internally alternating two "generation" + dicts, which are dropped after a certain time.""" + __slots__ = ('_lock', 'old_gen', 'new_gen', 'timeout', '_last_switch') + __synchronized__ = ('_expire_generations',) + + def __init__(self, timeout, items=None): + self._lock = threading.Lock() + self.old_gen = {} + self.new_gen = {} if items is None else items + self.timeout = timeout + self._last_switch = time.time() + + def __reduce__(self): + return (self.__class__, (self.timeout, dict(self))) + + def __repr__(self): + return 'TimeoutDict(%s, %r)' % (self.timeout, dict(self)) + + def __getitem__(self, key): + try: + # Check the new_gen first, as it contains the most recent + # insertion. + # We must also check them in this order to be thread-safe when + # _expire_generations() runs. + return self.new_gen[key] + except KeyError: + try: + return self.old_gen[key] + except KeyError: + raise KeyError(key) from None + + def __contains__(self, key): + # the two clauses must be in this order to be thread-safe when + # _expire_generations() runs. + return key in self.new_gen or key in self.old_gen + + def __setitem__(self, key, value): + self._expireGenerations() + self.new_gen[key] = value + + def _expireGenerations(self): + with self._lock: + now = time.time() + if self._last_switch + self.timeout < now: + # We last wrote to self.old_gen a long time ago + # (ie. more than self.timeout); let's drop the old_gen and + # make new_gen become the old_gen + # self.old_gen must be written before self.new_gen for + # __getitem__ and __contains__ to be able to run concurrently + # to this function. + self.old_gen = self.new_gen + self.new_gen = {} + self._last_switch = now + + def clear(self): + self.old_gen.clear() + self.new_gen.clear() + + def __delitem__(self, key): + self.old_gen.pop(key, None) + self.new_gen.pop(key, None) + + def __iter__(self): + # order matters + keys = set(self.new_gen.keys()) | set(self.old_gen.keys()) + return iter(keys) + + def __len__(self): + # order matters + return len(set(self.new_gen.keys()) | set(self.old_gen.keys())) + + class TruncatableSet(collections.abc.MutableSet): """A set that keeps track of the order of inserted elements so the oldest can be removed.""" diff --git a/test/test_utils.py b/test/test_utils.py index e6275b67b..b45623152 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1166,6 +1166,47 @@ class TestCacheDict(SupyTestCase): self.assertTrue(i in d) self.assertTrue(d[i] == i) +class TestTimeoutDict(SupyTestCase): + def testInit(self): + d = TimeoutDict(10) + self.assertEqual(dict(d), {}) + d['foo'] = 'bar' + d['baz'] = 'qux' + self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'}) + + def testExpire(self): + d = TimeoutDict(10) + self.assertEqual(dict(d), {}) + d['foo'] = 'bar' + timeFastForward(11) + d['baz'] = 'qux' # Moves 'foo' to the old gen + self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'}) + + timeFastForward(11) + self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'}) + + d['quux'] = 42 # removes the old gen and moves 'baz' to the old gen + self.assertEqual(dict(d), {'baz': 'qux', 'quux': 42}) + + def testEquality(self): + d1 = TimeoutDict(10) + d2 = TimeoutDict(10) + self.assertEqual(d1, d2) + + d1['foo'] = 'bar' + self.assertNotEqual(d1, d2) + + timeFastForward(5) # check they are equal despite the time difference + + d2['foo'] = 'bar' + self.assertEqual(d1, d2) + + timeFastForward(7) + + d1['baz'] = 'qux' # moves 'foo' to the old gen (12 seconds old) + d2['baz'] = 'qux' # does not move it (7 seconds old) + self.assertEqual(d1, d2) + class TestTruncatableSet(SupyTestCase): def testBasics(self): s = TruncatableSet(['foo', 'bar', 'baz', 'qux'])