diff --git a/src/irclib.py b/src/irclib.py index 6f20a737a..084f04b6c 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -323,34 +323,40 @@ class ChannelState(object): class IrcState(IrcCommandDispatcher): """Maintains state of the Irc connection. Should also become smarter. """ - __slots__ = ('history', 'nicksToHostmasks', 'channels') __metaclass__ = log.MetaFirewall __firewalled__ = {'addMsg': None} - def __init__(self): - self.supported = utils.InsensitivePreservingDict() - self.history=RingBuffer(conf.supybot.protocols.irc.maxHistoryLength()) - self.reset() + def __init__(self, history=None, supported=None, + nicksToHostmasks=None, channels=None): + if history is None: + history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength()) + if supported is None: + supported = utils.InsensitivePreservingDict() + if nicksToHostmasks is None: + nicksToHostmasks = ircutils.IrcDict() + if channels is None: + channels = ircutils.IrcDict() + self.supported = supported + self.history = history + self.channels = channels + self.nicksToHostmasks = nicksToHostmasks def reset(self): """Resets the state to normal, unconnected state.""" self.history.reset() + self.channels.clear() self.supported.clear() + self.nicksToHostmasks.clear() self.history.resize(conf.supybot.protocols.irc.maxHistoryLength()) - self.channels = ircutils.IrcDict() - self.nicksToHostmasks = ircutils.IrcDict() - - def __getstate__(self): - return map(utils.curry(getattr, self), self.__slots__) - - def __setstate__(self, t): - for (name, value) in zip(self.__slots__, t): - setattr(self, name, value) + def __reduce__(self): + return (self.__class__, (self.history, self.supported, + self.nicksToHostmasks, self.channels)) + def __eq__(self, other): - ret = True - for name in self.__slots__: - ret = ret and getattr(self, name) == getattr(other, name) - return ret + return self.history == other.history and \ + self.channels == other.channels and \ + self.supported == other.supported and \ + self.nicksToHostmasks == other.nicksToHostmasks def __ne__(self, other): return not self == other diff --git a/test/test_irclib.py b/test/test_irclib.py index 98c357242..558ed9e71 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -254,7 +254,17 @@ class IrcStateTestCase(SupyTestCase): pass self.assertEqual(state, state.copy()) + def testCopyCopiesChannels(self): state = irclib.IrcState() + stateCopy = copy.copy(state) + state.channels['#foo'] = None + self.failIf('#foo' in stateCopy.channels) + + def testCopyCopiesChannels2(self): + state = irclib.IrcState() + stateCopy = state.copy() + state.channels['#foo'] = None + self.failIf('#foo' in stateCopy.channels) def testJoin(self): st = irclib.IrcState()