From 3d21c7cbcb1e2afe1a7a55a2dcae553801b0f4e3 Mon Sep 17 00:00:00 2001 From: David Macek Date: Sat, 24 Apr 2021 20:33:53 +0200 Subject: [PATCH] Services: Keep per-network state separate Until now, only `waitingJoins` was stored separately per network, while `channels`, `sentGhost` and `identified` had one common value per plugin instance. Instead of making everything a dictionary indexed by network name like `waitingJoins`, let's bundle all the state together in a class and store *its* instances in such a dictionary. This fixes at least one race condition, for which a test case was added. Even with `noJoinsUntilIdentified` set, the bot would let joins through as long as *any* one network has already finished identifying. --- plugins/Services/plugin.py | 63 +++++++++++++++++++++-------------- plugins/Services/test.py | 67 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 25 deletions(-) diff --git a/plugins/Services/plugin.py b/plugins/Services/plugin.py index 6853d8988..4d901a1ba 100644 --- a/plugins/Services/plugin.py +++ b/plugins/Services/plugin.py @@ -42,6 +42,13 @@ import supybot.callbacks as callbacks from supybot.i18n import PluginInternationalization, internationalizeDocstring _ = PluginInternationalization('Services') +class State: + def __init__(self): + self.channels = [] + self.sentGhost = None + self.identified = False + self.waitingJoins = [] + class Services(callbacks.Plugin): """This plugin handles dealing with Services on networks that provide them. Basically, you should use the "password" command to tell the bot a nick to @@ -66,10 +73,10 @@ class Services(callbacks.Plugin): self.reset() def reset(self): - self.channels = [] - self.sentGhost = None - self.identified = False - self.waitingJoins = {} + self.state = {} + + def _getState(self, irc): + return self.state.setdefault(irc.network, State()) def disabled(self, irc): disabled = self.registryValue('disabledNetworks') @@ -79,13 +86,13 @@ class Services(callbacks.Plugin): return False def outFilter(self, irc, msg): + state = self._getState(irc) if msg.command == 'JOIN' and not self.disabled(irc): - if not self.identified: + if not state.identified: if self.registryValue('noJoinsUntilIdentified', network=irc.network): self.log.info('Holding JOIN to %s @ %s until identified.', msg.channel, irc.network) - self.waitingJoins.setdefault(irc.network, []) - self.waitingJoins[irc.network].append(msg) + state.waitingJoins.append(msg) return None return msg @@ -131,6 +138,7 @@ class Services(callbacks.Plugin): def _doGhost(self, irc, nick=None): if self.disabled(irc): return + state = self._getState(irc) if nick is None: nick = self._getNick(irc.network) if nick not in self.registryValue('nicks', network=irc.network): @@ -144,7 +152,7 @@ class Services(callbacks.Plugin): s = 'Tried to ghost without a NickServ or password set.' self.log.warning(s) return - if self.sentGhost and time.time() < (self.sentGhost + ghostDelay): + if state.sentGhost and time.time() < (state.sentGhost + ghostDelay): self.log.warning('Refusing to send GHOST more than once every ' '%s seconds.' % ghostDelay) elif not password: @@ -156,12 +164,13 @@ class Services(callbacks.Plugin): ghost = 'GHOST %s %s' % (nick, password) # Ditto about the sendMsg (see _doIdentify). irc.sendMsg(ircmsgs.privmsg(nickserv, ghost)) - self.sentGhost = time.time() + state.sentGhost = time.time() def __call__(self, irc, msg): self.__parent.__call__(irc, msg) if self.disabled(irc): return + state = self._getState(irc) nick = self._getNick(irc.network) if nick not in self.registryValue('nicks', network=irc.network): return @@ -172,16 +181,17 @@ class Services(callbacks.Plugin): return if nick and nickserv and password and \ not ircutils.strEqual(nick, irc.nick): - if irc.afterConnect and (self.sentGhost is None or - (self.sentGhost + ghostDelay) < time.time()): + if irc.afterConnect and (state.sentGhost is None or + (state.sentGhost + ghostDelay) < time.time()): if nick in irc.state.nicksToHostmasks: self._doGhost(irc) else: irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere. def do001(self, irc, msg): - # New connection, make sure sentGhost is False. - self.sentGhost = None + # New connection, make sure sentGhost is None. + state = self._getState(irc) + state.sentGhost = None def do376(self, irc, msg): if self.disabled(irc): @@ -221,7 +231,8 @@ class Services(callbacks.Plugin): def do515(self, irc, msg): # Can't join this channel, it's +r (we must be identified). - self.channels.append(msg.args[1]) + state = self._getState(irc) + state.channels.append(msg.args[1]) def doNick(self, irc, msg): nick = self._getNick(irc.network) @@ -295,6 +306,7 @@ class Services(callbacks.Plugin): def doNickservNotice(self, irc, msg): if self.disabled(irc): return + state = self._getState(irc) nick = self._getNick(irc.network) s = ircutils.stripFormatting(msg.args[1].lower()) on = 'on %s' % irc.network @@ -303,19 +315,19 @@ class Services(callbacks.Plugin): log = 'Received "Password Incorrect" from NickServ %s. ' \ 'Resetting password to empty.' % on self.log.warning(log) - self.sentGhost = time.time() + state.sentGhost = time.time() self._setNickServPassword(nick, '', irc.network) elif self._ghosted(irc, s): self.log.info('Received "GHOST succeeded" from NickServ %s.', on) - self.sentGhost = None - self.identified = False + state.sentGhost = None + state.identified = False irc.queueMsg(ircmsgs.nick(nick)) elif 'is not registered' in s: self.log.info('Received "Nick not registered" from NickServ %s.', on) elif 'currently' in s and 'isn\'t' in s or 'is not' in s: # The nick isn't online, let's change our nick to it. - self.sentGhost = None + state.sentGhost = None irc.queueMsg(ircmsgs.nick(nick)) elif ('owned by someone else' in s) or \ ('nickname is registered and protected' in s) or \ @@ -339,15 +351,15 @@ class Services(callbacks.Plugin): # freenode, oftc, arstechnica, zirc, .... # sorcery self.log.info('Received "Password accepted" from NickServ %s.', on) - self.identified = True + state.identified = True for channel in irc.state.channels.keys(): self.checkPrivileges(irc, channel) - for channel in self.channels: + for channel in state.channels: irc.queueMsg(networkGroup.channels.join(channel)) - waitingJoins = self.waitingJoins.pop(irc.network, None) - if waitingJoins: - for m in waitingJoins: - irc.sendMsg(m) + waitingJoins = state.waitingJoins + state.waitingJoins = [] + for join in waitingJoins: + irc.sendMsg(join) elif 'not yet authenticated' in s: # zirc.org has this, it requires an auth code. email = s.split()[-1] @@ -401,7 +413,8 @@ class Services(callbacks.Plugin): channel, on) def do366(self, irc, msg): # End of /NAMES list; finished joining a channel - if self.identified: + state = self._getState(irc) + if state.identified: channel = msg.args[1] # nick is msg.args[0]. self.checkPrivileges(irc, channel) diff --git a/plugins/Services/test.py b/plugins/Services/test.py index fe7d53314..66250f554 100644 --- a/plugins/Services/test.py +++ b/plugins/Services/test.py @@ -30,6 +30,7 @@ from supybot.test import * import supybot.conf as conf from supybot.ircmsgs import IrcMsg +from copy import copy class ServicesTestCase(PluginTestCase): plugins = ('Services', 'Config') @@ -105,6 +106,72 @@ class ServicesTestCase(PluginTestCase): self.assertIsNone(self.irc.takeMsg()) +class JoinsBeforeIdentifiedTestCase(PluginTestCase): + plugins = ('Services',) + config = { + 'plugins.Services.noJoinsUntilIdentified': False, + } + + def testSingleNetwork(self): + queuedJoin = ircmsgs.join('#test', prefix=self.prefix) + self.irc.queueMsg(queuedJoin) + self.assertEqual(self.irc.takeMsg(), queuedJoin, + 'Join request did not go through.') + + +class NoJoinsUntilIdentifiedTestCase(PluginTestCase): + plugins = ('Services',) + config = { + 'plugins.Services.noJoinsUntilIdentified': True, + } + + def _identify(self, irc): + irc.feedMsg(IrcMsg(command='376', args=(self.nick,))) + msg = irc.takeMsg() + self.assertEqual(msg.command, 'PRIVMSG') + self.assertEqual(msg.args[0], 'NickServ') + irc.feedMsg(ircmsgs.notice(self.nick, 'now identified', 'NickServ')) + + def testSingleNetwork(self): + try: + self.assertNotError('services password %s secret' % self.nick) + queuedJoin = ircmsgs.join('#test', prefix=self.prefix) + self.irc.queueMsg(queuedJoin) + self.assertIsNone(self.irc.takeMsg(), + 'Join request went through before identification.') + self._identify(self.irc) + self.assertEqual(self.irc.takeMsg(), queuedJoin, + 'Join request did not go through after identification.') + finally: + self.assertNotError('services password %s ""' % self.nick) + + def testMultipleNetworks(self): + try: + net1 = copy(self) + net1.irc = getTestIrc('testnet1') + net1.assertNotError('services password %s secret' % self.nick) + net2 = copy(self) + net2.irc = getTestIrc('testnet2') + net2.assertNotError('services password %s secret' % self.nick) + queuedJoin1 = ircmsgs.join('#testchan1', prefix=self.prefix) + net1.irc.queueMsg(queuedJoin1) + self.assertIsNone(net1.irc.takeMsg(), + 'Join request 1 went through before identification.') + self._identify(net1.irc) + self.assertEqual(net1.irc.takeMsg(), queuedJoin1, + 'Join request 1 did not go through after identification.') + queuedJoin2 = ircmsgs.join('#testchan2', prefix=self.prefix) + net2.irc.queueMsg(queuedJoin2) + self.assertIsNone(net2.irc.takeMsg(), + 'Join request 2 went through before identification.') + self._identify(net2.irc) + self.assertEqual(net2.irc.takeMsg(), queuedJoin2, + 'Join request 2 did not go through after identification.') + finally: + net1.assertNotError('services password %s ""' % self.nick) + net2.assertNotError('services password %s ""' % self.nick) + + class ExperimentalServicesTestCase(PluginTestCase): plugins = ["Services"] timeout = 0.1