mirror of
https://github.com/Mikaela/Limnoria.git
synced 2025-01-12 05:02:32 +01:00
Services: Keep per-network state separate
Until now, only `waitingJoins` was stored separately per network, while `channels`, `sentGhost` and `identified` had one common value per plugin instance. Instead of making everything a dictionary indexed by network name like `waitingJoins`, let's bundle all the state together in a class and store *its* instances in such a dictionary. This fixes at least one race condition, for which a test case was added. Even with `noJoinsUntilIdentified` set, the bot would let joins through as long as *any* one network has already finished identifying.
This commit is contained in:
parent
177c20267c
commit
3d21c7cbcb
@ -42,6 +42,13 @@ import supybot.callbacks as callbacks
|
|||||||
from supybot.i18n import PluginInternationalization, internationalizeDocstring
|
from supybot.i18n import PluginInternationalization, internationalizeDocstring
|
||||||
_ = PluginInternationalization('Services')
|
_ = PluginInternationalization('Services')
|
||||||
|
|
||||||
|
class State:
|
||||||
|
def __init__(self):
|
||||||
|
self.channels = []
|
||||||
|
self.sentGhost = None
|
||||||
|
self.identified = False
|
||||||
|
self.waitingJoins = []
|
||||||
|
|
||||||
class Services(callbacks.Plugin):
|
class Services(callbacks.Plugin):
|
||||||
"""This plugin handles dealing with Services on networks that provide them.
|
"""This plugin handles dealing with Services on networks that provide them.
|
||||||
Basically, you should use the "password" command to tell the bot a nick to
|
Basically, you should use the "password" command to tell the bot a nick to
|
||||||
@ -66,10 +73,10 @@ class Services(callbacks.Plugin):
|
|||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.channels = []
|
self.state = {}
|
||||||
self.sentGhost = None
|
|
||||||
self.identified = False
|
def _getState(self, irc):
|
||||||
self.waitingJoins = {}
|
return self.state.setdefault(irc.network, State())
|
||||||
|
|
||||||
def disabled(self, irc):
|
def disabled(self, irc):
|
||||||
disabled = self.registryValue('disabledNetworks')
|
disabled = self.registryValue('disabledNetworks')
|
||||||
@ -79,13 +86,13 @@ class Services(callbacks.Plugin):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def outFilter(self, irc, msg):
|
def outFilter(self, irc, msg):
|
||||||
|
state = self._getState(irc)
|
||||||
if msg.command == 'JOIN' and not self.disabled(irc):
|
if msg.command == 'JOIN' and not self.disabled(irc):
|
||||||
if not self.identified:
|
if not state.identified:
|
||||||
if self.registryValue('noJoinsUntilIdentified', network=irc.network):
|
if self.registryValue('noJoinsUntilIdentified', network=irc.network):
|
||||||
self.log.info('Holding JOIN to %s @ %s until identified.',
|
self.log.info('Holding JOIN to %s @ %s until identified.',
|
||||||
msg.channel, irc.network)
|
msg.channel, irc.network)
|
||||||
self.waitingJoins.setdefault(irc.network, [])
|
state.waitingJoins.append(msg)
|
||||||
self.waitingJoins[irc.network].append(msg)
|
|
||||||
return None
|
return None
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
@ -131,6 +138,7 @@ class Services(callbacks.Plugin):
|
|||||||
def _doGhost(self, irc, nick=None):
|
def _doGhost(self, irc, nick=None):
|
||||||
if self.disabled(irc):
|
if self.disabled(irc):
|
||||||
return
|
return
|
||||||
|
state = self._getState(irc)
|
||||||
if nick is None:
|
if nick is None:
|
||||||
nick = self._getNick(irc.network)
|
nick = self._getNick(irc.network)
|
||||||
if nick not in self.registryValue('nicks', network=irc.network):
|
if nick not in self.registryValue('nicks', network=irc.network):
|
||||||
@ -144,7 +152,7 @@ class Services(callbacks.Plugin):
|
|||||||
s = 'Tried to ghost without a NickServ or password set.'
|
s = 'Tried to ghost without a NickServ or password set.'
|
||||||
self.log.warning(s)
|
self.log.warning(s)
|
||||||
return
|
return
|
||||||
if self.sentGhost and time.time() < (self.sentGhost + ghostDelay):
|
if state.sentGhost and time.time() < (state.sentGhost + ghostDelay):
|
||||||
self.log.warning('Refusing to send GHOST more than once every '
|
self.log.warning('Refusing to send GHOST more than once every '
|
||||||
'%s seconds.' % ghostDelay)
|
'%s seconds.' % ghostDelay)
|
||||||
elif not password:
|
elif not password:
|
||||||
@ -156,12 +164,13 @@ class Services(callbacks.Plugin):
|
|||||||
ghost = 'GHOST %s %s' % (nick, password)
|
ghost = 'GHOST %s %s' % (nick, password)
|
||||||
# Ditto about the sendMsg (see _doIdentify).
|
# Ditto about the sendMsg (see _doIdentify).
|
||||||
irc.sendMsg(ircmsgs.privmsg(nickserv, ghost))
|
irc.sendMsg(ircmsgs.privmsg(nickserv, ghost))
|
||||||
self.sentGhost = time.time()
|
state.sentGhost = time.time()
|
||||||
|
|
||||||
def __call__(self, irc, msg):
|
def __call__(self, irc, msg):
|
||||||
self.__parent.__call__(irc, msg)
|
self.__parent.__call__(irc, msg)
|
||||||
if self.disabled(irc):
|
if self.disabled(irc):
|
||||||
return
|
return
|
||||||
|
state = self._getState(irc)
|
||||||
nick = self._getNick(irc.network)
|
nick = self._getNick(irc.network)
|
||||||
if nick not in self.registryValue('nicks', network=irc.network):
|
if nick not in self.registryValue('nicks', network=irc.network):
|
||||||
return
|
return
|
||||||
@ -172,16 +181,17 @@ class Services(callbacks.Plugin):
|
|||||||
return
|
return
|
||||||
if nick and nickserv and password and \
|
if nick and nickserv and password and \
|
||||||
not ircutils.strEqual(nick, irc.nick):
|
not ircutils.strEqual(nick, irc.nick):
|
||||||
if irc.afterConnect and (self.sentGhost is None or
|
if irc.afterConnect and (state.sentGhost is None or
|
||||||
(self.sentGhost + ghostDelay) < time.time()):
|
(state.sentGhost + ghostDelay) < time.time()):
|
||||||
if nick in irc.state.nicksToHostmasks:
|
if nick in irc.state.nicksToHostmasks:
|
||||||
self._doGhost(irc)
|
self._doGhost(irc)
|
||||||
else:
|
else:
|
||||||
irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere.
|
irc.sendMsg(ircmsgs.nick(nick)) # 433 is handled elsewhere.
|
||||||
|
|
||||||
def do001(self, irc, msg):
|
def do001(self, irc, msg):
|
||||||
# New connection, make sure sentGhost is False.
|
# New connection, make sure sentGhost is None.
|
||||||
self.sentGhost = None
|
state = self._getState(irc)
|
||||||
|
state.sentGhost = None
|
||||||
|
|
||||||
def do376(self, irc, msg):
|
def do376(self, irc, msg):
|
||||||
if self.disabled(irc):
|
if self.disabled(irc):
|
||||||
@ -221,7 +231,8 @@ class Services(callbacks.Plugin):
|
|||||||
|
|
||||||
def do515(self, irc, msg):
|
def do515(self, irc, msg):
|
||||||
# Can't join this channel, it's +r (we must be identified).
|
# Can't join this channel, it's +r (we must be identified).
|
||||||
self.channels.append(msg.args[1])
|
state = self._getState(irc)
|
||||||
|
state.channels.append(msg.args[1])
|
||||||
|
|
||||||
def doNick(self, irc, msg):
|
def doNick(self, irc, msg):
|
||||||
nick = self._getNick(irc.network)
|
nick = self._getNick(irc.network)
|
||||||
@ -295,6 +306,7 @@ class Services(callbacks.Plugin):
|
|||||||
def doNickservNotice(self, irc, msg):
|
def doNickservNotice(self, irc, msg):
|
||||||
if self.disabled(irc):
|
if self.disabled(irc):
|
||||||
return
|
return
|
||||||
|
state = self._getState(irc)
|
||||||
nick = self._getNick(irc.network)
|
nick = self._getNick(irc.network)
|
||||||
s = ircutils.stripFormatting(msg.args[1].lower())
|
s = ircutils.stripFormatting(msg.args[1].lower())
|
||||||
on = 'on %s' % irc.network
|
on = 'on %s' % irc.network
|
||||||
@ -303,19 +315,19 @@ class Services(callbacks.Plugin):
|
|||||||
log = 'Received "Password Incorrect" from NickServ %s. ' \
|
log = 'Received "Password Incorrect" from NickServ %s. ' \
|
||||||
'Resetting password to empty.' % on
|
'Resetting password to empty.' % on
|
||||||
self.log.warning(log)
|
self.log.warning(log)
|
||||||
self.sentGhost = time.time()
|
state.sentGhost = time.time()
|
||||||
self._setNickServPassword(nick, '', irc.network)
|
self._setNickServPassword(nick, '', irc.network)
|
||||||
elif self._ghosted(irc, s):
|
elif self._ghosted(irc, s):
|
||||||
self.log.info('Received "GHOST succeeded" from NickServ %s.', on)
|
self.log.info('Received "GHOST succeeded" from NickServ %s.', on)
|
||||||
self.sentGhost = None
|
state.sentGhost = None
|
||||||
self.identified = False
|
state.identified = False
|
||||||
irc.queueMsg(ircmsgs.nick(nick))
|
irc.queueMsg(ircmsgs.nick(nick))
|
||||||
elif 'is not registered' in s:
|
elif 'is not registered' in s:
|
||||||
self.log.info('Received "Nick not registered" from NickServ %s.',
|
self.log.info('Received "Nick not registered" from NickServ %s.',
|
||||||
on)
|
on)
|
||||||
elif 'currently' in s and 'isn\'t' in s or 'is not' in s:
|
elif 'currently' in s and 'isn\'t' in s or 'is not' in s:
|
||||||
# The nick isn't online, let's change our nick to it.
|
# The nick isn't online, let's change our nick to it.
|
||||||
self.sentGhost = None
|
state.sentGhost = None
|
||||||
irc.queueMsg(ircmsgs.nick(nick))
|
irc.queueMsg(ircmsgs.nick(nick))
|
||||||
elif ('owned by someone else' in s) or \
|
elif ('owned by someone else' in s) or \
|
||||||
('nickname is registered and protected' in s) or \
|
('nickname is registered and protected' in s) or \
|
||||||
@ -339,15 +351,15 @@ class Services(callbacks.Plugin):
|
|||||||
# freenode, oftc, arstechnica, zirc, ....
|
# freenode, oftc, arstechnica, zirc, ....
|
||||||
# sorcery
|
# sorcery
|
||||||
self.log.info('Received "Password accepted" from NickServ %s.', on)
|
self.log.info('Received "Password accepted" from NickServ %s.', on)
|
||||||
self.identified = True
|
state.identified = True
|
||||||
for channel in irc.state.channels.keys():
|
for channel in irc.state.channels.keys():
|
||||||
self.checkPrivileges(irc, channel)
|
self.checkPrivileges(irc, channel)
|
||||||
for channel in self.channels:
|
for channel in state.channels:
|
||||||
irc.queueMsg(networkGroup.channels.join(channel))
|
irc.queueMsg(networkGroup.channels.join(channel))
|
||||||
waitingJoins = self.waitingJoins.pop(irc.network, None)
|
waitingJoins = state.waitingJoins
|
||||||
if waitingJoins:
|
state.waitingJoins = []
|
||||||
for m in waitingJoins:
|
for join in waitingJoins:
|
||||||
irc.sendMsg(m)
|
irc.sendMsg(join)
|
||||||
elif 'not yet authenticated' in s:
|
elif 'not yet authenticated' in s:
|
||||||
# zirc.org has this, it requires an auth code.
|
# zirc.org has this, it requires an auth code.
|
||||||
email = s.split()[-1]
|
email = s.split()[-1]
|
||||||
@ -401,7 +413,8 @@ class Services(callbacks.Plugin):
|
|||||||
channel, on)
|
channel, on)
|
||||||
|
|
||||||
def do366(self, irc, msg): # End of /NAMES list; finished joining a channel
|
def do366(self, irc, msg): # End of /NAMES list; finished joining a channel
|
||||||
if self.identified:
|
state = self._getState(irc)
|
||||||
|
if state.identified:
|
||||||
channel = msg.args[1] # nick is msg.args[0].
|
channel = msg.args[1] # nick is msg.args[0].
|
||||||
self.checkPrivileges(irc, channel)
|
self.checkPrivileges(irc, channel)
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
from supybot.test import *
|
from supybot.test import *
|
||||||
import supybot.conf as conf
|
import supybot.conf as conf
|
||||||
from supybot.ircmsgs import IrcMsg
|
from supybot.ircmsgs import IrcMsg
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
class ServicesTestCase(PluginTestCase):
|
class ServicesTestCase(PluginTestCase):
|
||||||
plugins = ('Services', 'Config')
|
plugins = ('Services', 'Config')
|
||||||
@ -105,6 +106,72 @@ class ServicesTestCase(PluginTestCase):
|
|||||||
self.assertIsNone(self.irc.takeMsg())
|
self.assertIsNone(self.irc.takeMsg())
|
||||||
|
|
||||||
|
|
||||||
|
class JoinsBeforeIdentifiedTestCase(PluginTestCase):
|
||||||
|
plugins = ('Services',)
|
||||||
|
config = {
|
||||||
|
'plugins.Services.noJoinsUntilIdentified': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def testSingleNetwork(self):
|
||||||
|
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
|
||||||
|
self.irc.queueMsg(queuedJoin)
|
||||||
|
self.assertEqual(self.irc.takeMsg(), queuedJoin,
|
||||||
|
'Join request did not go through.')
|
||||||
|
|
||||||
|
|
||||||
|
class NoJoinsUntilIdentifiedTestCase(PluginTestCase):
|
||||||
|
plugins = ('Services',)
|
||||||
|
config = {
|
||||||
|
'plugins.Services.noJoinsUntilIdentified': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _identify(self, irc):
|
||||||
|
irc.feedMsg(IrcMsg(command='376', args=(self.nick,)))
|
||||||
|
msg = irc.takeMsg()
|
||||||
|
self.assertEqual(msg.command, 'PRIVMSG')
|
||||||
|
self.assertEqual(msg.args[0], 'NickServ')
|
||||||
|
irc.feedMsg(ircmsgs.notice(self.nick, 'now identified', 'NickServ'))
|
||||||
|
|
||||||
|
def testSingleNetwork(self):
|
||||||
|
try:
|
||||||
|
self.assertNotError('services password %s secret' % self.nick)
|
||||||
|
queuedJoin = ircmsgs.join('#test', prefix=self.prefix)
|
||||||
|
self.irc.queueMsg(queuedJoin)
|
||||||
|
self.assertIsNone(self.irc.takeMsg(),
|
||||||
|
'Join request went through before identification.')
|
||||||
|
self._identify(self.irc)
|
||||||
|
self.assertEqual(self.irc.takeMsg(), queuedJoin,
|
||||||
|
'Join request did not go through after identification.')
|
||||||
|
finally:
|
||||||
|
self.assertNotError('services password %s ""' % self.nick)
|
||||||
|
|
||||||
|
def testMultipleNetworks(self):
|
||||||
|
try:
|
||||||
|
net1 = copy(self)
|
||||||
|
net1.irc = getTestIrc('testnet1')
|
||||||
|
net1.assertNotError('services password %s secret' % self.nick)
|
||||||
|
net2 = copy(self)
|
||||||
|
net2.irc = getTestIrc('testnet2')
|
||||||
|
net2.assertNotError('services password %s secret' % self.nick)
|
||||||
|
queuedJoin1 = ircmsgs.join('#testchan1', prefix=self.prefix)
|
||||||
|
net1.irc.queueMsg(queuedJoin1)
|
||||||
|
self.assertIsNone(net1.irc.takeMsg(),
|
||||||
|
'Join request 1 went through before identification.')
|
||||||
|
self._identify(net1.irc)
|
||||||
|
self.assertEqual(net1.irc.takeMsg(), queuedJoin1,
|
||||||
|
'Join request 1 did not go through after identification.')
|
||||||
|
queuedJoin2 = ircmsgs.join('#testchan2', prefix=self.prefix)
|
||||||
|
net2.irc.queueMsg(queuedJoin2)
|
||||||
|
self.assertIsNone(net2.irc.takeMsg(),
|
||||||
|
'Join request 2 went through before identification.')
|
||||||
|
self._identify(net2.irc)
|
||||||
|
self.assertEqual(net2.irc.takeMsg(), queuedJoin2,
|
||||||
|
'Join request 2 did not go through after identification.')
|
||||||
|
finally:
|
||||||
|
net1.assertNotError('services password %s ""' % self.nick)
|
||||||
|
net2.assertNotError('services password %s ""' % self.nick)
|
||||||
|
|
||||||
|
|
||||||
class ExperimentalServicesTestCase(PluginTestCase):
|
class ExperimentalServicesTestCase(PluginTestCase):
|
||||||
plugins = ["Services"]
|
plugins = ["Services"]
|
||||||
timeout = 0.1
|
timeout = 0.1
|
||||||
|
Loading…
Reference in New Issue
Block a user