Expire batches that never ended to avoid leaking memory.

This commit is contained in:
Valentin Lorentz 2020-05-06 20:39:21 +02:00
parent cc0af4e790
commit da328b4985
3 changed files with 125 additions and 4 deletions

View File

@ -58,7 +58,7 @@ from . import conf, ircdb, ircmsgs, ircutils, log, utils, world
from .drivers import Server from .drivers import Server
from .utils.str import rsplit from .utils.str import rsplit
from .utils.iter import chain from .utils.iter import chain
from .utils.structures import smallqueue, RingBuffer from .utils.structures import smallqueue, RingBuffer, TimeoutDict
MAX_LINE_SIZE = 512 # Including \r\n MAX_LINE_SIZE = 512 # Including \r\n
@ -539,7 +539,10 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
self.history = history self.history = history
self.channels = channels self.channels = channels
self.nicksToHostmasks = nicksToHostmasks self.nicksToHostmasks = nicksToHostmasks
self.batches = {}
# Batches should always finish and be way shorter than 3600s, but
# let's just make sure to avoid leaking memory.
self.batches = TimeoutDict(timeout=3600)
def reset(self): def reset(self):
"""Resets the state to normal, unconnected state.""" """Resets the state to normal, unconnected state."""
@ -550,7 +553,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
self.channels.clear() self.channels.clear()
self.supported.clear() self.supported.clear()
self.nicksToHostmasks.clear() self.nicksToHostmasks.clear()
self.batches = {} self.batches.clear()
self.capabilities_req = set() self.capabilities_req = set()
self.capabilities_ack = set() self.capabilities_ack = set()
self.capabilities_nak = set() self.capabilities_nak = set()

View File

@ -32,6 +32,7 @@ Data structures for Python.
""" """
import time import time
import threading
import collections.abc import collections.abc
@ -446,7 +447,7 @@ class CacheDict(collections.abc.MutableMapping):
def keys(self): def keys(self):
return self.d.keys() return self.d.keys()
def items(self): def items(self):
return self.d.items() return self.d.items()
@ -456,6 +457,82 @@ class CacheDict(collections.abc.MutableMapping):
def __len__(self): def __len__(self):
return len(self.d) return len(self.d)
class TimeoutDict(collections.abc.MutableMapping):
"""A dictionary that may drop its items when they are too old.
Currently, this is implemented by internally alternating two "generation"
dicts, which are dropped after a certain time."""
__slots__ = ('_lock', 'old_gen', 'new_gen', 'timeout', '_last_switch')
__synchronized__ = ('_expire_generations',)
def __init__(self, timeout, items=None):
self._lock = threading.Lock()
self.old_gen = {}
self.new_gen = {} if items is None else items
self.timeout = timeout
self._last_switch = time.time()
def __reduce__(self):
return (self.__class__, (self.timeout, dict(self)))
def __repr__(self):
return 'TimeoutDict(%s, %r)' % (self.timeout, dict(self))
def __getitem__(self, key):
try:
# Check the new_gen first, as it contains the most recent
# insertion.
# We must also check them in this order to be thread-safe when
# _expire_generations() runs.
return self.new_gen[key]
except KeyError:
try:
return self.old_gen[key]
except KeyError:
raise KeyError(key) from None
def __contains__(self, key):
# the two clauses must be in this order to be thread-safe when
# _expire_generations() runs.
return key in self.new_gen or key in self.old_gen
def __setitem__(self, key, value):
self._expireGenerations()
self.new_gen[key] = value
def _expireGenerations(self):
with self._lock:
now = time.time()
if self._last_switch + self.timeout < now:
# We last wrote to self.old_gen a long time ago
# (ie. more than self.timeout); let's drop the old_gen and
# make new_gen become the old_gen
# self.old_gen must be written before self.new_gen for
# __getitem__ and __contains__ to be able to run concurrently
# to this function.
self.old_gen = self.new_gen
self.new_gen = {}
self._last_switch = now
def clear(self):
self.old_gen.clear()
self.new_gen.clear()
def __delitem__(self, key):
self.old_gen.pop(key, None)
self.new_gen.pop(key, None)
def __iter__(self):
# order matters
keys = set(self.new_gen.keys()) | set(self.old_gen.keys())
return iter(keys)
def __len__(self):
# order matters
return len(set(self.new_gen.keys()) | set(self.old_gen.keys()))
class TruncatableSet(collections.abc.MutableSet): class TruncatableSet(collections.abc.MutableSet):
"""A set that keeps track of the order of inserted elements so """A set that keeps track of the order of inserted elements so
the oldest can be removed.""" the oldest can be removed."""

View File

@ -1166,6 +1166,47 @@ class TestCacheDict(SupyTestCase):
self.assertTrue(i in d) self.assertTrue(i in d)
self.assertTrue(d[i] == i) self.assertTrue(d[i] == i)
class TestTimeoutDict(SupyTestCase):
def testInit(self):
d = TimeoutDict(10)
self.assertEqual(dict(d), {})
d['foo'] = 'bar'
d['baz'] = 'qux'
self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'})
def testExpire(self):
d = TimeoutDict(10)
self.assertEqual(dict(d), {})
d['foo'] = 'bar'
timeFastForward(11)
d['baz'] = 'qux' # Moves 'foo' to the old gen
self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'})
timeFastForward(11)
self.assertEqual(dict(d), {'foo': 'bar', 'baz': 'qux'})
d['quux'] = 42 # removes the old gen and moves 'baz' to the old gen
self.assertEqual(dict(d), {'baz': 'qux', 'quux': 42})
def testEquality(self):
d1 = TimeoutDict(10)
d2 = TimeoutDict(10)
self.assertEqual(d1, d2)
d1['foo'] = 'bar'
self.assertNotEqual(d1, d2)
timeFastForward(5) # check they are equal despite the time difference
d2['foo'] = 'bar'
self.assertEqual(d1, d2)
timeFastForward(7)
d1['baz'] = 'qux' # moves 'foo' to the old gen (12 seconds old)
d2['baz'] = 'qux' # does not move it (7 seconds old)
self.assertEqual(d1, d2)
class TestTruncatableSet(SupyTestCase): class TestTruncatableSet(SupyTestCase):
def testBasics(self): def testBasics(self):
s = TruncatableSet(['foo', 'bar', 'baz', 'qux']) s = TruncatableSet(['foo', 'bar', 'baz', 'qux'])