Services: Keep per-network state separate

Until now, only `waitingJoins` was stored separately per network, while
`channels`, `sentGhost` and `identified` had one common value per plugin
instance.  Instead of making everything a dictionary indexed by network
name like `waitingJoins`, let's bundle all the state together in a class
and store *its* instances in such a dictionary.

This fixes at least one race condition, for which a test case was added.
Even with `noJoinsUntilIdentified` set, the bot would let joins through
as long as *any* one network has already finished identifying.
This commit is contained in:
David Macek 2021-04-24 20:33:53 +02:00 committed by Valentin Lorentz
parent 177c20267c
commit 3d21c7cbcb
2 changed files with 105 additions and 25 deletions

View File

@ -42,6 +42,13 @@ import supybot.callbacks as callbacks
from supybot.i18n import PluginInternationalization, internationalizeDocstring
_ = PluginInternationalization('Services')
class State:
def __init__(self):
self.channels = []
self.sentGhost = None
self.identified = False
self.waitingJoins = []
class Services(callbacks.Plugin):
"""This plugin handles dealing with Services on networks that provide them.
Basically, you should use the "password" command to tell the bot a nick to
@ -66,10 +73,10 @@ class Services(callbacks.Plugin):
self.reset()
def reset(self):
self.channels = []
self.sentGhost = None
self.identified = False
self.waitingJoins = {}
self.state = {}
def _getState(self, irc):
return self.state.setdefault(irc.network, State())
def disabled(self, irc):
disabled = self.registryValue('disabledNetworks')
@ -79,13 +86,13 @@ class Services(callbacks.Plugin):
return False
def outFilter(self, irc, msg):
state = self._getState(irc)
if msg.command == 'JOIN' and not self.disabled(irc):
if not self.identified:
if not state.identified:
if self.registryValue('noJoinsUntilIdentified', network=irc.network):
self.log.info('Holding JOIN to %s @ %s until identified.',
msg.channel, irc.network)
self.waitingJoins.setdefault(irc.network, [])
self.waitingJoins[irc.network].append(msg)
state.waitingJoins.append(msg)
return None
return msg
@ -131,6 +138,7 @@ class Services(callbacks.Plugin):
def _doGhost(self, irc, nick=None):
if self.disabled(irc):
return
state = self._getState(irc)
if nick is None:
nick = self._getNick(irc.network)
if nick not in self.registryValue('nicks', network=irc.network):
@ -144,7 +152,7 @@ class Services(callbacks.Plugin):
s = 'Tried to ghost without a NickServ or password set.'
self.log.warning(s)
return
if self.sentGhost and time.time() < (self.sentGhost + ghostDelay):
if state.sentGhost and time.time() < (state.sentGhost + ghostDelay):
self.log.warning('Refusing to send GHOST more than once every '
'%s seconds.' % ghostDelay)
elif not password:
@ -156,12 +164,13 @@ class Services(callbacks.Plugin):
ghost = 'GHOST %s %s' % (nick, password)
# Ditto about the sendMsg (see _doIdentify).
irc.sendMsg(ircmsgs.privmsg(nickserv, ghost))
self.sentGhost = time.time()
state.sentGhost = time.time()
def __call__(self, irc, msg):
self.__parent.__call__(irc, msg)
if self.disabled(irc):
return
state = self._getState(irc)
nick = self._getNick(irc.network)
if nick not in self.registryValue('nicks', network=irc.network):
return
@ -172,16 +181,17 @@ class Services(callbacks.Plugin):
return
if nick and nickserv and password and \
not ircutils.strEqual(nick, irc.nick):
if irc.afterConnect and (self.sentGhost is None or
(self.sentGhost + ghostDelay) < time.time()):
if irc.afterConnect and (state.sentGhost is None or
(state.sentGhost + ghostDelay) < time.time()):
if nick in irc.state.nicksToHostmasks:
self._doGhost(irc)
else:
irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere.
def do001(self, irc, msg):
# New connection, make sure sentGhost is False.
self.sentGhost = None
# New connection, make sure sentGhost is None.
state = self._getState(irc)
state.sentGhost = None
def do376(self, irc, msg):
if self.disabled(irc):
@ -221,7 +231,8 @@ class Services(callbacks.Plugin):
def do515(self, irc, msg):
# Can't join this channel, it's +r (we must be identified).
self.channels.append(msg.args[1])
state = self._getState(irc)
state.channels.append(msg.args[1])
def doNick(self, irc, msg):
nick = self._getNick(irc.network)
@ -295,6 +306,7 @@ class Services(callbacks.Plugin):
def doNickservNotice(self, irc, msg):
if self.disabled(irc):
return
state = self._getState(irc)
nick = self._getNick(irc.network)
s = ircutils.stripFormatting(msg.args[1].lower())
on = 'on %s' % irc.network
@ -303,19 +315,19 @@ class Services(callbacks.Plugin):
log = 'Received "Password Incorrect" from NickServ %s. ' \
'Resetting password to empty.' % on
self.log.warning(log)
self.sentGhost = time.time()
state.sentGhost = time.time()
self._setNickServPassword(nick, '', irc.network)
elif self._ghosted(irc, s):
self.log.info('Received "GHOST succeeded" from NickServ %s.', on)
self.sentGhost = None
self.identified = False
state.sentGhost = None
state.identified = False
irc.queueMsg(ircmsgs.nick(nick))
elif 'is not registered' in s:
self.log.info('Received "Nick not registered" from NickServ %s.',
on)
elif 'currently' in s and 'isn\'t' in s or 'is not' in s:
# The nick isn't online, let's change our nick to it.
self.sentGhost = None
state.sentGhost = None
irc.queueMsg(ircmsgs.nick(nick))
elif ('owned by someone else' in s) or \
('nickname is registered and protected' in s) or \
@ -339,15 +351,15 @@ class Services(callbacks.Plugin):
# freenode, oftc, arstechnica, zirc, ....
# sorcery
self.log.info('Received "Password accepted" from NickServ %s.', on)
self.identified = True
state.identified = True
for channel in irc.state.channels.keys():
self.checkPrivileges(irc, channel)
for channel in self.channels:
for channel in state.channels:
irc.queueMsg(networkGroup.channels.join(channel))
waitingJoins = self.waitingJoins.pop(irc.network, None)
if waitingJoins:
for m in waitingJoins:
irc.sendMsg(m)
waitingJoins = state.waitingJoins
state.waitingJoins = []
for join in waitingJoins:
irc.sendMsg(join)
elif 'not yet authenticated' in s:
# zirc.org has this, it requires an auth code.
email = s.split()[-1]
@ -401,7 +413,8 @@ class Services(callbacks.Plugin):
channel, on)
def do366(self, irc, msg): # End of /NAMES list; finished joining a channel
if self.identified:
state = self._getState(irc)
if state.identified:
channel = msg.args[1] # nick is msg.args[0].
self.checkPrivileges(irc, channel)

View File

@ -30,6 +30,7 @@
from supybot.test import *
import supybot.conf as conf
from supybot.ircmsgs import IrcMsg
from copy import copy
class ServicesTestCase(PluginTestCase):
plugins = ('Services', 'Config')
@ -105,6 +106,72 @@ class ServicesTestCase(PluginTestCase):
self.assertIsNone(self.irc.takeMsg())
class JoinsBeforeIdentifiedTestCase(PluginTestCase):
plugins = ('Services',)
config = {
'plugins.Services.noJoinsUntilIdentified': False,
}
def testSingleNetwork(self):
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
self.irc.queueMsg(queuedJoin)
self.assertEqual(self.irc.takeMsg(), queuedJoin,
'Join request did not go through.')
class NoJoinsUntilIdentifiedTestCase(PluginTestCase):
plugins = ('Services',)
config = {
'plugins.Services.noJoinsUntilIdentified': True,
}
def _identify(self, irc):
irc.feedMsg(IrcMsg(command='376', args=(self.nick,)))
msg = irc.takeMsg()
self.assertEqual(msg.command, 'PRIVMSG')
self.assertEqual(msg.args[0], 'NickServ')
irc.feedMsg(ircmsgs.notice(self.nick, 'now identified', 'NickServ'))
def testSingleNetwork(self):
try:
self.assertNotError('services password %s secret' % self.nick)
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
self.irc.queueMsg(queuedJoin)
self.assertIsNone(self.irc.takeMsg(),
'Join request went through before identification.')
self._identify(self.irc)
self.assertEqual(self.irc.takeMsg(), queuedJoin,
'Join request did not go through after identification.')
finally:
self.assertNotError('services password %s ""' % self.nick)
def testMultipleNetworks(self):
try:
net1 = copy(self)
net1.irc = getTestIrc('testnet1')
net1.assertNotError('services password %s secret' % self.nick)
net2 = copy(self)
net2.irc = getTestIrc('testnet2')
net2.assertNotError('services password %s secret' % self.nick)
queuedJoin1 = ircmsgs.join('#testchan1', prefix=self.prefix)
net1.irc.queueMsg(queuedJoin1)
self.assertIsNone(net1.irc.takeMsg(),
'Join request 1 went through before identification.')
self._identify(net1.irc)
self.assertEqual(net1.irc.takeMsg(), queuedJoin1,
'Join request 1 did not go through after identification.')
queuedJoin2 = ircmsgs.join('#testchan2', prefix=self.prefix)
net2.irc.queueMsg(queuedJoin2)
self.assertIsNone(net2.irc.takeMsg(),
'Join request 2 went through before identification.')
self._identify(net2.irc)
self.assertEqual(net2.irc.takeMsg(), queuedJoin2,
'Join request 2 did not go through after identification.')
finally:
net1.assertNotError('services password %s ""' % self.nick)
net2.assertNotError('services password %s ""' % self.nick)
class ExperimentalServicesTestCase(PluginTestCase):
plugins = ["Services"]
timeout = 0.1