Services: Keep per-network state separate

Until now, only `waitingJoins` was stored separately per network, while
`channels`, `sentGhost` and `identified` had one common value per plugin
instance.  Instead of making everything a dictionary indexed by network
name like `waitingJoins`, let's bundle all the state together in a class
and store *its* instances in such a dictionary.

This fixes at least one race condition, for which a test case was added.
Even with `noJoinsUntilIdentified` set, the bot would let joins through
as long as *any* one network has already finished identifying.
This commit is contained in:
David Macek 2021-04-24 20:33:53 +02:00 committed by Valentin Lorentz
parent 177c20267c
commit 3d21c7cbcb
2 changed files with 105 additions and 25 deletions

View File

@ -42,6 +42,13 @@ import supybot.callbacks as callbacks
from supybot.i18n import PluginInternationalization, internationalizeDocstring from supybot.i18n import PluginInternationalization, internationalizeDocstring
_ = PluginInternationalization('Services') _ = PluginInternationalization('Services')
class State:
def __init__(self):
self.channels = []
self.sentGhost = None
self.identified = False
self.waitingJoins = []
class Services(callbacks.Plugin): class Services(callbacks.Plugin):
"""This plugin handles dealing with Services on networks that provide them. """This plugin handles dealing with Services on networks that provide them.
Basically, you should use the "password" command to tell the bot a nick to Basically, you should use the "password" command to tell the bot a nick to
@ -66,10 +73,10 @@ class Services(callbacks.Plugin):
self.reset() self.reset()
def reset(self): def reset(self):
self.channels = [] self.state = {}
self.sentGhost = None
self.identified = False def _getState(self, irc):
self.waitingJoins = {} return self.state.setdefault(irc.network, State())
def disabled(self, irc): def disabled(self, irc):
disabled = self.registryValue('disabledNetworks') disabled = self.registryValue('disabledNetworks')
@ -79,13 +86,13 @@ class Services(callbacks.Plugin):
return False return False
def outFilter(self, irc, msg): def outFilter(self, irc, msg):
state = self._getState(irc)
if msg.command == 'JOIN' and not self.disabled(irc): if msg.command == 'JOIN' and not self.disabled(irc):
if not self.identified: if not state.identified:
if self.registryValue('noJoinsUntilIdentified', network=irc.network): if self.registryValue('noJoinsUntilIdentified', network=irc.network):
self.log.info('Holding JOIN to %s @ %s until identified.', self.log.info('Holding JOIN to %s @ %s until identified.',
msg.channel, irc.network) msg.channel, irc.network)
self.waitingJoins.setdefault(irc.network, []) state.waitingJoins.append(msg)
self.waitingJoins[irc.network].append(msg)
return None return None
return msg return msg
@ -131,6 +138,7 @@ class Services(callbacks.Plugin):
def _doGhost(self, irc, nick=None): def _doGhost(self, irc, nick=None):
if self.disabled(irc): if self.disabled(irc):
return return
state = self._getState(irc)
if nick is None: if nick is None:
nick = self._getNick(irc.network) nick = self._getNick(irc.network)
if nick not in self.registryValue('nicks', network=irc.network): if nick not in self.registryValue('nicks', network=irc.network):
@ -144,7 +152,7 @@ class Services(callbacks.Plugin):
s = 'Tried to ghost without a NickServ or password set.' s = 'Tried to ghost without a NickServ or password set.'
self.log.warning(s) self.log.warning(s)
return return
if self.sentGhost and time.time() < (self.sentGhost + ghostDelay): if state.sentGhost and time.time() < (state.sentGhost + ghostDelay):
self.log.warning('Refusing to send GHOST more than once every ' self.log.warning('Refusing to send GHOST more than once every '
'%s seconds.' % ghostDelay) '%s seconds.' % ghostDelay)
elif not password: elif not password:
@ -156,12 +164,13 @@ class Services(callbacks.Plugin):
ghost = 'GHOST %s %s' % (nick, password) ghost = 'GHOST %s %s' % (nick, password)
# Ditto about the sendMsg (see _doIdentify). # Ditto about the sendMsg (see _doIdentify).
irc.sendMsg(ircmsgs.privmsg(nickserv, ghost)) irc.sendMsg(ircmsgs.privmsg(nickserv, ghost))
self.sentGhost = time.time() state.sentGhost = time.time()
def __call__(self, irc, msg): def __call__(self, irc, msg):
self.__parent.__call__(irc, msg) self.__parent.__call__(irc, msg)
if self.disabled(irc): if self.disabled(irc):
return return
state = self._getState(irc)
nick = self._getNick(irc.network) nick = self._getNick(irc.network)
if nick not in self.registryValue('nicks', network=irc.network): if nick not in self.registryValue('nicks', network=irc.network):
return return
@ -172,16 +181,17 @@ class Services(callbacks.Plugin):
return return
if nick and nickserv and password and \ if nick and nickserv and password and \
not ircutils.strEqual(nick, irc.nick): not ircutils.strEqual(nick, irc.nick):
if irc.afterConnect and (self.sentGhost is None or if irc.afterConnect and (state.sentGhost is None or
(self.sentGhost + ghostDelay) < time.time()): (state.sentGhost + ghostDelay) < time.time()):
if nick in irc.state.nicksToHostmasks: if nick in irc.state.nicksToHostmasks:
self._doGhost(irc) self._doGhost(irc)
else: else:
irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere. irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere.
def do001(self, irc, msg): def do001(self, irc, msg):
# New connection, make sure sentGhost is False. # New connection, make sure sentGhost is None.
self.sentGhost = None state = self._getState(irc)
state.sentGhost = None
def do376(self, irc, msg): def do376(self, irc, msg):
if self.disabled(irc): if self.disabled(irc):
@ -221,7 +231,8 @@ class Services(callbacks.Plugin):
def do515(self, irc, msg): def do515(self, irc, msg):
# Can't join this channel, it's +r (we must be identified). # Can't join this channel, it's +r (we must be identified).
self.channels.append(msg.args[1]) state = self._getState(irc)
state.channels.append(msg.args[1])
def doNick(self, irc, msg): def doNick(self, irc, msg):
nick = self._getNick(irc.network) nick = self._getNick(irc.network)
@ -295,6 +306,7 @@ class Services(callbacks.Plugin):
def doNickservNotice(self, irc, msg): def doNickservNotice(self, irc, msg):
if self.disabled(irc): if self.disabled(irc):
return return
state = self._getState(irc)
nick = self._getNick(irc.network) nick = self._getNick(irc.network)
s = ircutils.stripFormatting(msg.args[1].lower()) s = ircutils.stripFormatting(msg.args[1].lower())
on = 'on %s' % irc.network on = 'on %s' % irc.network
@ -303,19 +315,19 @@ class Services(callbacks.Plugin):
log = 'Received "Password Incorrect" from NickServ %s. ' \ log = 'Received "Password Incorrect" from NickServ %s. ' \
'Resetting password to empty.' % on 'Resetting password to empty.' % on
self.log.warning(log) self.log.warning(log)
self.sentGhost = time.time() state.sentGhost = time.time()
self._setNickServPassword(nick, '', irc.network) self._setNickServPassword(nick, '', irc.network)
elif self._ghosted(irc, s): elif self._ghosted(irc, s):
self.log.info('Received "GHOST succeeded" from NickServ %s.', on) self.log.info('Received "GHOST succeeded" from NickServ %s.', on)
self.sentGhost = None state.sentGhost = None
self.identified = False state.identified = False
irc.queueMsg(ircmsgs.nick(nick)) irc.queueMsg(ircmsgs.nick(nick))
elif 'is not registered' in s: elif 'is not registered' in s:
self.log.info('Received "Nick not registered" from NickServ %s.', self.log.info('Received "Nick not registered" from NickServ %s.',
on) on)
elif 'currently' in s and 'isn\'t' in s or 'is not' in s: elif 'currently' in s and 'isn\'t' in s or 'is not' in s:
# The nick isn't online, let's change our nick to it. # The nick isn't online, let's change our nick to it.
self.sentGhost = None state.sentGhost = None
irc.queueMsg(ircmsgs.nick(nick)) irc.queueMsg(ircmsgs.nick(nick))
elif ('owned by someone else' in s) or \ elif ('owned by someone else' in s) or \
('nickname is registered and protected' in s) or \ ('nickname is registered and protected' in s) or \
@ -339,15 +351,15 @@ class Services(callbacks.Plugin):
# freenode, oftc, arstechnica, zirc, .... # freenode, oftc, arstechnica, zirc, ....
# sorcery # sorcery
self.log.info('Received "Password accepted" from NickServ %s.', on) self.log.info('Received "Password accepted" from NickServ %s.', on)
self.identified = True state.identified = True
for channel in irc.state.channels.keys(): for channel in irc.state.channels.keys():
self.checkPrivileges(irc, channel) self.checkPrivileges(irc, channel)
for channel in self.channels: for channel in state.channels:
irc.queueMsg(networkGroup.channels.join(channel)) irc.queueMsg(networkGroup.channels.join(channel))
waitingJoins = self.waitingJoins.pop(irc.network, None) waitingJoins = state.waitingJoins
if waitingJoins: state.waitingJoins = []
for m in waitingJoins: for join in waitingJoins:
irc.sendMsg(m) irc.sendMsg(join)
elif 'not yet authenticated' in s: elif 'not yet authenticated' in s:
# zirc.org has this, it requires an auth code. # zirc.org has this, it requires an auth code.
email = s.split()[-1] email = s.split()[-1]
@ -401,7 +413,8 @@ class Services(callbacks.Plugin):
channel, on) channel, on)
def do366(self, irc, msg): # End of /NAMES list; finished joining a channel def do366(self, irc, msg): # End of /NAMES list; finished joining a channel
if self.identified: state = self._getState(irc)
if state.identified:
channel = msg.args[1] # nick is msg.args[0]. channel = msg.args[1] # nick is msg.args[0].
self.checkPrivileges(irc, channel) self.checkPrivileges(irc, channel)

View File

@ -30,6 +30,7 @@
from supybot.test import * from supybot.test import *
import supybot.conf as conf import supybot.conf as conf
from supybot.ircmsgs import IrcMsg from supybot.ircmsgs import IrcMsg
from copy import copy
class ServicesTestCase(PluginTestCase): class ServicesTestCase(PluginTestCase):
plugins = ('Services', 'Config') plugins = ('Services', 'Config')
@ -105,6 +106,72 @@ class ServicesTestCase(PluginTestCase):
self.assertIsNone(self.irc.takeMsg()) self.assertIsNone(self.irc.takeMsg())
class JoinsBeforeIdentifiedTestCase(PluginTestCase):
plugins = ('Services',)
config = {
'plugins.Services.noJoinsUntilIdentified': False,
}
def testSingleNetwork(self):
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
self.irc.queueMsg(queuedJoin)
self.assertEqual(self.irc.takeMsg(), queuedJoin,
'Join request did not go through.')
class NoJoinsUntilIdentifiedTestCase(PluginTestCase):
plugins = ('Services',)
config = {
'plugins.Services.noJoinsUntilIdentified': True,
}
def _identify(self, irc):
irc.feedMsg(IrcMsg(command='376', args=(self.nick,)))
msg = irc.takeMsg()
self.assertEqual(msg.command, 'PRIVMSG')
self.assertEqual(msg.args[0], 'NickServ')
irc.feedMsg(ircmsgs.notice(self.nick, 'now identified', 'NickServ'))
def testSingleNetwork(self):
try:
self.assertNotError('services password %s secret' % self.nick)
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
self.irc.queueMsg(queuedJoin)
self.assertIsNone(self.irc.takeMsg(),
'Join request went through before identification.')
self._identify(self.irc)
self.assertEqual(self.irc.takeMsg(), queuedJoin,
'Join request did not go through after identification.')
finally:
self.assertNotError('services password %s ""' % self.nick)
def testMultipleNetworks(self):
try:
net1 = copy(self)
net1.irc = getTestIrc('testnet1')
net1.assertNotError('services password %s secret' % self.nick)
net2 = copy(self)
net2.irc = getTestIrc('testnet2')
net2.assertNotError('services password %s secret' % self.nick)
queuedJoin1 = ircmsgs.join('#testchan1', prefix=self.prefix)
net1.irc.queueMsg(queuedJoin1)
self.assertIsNone(net1.irc.takeMsg(),
'Join request 1 went through before identification.')
self._identify(net1.irc)
self.assertEqual(net1.irc.takeMsg(), queuedJoin1,
'Join request 1 did not go through after identification.')
queuedJoin2 = ircmsgs.join('#testchan2', prefix=self.prefix)
net2.irc.queueMsg(queuedJoin2)
self.assertIsNone(net2.irc.takeMsg(),
'Join request 2 went through before identification.')
self._identify(net2.irc)
self.assertEqual(net2.irc.takeMsg(), queuedJoin2,
'Join request 2 did not go through after identification.')
finally:
net1.assertNotError('services password %s ""' % self.nick)
net2.assertNotError('services password %s ""' % self.nick)
class ExperimentalServicesTestCase(PluginTestCase): class ExperimentalServicesTestCase(PluginTestCase):
plugins = ["Services"] plugins = ["Services"]
timeout = 0.1 timeout = 0.1