mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-11-05 18:49:23 +01:00
Services: Keep per-network state separate
Until now, only `waitingJoins` was stored separately per network, while `channels`, `sentGhost` and `identified` had one common value per plugin instance. Instead of making everything a dictionary indexed by network name like `waitingJoins`, let's bundle all the state together in a class and store *its* instances in such a dictionary. This fixes at least one race condition, for which a test case was added. Even with `noJoinsUntilIdentified` set, the bot would let joins through as long as *any* one network has already finished identifying.
This commit is contained in:
parent
177c20267c
commit
3d21c7cbcb
@ -42,6 +42,13 @@ import supybot.callbacks as callbacks
|
||||
from supybot.i18n import PluginInternationalization, internationalizeDocstring
|
||||
_ = PluginInternationalization('Services')
|
||||
|
||||
class State:
|
||||
def __init__(self):
|
||||
self.channels = []
|
||||
self.sentGhost = None
|
||||
self.identified = False
|
||||
self.waitingJoins = []
|
||||
|
||||
class Services(callbacks.Plugin):
|
||||
"""This plugin handles dealing with Services on networks that provide them.
|
||||
Basically, you should use the "password" command to tell the bot a nick to
|
||||
@ -66,10 +73,10 @@ class Services(callbacks.Plugin):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.channels = []
|
||||
self.sentGhost = None
|
||||
self.identified = False
|
||||
self.waitingJoins = {}
|
||||
self.state = {}
|
||||
|
||||
def _getState(self, irc):
|
||||
return self.state.setdefault(irc.network, State())
|
||||
|
||||
def disabled(self, irc):
|
||||
disabled = self.registryValue('disabledNetworks')
|
||||
@ -79,13 +86,13 @@ class Services(callbacks.Plugin):
|
||||
return False
|
||||
|
||||
def outFilter(self, irc, msg):
|
||||
state = self._getState(irc)
|
||||
if msg.command == 'JOIN' and not self.disabled(irc):
|
||||
if not self.identified:
|
||||
if not state.identified:
|
||||
if self.registryValue('noJoinsUntilIdentified', network=irc.network):
|
||||
self.log.info('Holding JOIN to %s @ %s until identified.',
|
||||
msg.channel, irc.network)
|
||||
self.waitingJoins.setdefault(irc.network, [])
|
||||
self.waitingJoins[irc.network].append(msg)
|
||||
state.waitingJoins.append(msg)
|
||||
return None
|
||||
return msg
|
||||
|
||||
@ -131,6 +138,7 @@ class Services(callbacks.Plugin):
|
||||
def _doGhost(self, irc, nick=None):
|
||||
if self.disabled(irc):
|
||||
return
|
||||
state = self._getState(irc)
|
||||
if nick is None:
|
||||
nick = self._getNick(irc.network)
|
||||
if nick not in self.registryValue('nicks', network=irc.network):
|
||||
@ -144,7 +152,7 @@ class Services(callbacks.Plugin):
|
||||
s = 'Tried to ghost without a NickServ or password set.'
|
||||
self.log.warning(s)
|
||||
return
|
||||
if self.sentGhost and time.time() < (self.sentGhost + ghostDelay):
|
||||
if state.sentGhost and time.time() < (state.sentGhost + ghostDelay):
|
||||
self.log.warning('Refusing to send GHOST more than once every '
|
||||
'%s seconds.' % ghostDelay)
|
||||
elif not password:
|
||||
@ -156,12 +164,13 @@ class Services(callbacks.Plugin):
|
||||
ghost = 'GHOST %s %s' % (nick, password)
|
||||
# Ditto about the sendMsg (see _doIdentify).
|
||||
irc.sendMsg(ircmsgs.privmsg(nickserv, ghost))
|
||||
self.sentGhost = time.time()
|
||||
state.sentGhost = time.time()
|
||||
|
||||
def __call__(self, irc, msg):
|
||||
self.__parent.__call__(irc, msg)
|
||||
if self.disabled(irc):
|
||||
return
|
||||
state = self._getState(irc)
|
||||
nick = self._getNick(irc.network)
|
||||
if nick not in self.registryValue('nicks', network=irc.network):
|
||||
return
|
||||
@ -172,16 +181,17 @@ class Services(callbacks.Plugin):
|
||||
return
|
||||
if nick and nickserv and password and \
|
||||
not ircutils.strEqual(nick, irc.nick):
|
||||
if irc.afterConnect and (self.sentGhost is None or
|
||||
(self.sentGhost + ghostDelay) < time.time()):
|
||||
if irc.afterConnect and (state.sentGhost is None or
|
||||
(state.sentGhost + ghostDelay) < time.time()):
|
||||
if nick in irc.state.nicksToHostmasks:
|
||||
self._doGhost(irc)
|
||||
else:
|
||||
irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere.
|
||||
|
||||
def do001(self, irc, msg):
|
||||
# New connection, make sure sentGhost is False.
|
||||
self.sentGhost = None
|
||||
# New connection, make sure sentGhost is None.
|
||||
state = self._getState(irc)
|
||||
state.sentGhost = None
|
||||
|
||||
def do376(self, irc, msg):
|
||||
if self.disabled(irc):
|
||||
@ -221,7 +231,8 @@ class Services(callbacks.Plugin):
|
||||
|
||||
def do515(self, irc, msg):
|
||||
# Can't join this channel, it's +r (we must be identified).
|
||||
self.channels.append(msg.args[1])
|
||||
state = self._getState(irc)
|
||||
state.channels.append(msg.args[1])
|
||||
|
||||
def doNick(self, irc, msg):
|
||||
nick = self._getNick(irc.network)
|
||||
@ -295,6 +306,7 @@ class Services(callbacks.Plugin):
|
||||
def doNickservNotice(self, irc, msg):
|
||||
if self.disabled(irc):
|
||||
return
|
||||
state = self._getState(irc)
|
||||
nick = self._getNick(irc.network)
|
||||
s = ircutils.stripFormatting(msg.args[1].lower())
|
||||
on = 'on %s' % irc.network
|
||||
@ -303,19 +315,19 @@ class Services(callbacks.Plugin):
|
||||
log = 'Received "Password Incorrect" from NickServ %s. ' \
|
||||
'Resetting password to empty.' % on
|
||||
self.log.warning(log)
|
||||
self.sentGhost = time.time()
|
||||
state.sentGhost = time.time()
|
||||
self._setNickServPassword(nick, '', irc.network)
|
||||
elif self._ghosted(irc, s):
|
||||
self.log.info('Received "GHOST succeeded" from NickServ %s.', on)
|
||||
self.sentGhost = None
|
||||
self.identified = False
|
||||
state.sentGhost = None
|
||||
state.identified = False
|
||||
irc.queueMsg(ircmsgs.nick(nick))
|
||||
elif 'is not registered' in s:
|
||||
self.log.info('Received "Nick not registered" from NickServ %s.',
|
||||
on)
|
||||
elif 'currently' in s and 'isn\'t' in s or 'is not' in s:
|
||||
# The nick isn't online, let's change our nick to it.
|
||||
self.sentGhost = None
|
||||
state.sentGhost = None
|
||||
irc.queueMsg(ircmsgs.nick(nick))
|
||||
elif ('owned by someone else' in s) or \
|
||||
('nickname is registered and protected' in s) or \
|
||||
@ -339,15 +351,15 @@ class Services(callbacks.Plugin):
|
||||
# freenode, oftc, arstechnica, zirc, ....
|
||||
# sorcery
|
||||
self.log.info('Received "Password accepted" from NickServ %s.', on)
|
||||
self.identified = True
|
||||
state.identified = True
|
||||
for channel in irc.state.channels.keys():
|
||||
self.checkPrivileges(irc, channel)
|
||||
for channel in self.channels:
|
||||
for channel in state.channels:
|
||||
irc.queueMsg(networkGroup.channels.join(channel))
|
||||
waitingJoins = self.waitingJoins.pop(irc.network, None)
|
||||
if waitingJoins:
|
||||
for m in waitingJoins:
|
||||
irc.sendMsg(m)
|
||||
waitingJoins = state.waitingJoins
|
||||
state.waitingJoins = []
|
||||
for join in waitingJoins:
|
||||
irc.sendMsg(join)
|
||||
elif 'not yet authenticated' in s:
|
||||
# zirc.org has this, it requires an auth code.
|
||||
email = s.split()[-1]
|
||||
@ -401,7 +413,8 @@ class Services(callbacks.Plugin):
|
||||
channel, on)
|
||||
|
||||
def do366(self, irc, msg): # End of /NAMES list; finished joining a channel
|
||||
if self.identified:
|
||||
state = self._getState(irc)
|
||||
if state.identified:
|
||||
channel = msg.args[1] # nick is msg.args[0].
|
||||
self.checkPrivileges(irc, channel)
|
||||
|
||||
|
@ -30,6 +30,7 @@
|
||||
from supybot.test import *
|
||||
import supybot.conf as conf
|
||||
from supybot.ircmsgs import IrcMsg
|
||||
from copy import copy
|
||||
|
||||
class ServicesTestCase(PluginTestCase):
|
||||
plugins = ('Services', 'Config')
|
||||
@ -105,6 +106,72 @@ class ServicesTestCase(PluginTestCase):
|
||||
self.assertIsNone(self.irc.takeMsg())
|
||||
|
||||
|
||||
class JoinsBeforeIdentifiedTestCase(PluginTestCase):
|
||||
plugins = ('Services',)
|
||||
config = {
|
||||
'plugins.Services.noJoinsUntilIdentified': False,
|
||||
}
|
||||
|
||||
def testSingleNetwork(self):
|
||||
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
|
||||
self.irc.queueMsg(queuedJoin)
|
||||
self.assertEqual(self.irc.takeMsg(), queuedJoin,
|
||||
'Join request did not go through.')
|
||||
|
||||
|
||||
class NoJoinsUntilIdentifiedTestCase(PluginTestCase):
|
||||
plugins = ('Services',)
|
||||
config = {
|
||||
'plugins.Services.noJoinsUntilIdentified': True,
|
||||
}
|
||||
|
||||
def _identify(self, irc):
|
||||
irc.feedMsg(IrcMsg(command='376', args=(self.nick,)))
|
||||
msg = irc.takeMsg()
|
||||
self.assertEqual(msg.command, 'PRIVMSG')
|
||||
self.assertEqual(msg.args[0], 'NickServ')
|
||||
irc.feedMsg(ircmsgs.notice(self.nick, 'now identified', 'NickServ'))
|
||||
|
||||
def testSingleNetwork(self):
|
||||
try:
|
||||
self.assertNotError('services password %s secret' % self.nick)
|
||||
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
|
||||
self.irc.queueMsg(queuedJoin)
|
||||
self.assertIsNone(self.irc.takeMsg(),
|
||||
'Join request went through before identification.')
|
||||
self._identify(self.irc)
|
||||
self.assertEqual(self.irc.takeMsg(), queuedJoin,
|
||||
'Join request did not go through after identification.')
|
||||
finally:
|
||||
self.assertNotError('services password %s ""' % self.nick)
|
||||
|
||||
def testMultipleNetworks(self):
|
||||
try:
|
||||
net1 = copy(self)
|
||||
net1.irc = getTestIrc('testnet1')
|
||||
net1.assertNotError('services password %s secret' % self.nick)
|
||||
net2 = copy(self)
|
||||
net2.irc = getTestIrc('testnet2')
|
||||
net2.assertNotError('services password %s secret' % self.nick)
|
||||
queuedJoin1 = ircmsgs.join('#testchan1', prefix=self.prefix)
|
||||
net1.irc.queueMsg(queuedJoin1)
|
||||
self.assertIsNone(net1.irc.takeMsg(),
|
||||
'Join request 1 went through before identification.')
|
||||
self._identify(net1.irc)
|
||||
self.assertEqual(net1.irc.takeMsg(), queuedJoin1,
|
||||
'Join request 1 did not go through after identification.')
|
||||
queuedJoin2 = ircmsgs.join('#testchan2', prefix=self.prefix)
|
||||
net2.irc.queueMsg(queuedJoin2)
|
||||
self.assertIsNone(net2.irc.takeMsg(),
|
||||
'Join request 2 went through before identification.')
|
||||
self._identify(net2.irc)
|
||||
self.assertEqual(net2.irc.takeMsg(), queuedJoin2,
|
||||
'Join request 2 did not go through after identification.')
|
||||
finally:
|
||||
net1.assertNotError('services password %s ""' % self.nick)
|
||||
net2.assertNotError('services password %s ""' % self.nick)
|
||||
|
||||
|
||||
class ExperimentalServicesTestCase(PluginTestCase):
|
||||
plugins = ["Services"]
|
||||
timeout = 0.1
|
||||
|
Loading…
Reference in New Issue
Block a user