mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-11-29 22:29:24 +01:00
[WIP] Start reworking Irc around a FSM.
To keep track of connection state instead of a complex implicit flow between handling functions.
This commit is contained in:
parent
3eb20adaf2
commit
45ff70907f
198
src/irclib.py
198
src/irclib.py
@ -30,6 +30,7 @@
|
|||||||
import re
|
import re
|
||||||
import copy
|
import copy
|
||||||
import time
|
import time
|
||||||
|
import enum
|
||||||
import random
|
import random
|
||||||
import base64
|
import base64
|
||||||
import textwrap
|
import textwrap
|
||||||
@ -389,14 +390,108 @@ class ChannelState(utils.python.Object):
|
|||||||
|
|
||||||
Batch = collections.namedtuple('Batch', 'type arguments messages')
|
Batch = collections.namedtuple('Batch', 'type arguments messages')
|
||||||
|
|
||||||
|
class IrcStateFsm(object):
|
||||||
|
'''Finite State Machine keeping track of what part of the connection
|
||||||
|
initialization we are in.'''
|
||||||
|
__slots__ = ('state',)
|
||||||
|
|
||||||
|
@enum.unique
|
||||||
|
class States(enum.Enum):
|
||||||
|
UNINITIALIZED = 10
|
||||||
|
'''Nothing received yet (except server notices)'''
|
||||||
|
|
||||||
|
INIT_CAP_NEGOTIATION = 20
|
||||||
|
'''Sent CAP LS, did not send CAP END yet'''
|
||||||
|
|
||||||
|
INIT_SASL = 30
|
||||||
|
'''In an AUTHENTICATE session'''
|
||||||
|
|
||||||
|
INIT_WAITING_MOTD = 50
|
||||||
|
'''Waiting for start of MOTD'''
|
||||||
|
|
||||||
|
INIT_MOTD = 60
|
||||||
|
'''Waiting for end of MOTD'''
|
||||||
|
|
||||||
|
CONNECTED = 70
|
||||||
|
'''Normal state of the connections'''
|
||||||
|
|
||||||
|
CONNECTED_SASL = 80
|
||||||
|
'''Doing SASL authentication in the middle of a connection.'''
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.state = self.States.UNINITIALIZED
|
||||||
|
|
||||||
|
def _transition(self, to_state, expected_from=None):
|
||||||
|
if expected_from is None or self.state in expected_from:
|
||||||
|
log.debug('transition from %s to %s', self.state, to_state)
|
||||||
|
self.state = to_state
|
||||||
|
else:
|
||||||
|
raise ValueError('unexpected transition to %s while in state %s' %
|
||||||
|
(to_state, self.state))
|
||||||
|
|
||||||
|
def expect_state(self, expected_states):
|
||||||
|
if self.state not in expected_states:
|
||||||
|
raise ValueError(('Connection in state %s, but expected to be '
|
||||||
|
'in state %s') % (self.state, expected_states))
|
||||||
|
|
||||||
|
def on_init_messages_sent(self):
|
||||||
|
'''As soon as USER/NICK/CAP LS are sent'''
|
||||||
|
self._transition(self.States.INIT_CAP_NEGOTIATION, [
|
||||||
|
self.States.UNINITIALIZED,
|
||||||
|
])
|
||||||
|
|
||||||
|
def on_sasl_cap(self):
|
||||||
|
'''Whenever we see the 'sasl' capability in a CAP LS response'''
|
||||||
|
if self.state == self.States.INIT_CAP_NEGOTIATION:
|
||||||
|
self._transition(self.States.INIT_SASL)
|
||||||
|
elif self.state == self.States.CONNECTED:
|
||||||
|
self._transition(self.States.CONNECTED_SASL)
|
||||||
|
else:
|
||||||
|
raise ValueError('Got sasl cap while in state %s' % self.state)
|
||||||
|
|
||||||
|
def on_sasl_auth_finished(self):
|
||||||
|
'''When sasl auth either succeeded or failed.'''
|
||||||
|
if self.state == self.States.INIT_SASL:
|
||||||
|
self._transition(self.States.INIT_CAP_NEGOTIATION)
|
||||||
|
elif self.state == self.States.CONNECTED_SASL:
|
||||||
|
self._transition(self.States.CONNECTED)
|
||||||
|
else:
|
||||||
|
raise ValueError('Finished SASL auth while in state %s' % self.state)
|
||||||
|
|
||||||
|
def on_cap_end(self):
|
||||||
|
'''When we send CAP END'''
|
||||||
|
self._transition(self.States.INIT_WAITING_MOTD, [
|
||||||
|
self.States.INIT_CAP_NEGOTIATION,
|
||||||
|
])
|
||||||
|
|
||||||
|
def on_start_motd(self):
|
||||||
|
'''On 375 (RPL_MOTDSTART)'''
|
||||||
|
self._transition(self.States.INIT_MOTD, [
|
||||||
|
self.States.INIT_CAP_NEGOTIATION,
|
||||||
|
self.States.INIT_WAITING_MOTD,
|
||||||
|
])
|
||||||
|
|
||||||
|
def on_end_motd(self):
|
||||||
|
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
|
||||||
|
self._transition(self.States.CONNECTED, [
|
||||||
|
self.States.INIT_CAP_NEGOTIATION,
|
||||||
|
self.States.INIT_WAITING_MOTD,
|
||||||
|
self.States.INIT_MOTD
|
||||||
|
])
|
||||||
|
|
||||||
class IrcState(IrcCommandDispatcher, log.Firewalled):
|
class IrcState(IrcCommandDispatcher, log.Firewalled):
|
||||||
"""Maintains state of the Irc connection. Should also become smarter.
|
"""Maintains state of the Irc connection. Should also become smarter.
|
||||||
"""
|
"""
|
||||||
__firewalled__ = {'addMsg': None}
|
__firewalled__ = {'addMsg': None}
|
||||||
def __init__(self, history=None, supported=None,
|
def __init__(self, history=None, supported=None,
|
||||||
nicksToHostmasks=None, channels=None,
|
nicksToHostmasks=None, channels=None,
|
||||||
|
capabilities_req=None,
|
||||||
capabilities_ack=None, capabilities_nak=None,
|
capabilities_ack=None, capabilities_nak=None,
|
||||||
capabilities_ls=None):
|
capabilities_ls=None):
|
||||||
|
self.fsm = IrcStateFsm()
|
||||||
if history is None:
|
if history is None:
|
||||||
history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength())
|
history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength())
|
||||||
if supported is None:
|
if supported is None:
|
||||||
@ -405,6 +500,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
|
|||||||
nicksToHostmasks = ircutils.IrcDict()
|
nicksToHostmasks = ircutils.IrcDict()
|
||||||
if channels is None:
|
if channels is None:
|
||||||
channels = ircutils.IrcDict()
|
channels = ircutils.IrcDict()
|
||||||
|
self.capabilities_req = capabilities_req or set()
|
||||||
self.capabilities_ack = capabilities_ack or set()
|
self.capabilities_ack = capabilities_ack or set()
|
||||||
self.capabilities_nak = capabilities_nak or set()
|
self.capabilities_nak = capabilities_nak or set()
|
||||||
self.capabilities_ls = capabilities_ls or {}
|
self.capabilities_ls = capabilities_ls or {}
|
||||||
@ -417,6 +513,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Resets the state to normal, unconnected state."""
|
"""Resets the state to normal, unconnected state."""
|
||||||
|
self.fsm.reset()
|
||||||
self.history.reset()
|
self.history.reset()
|
||||||
self.history.resize(conf.supybot.protocols.irc.maxHistoryLength())
|
self.history.resize(conf.supybot.protocols.irc.maxHistoryLength())
|
||||||
self.ircd = None
|
self.ircd = None
|
||||||
@ -424,6 +521,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.supported.clear()
|
self.supported.clear()
|
||||||
self.nicksToHostmasks.clear()
|
self.nicksToHostmasks.clear()
|
||||||
self.batches = {}
|
self.batches = {}
|
||||||
|
self.capabilities_req = set()
|
||||||
self.capabilities_ack = set()
|
self.capabilities_ack = set()
|
||||||
self.capabilities_nak = set()
|
self.capabilities_nak = set()
|
||||||
self.capabilities_ls = {}
|
self.capabilities_ls = {}
|
||||||
@ -1115,6 +1213,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
|
|
||||||
self.sendAuthenticationMessages()
|
self.sendAuthenticationMessages()
|
||||||
|
|
||||||
|
self.state.fsm.on_init_messages_sent()
|
||||||
|
|
||||||
def sendAuthenticationMessages(self):
|
def sendAuthenticationMessages(self):
|
||||||
# Notes:
|
# Notes:
|
||||||
# * using sendMsg instead of queueMsg because these messages cannot
|
# * using sendMsg instead of queueMsg because these messages cannot
|
||||||
@ -1135,9 +1235,42 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
|
|
||||||
self.sendMsg(ircmsgs.user(self.ident, self.user))
|
self.sendMsg(ircmsgs.user(self.ident, self.user))
|
||||||
|
|
||||||
|
def capUpkeep(self):
|
||||||
|
self.state.fsm.expect_state([
|
||||||
|
# Normal CAP ACK / CAP NAK during cap negotiation
|
||||||
|
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||||
|
# CAP ACK / CAP NAK after a CAP NEW (probably)
|
||||||
|
IrcStateFsm.States.CONNECTED,
|
||||||
|
])
|
||||||
|
|
||||||
|
capabilities_responded = (self.state.capabilities_ack |
|
||||||
|
self.state.capabilities_nak)
|
||||||
|
if not capabilities_responded <= self.state.capabilities_req:
|
||||||
|
log.error('Server responded with unrequested ACK/NAK '
|
||||||
|
'capabilities: req=%r, ack=%r, nak=%r',
|
||||||
|
self.state.capabilities_req,
|
||||||
|
self.state.capabilities_ack,
|
||||||
|
self.state.capabilities_nak)
|
||||||
|
self.driver.reconnect()
|
||||||
|
elif capabilities_responded == self.state.capabilities_req:
|
||||||
|
log.debug('Got all capabilities ACKed/NAKed')
|
||||||
|
# We got all the capabilities we asked for
|
||||||
|
if 'sasl' in self.state.capabilities_ack:
|
||||||
|
if self.state.fsm.state in [
|
||||||
|
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||||
|
IrcStateFsm.States.CONNECTED]:
|
||||||
|
self._maybeStartSasl()
|
||||||
|
else:
|
||||||
|
pass # Already in the middle of a SASL auth
|
||||||
|
else:
|
||||||
|
self.endCapabilityNegociation()
|
||||||
|
else:
|
||||||
|
log.debug('Waiting for ACK/NAK of capabilities: %r',
|
||||||
|
self.state.capabilities_req - capabilities_responded)
|
||||||
|
pass # Do nothing, we'll get more
|
||||||
|
|
||||||
def endCapabilityNegociation(self):
|
def endCapabilityNegociation(self):
|
||||||
if not self.capNegociationEnded:
|
self.state.fsm.on_cap_end()
|
||||||
self.capNegociationEnded = True
|
|
||||||
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
|
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
|
||||||
|
|
||||||
def sendSaslString(self, string):
|
def sendSaslString(self, string):
|
||||||
@ -1146,6 +1279,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
args=(chunk,)))
|
args=(chunk,)))
|
||||||
|
|
||||||
def tryNextSaslMechanism(self):
|
def tryNextSaslMechanism(self):
|
||||||
|
self.state.fsm.expect_state([
|
||||||
|
IrcStateFsm.States.INIT_SASL,
|
||||||
|
IrcStateFsm.States.CONNECTED_SASL,
|
||||||
|
])
|
||||||
if self.sasl_next_mechanisms:
|
if self.sasl_next_mechanisms:
|
||||||
self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0)
|
self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0)
|
||||||
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
|
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
|
||||||
@ -1155,15 +1292,30 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
'aborting connection.')
|
'aborting connection.')
|
||||||
else:
|
else:
|
||||||
self.sasl_current_mechanism = None
|
self.sasl_current_mechanism = None
|
||||||
|
self.state.fsm.on_sasl_auth_finished()
|
||||||
|
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
||||||
self.endCapabilityNegociation()
|
self.endCapabilityNegociation()
|
||||||
|
|
||||||
def filterSaslMechanisms(self, available):
|
def _maybeStartSasl(self):
|
||||||
available = set(map(str.lower, available))
|
if not self.sasl_authenticated and \
|
||||||
|
'sasl' in self.state.capabilities_ack:
|
||||||
|
self.state.fsm.on_sasl_cap()
|
||||||
|
assert 'sasl' in self.state.capabilities_ls, (
|
||||||
|
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
|
||||||
|
'"CAP NEW sasl" first.')
|
||||||
|
s = self.state.capabilities_ls['sasl']
|
||||||
|
if s is not None:
|
||||||
|
available = set(map(str.lower, s.split(',')))
|
||||||
self.sasl_next_mechanisms = [
|
self.sasl_next_mechanisms = [
|
||||||
x for x in self.sasl_next_mechanisms
|
x for x in self.sasl_next_mechanisms
|
||||||
if x.lower() in available]
|
if x.lower() in available]
|
||||||
|
self.tryNextSaslMechanism()
|
||||||
|
|
||||||
def doAuthenticate(self, msg):
|
def doAuthenticate(self, msg):
|
||||||
|
self.state.fsm.expect_state([
|
||||||
|
IrcStateFsm.States.INIT_SASL,
|
||||||
|
IrcStateFsm.States.CONNECTED_SASL,
|
||||||
|
])
|
||||||
if not self.authenticate_decoder:
|
if not self.authenticate_decoder:
|
||||||
self.authenticate_decoder = ircutils.AuthenticateDecoder()
|
self.authenticate_decoder = ircutils.AuthenticateDecoder()
|
||||||
self.authenticate_decoder.feed(msg)
|
self.authenticate_decoder.feed(msg)
|
||||||
@ -1265,6 +1417,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
def do903(self, msg):
|
def do903(self, msg):
|
||||||
log.info('%s: SASL authentication successful', self.network)
|
log.info('%s: SASL authentication successful', self.network)
|
||||||
self.sasl_authenticated = True
|
self.sasl_authenticated = True
|
||||||
|
self.state.fsm.on_sasl_auth_finished()
|
||||||
|
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
||||||
self.endCapabilityNegociation()
|
self.endCapabilityNegociation()
|
||||||
|
|
||||||
def do904(self, msg):
|
def do904(self, msg):
|
||||||
@ -1301,10 +1455,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.network, caps)
|
self.network, caps)
|
||||||
self.state.capabilities_ack.update(caps)
|
self.state.capabilities_ack.update(caps)
|
||||||
|
|
||||||
if 'sasl' in caps:
|
self.capUpkeep()
|
||||||
self.tryNextSaslMechanism()
|
|
||||||
else:
|
|
||||||
self.endCapabilityNegociation()
|
|
||||||
def doCapNak(self, msg):
|
def doCapNak(self, msg):
|
||||||
if len(msg.args) != 3:
|
if len(msg.args) != 3:
|
||||||
log.warning('Bad CAP NAK from server: %r', msg)
|
log.warning('Bad CAP NAK from server: %r', msg)
|
||||||
@ -1314,7 +1466,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.state.capabilities_nak.update(caps)
|
self.state.capabilities_nak.update(caps)
|
||||||
log.warning('%s: Server refused capabilities: %L',
|
log.warning('%s: Server refused capabilities: %L',
|
||||||
self.network, caps)
|
self.network, caps)
|
||||||
self.endCapabilityNegociation()
|
self.capUpkeep()
|
||||||
|
|
||||||
def _addCapabilities(self, capstring):
|
def _addCapabilities(self, capstring):
|
||||||
for item in capstring.split():
|
for item in capstring.split():
|
||||||
while item.startswith(('=', '~')):
|
while item.startswith(('=', '~')):
|
||||||
@ -1324,6 +1477,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.state.capabilities_ls[cap] = value
|
self.state.capabilities_ls[cap] = value
|
||||||
else:
|
else:
|
||||||
self.state.capabilities_ls[item] = None
|
self.state.capabilities_ls[item] = None
|
||||||
|
|
||||||
def doCapLs(self, msg):
|
def doCapLs(self, msg):
|
||||||
if len(msg.args) == 4:
|
if len(msg.args) == 4:
|
||||||
# Multi-line LS
|
# Multi-line LS
|
||||||
@ -1333,12 +1487,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self._addCapabilities(msg.args[3])
|
self._addCapabilities(msg.args[3])
|
||||||
elif len(msg.args) == 3: # End of LS
|
elif len(msg.args) == 3: # End of LS
|
||||||
self._addCapabilities(msg.args[2])
|
self._addCapabilities(msg.args[2])
|
||||||
|
self.state.fsm.expect_state([
|
||||||
if 'sasl' in self.state.capabilities_ls:
|
# Normal case:
|
||||||
s = self.state.capabilities_ls['sasl']
|
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||||
if s is not None:
|
# Should only happen if a plugin sends a CAP LS (which they
|
||||||
self.filterSaslMechanisms(set(s.split(',')))
|
# shouldn't do):
|
||||||
|
IrcStateFsm.States.CONNECTED,
|
||||||
|
IrcStateFsm.States.CONNECTED_SASL,
|
||||||
|
])
|
||||||
# Normally at this point, self.state.capabilities_ack should be
|
# Normally at this point, self.state.capabilities_ack should be
|
||||||
# empty; but let's just make sure we're not requesting the same
|
# empty; but let's just make sure we're not requesting the same
|
||||||
# caps twice for no reason.
|
# caps twice for no reason.
|
||||||
@ -1356,6 +1512,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
else:
|
else:
|
||||||
log.warning('Bad CAP LS from server: %r', msg)
|
log.warning('Bad CAP LS from server: %r', msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
def doCapDel(self, msg):
|
def doCapDel(self, msg):
|
||||||
if len(msg.args) != 3:
|
if len(msg.args) != 3:
|
||||||
log.warning('Bad CAP DEL from server: %r', msg)
|
log.warning('Bad CAP DEL from server: %r', msg)
|
||||||
@ -1374,18 +1531,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.state.capabilities_ack.remove(cap)
|
self.state.capabilities_ack.remove(cap)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def doCapNew(self, msg):
|
def doCapNew(self, msg):
|
||||||
|
# Note that in theory, this method may be called at any time, even
|
||||||
|
# before CAP END (or even before the initial CAP LS).
|
||||||
if len(msg.args) != 3:
|
if len(msg.args) != 3:
|
||||||
log.warning('Bad CAP NEW from server: %r', msg)
|
log.warning('Bad CAP NEW from server: %r', msg)
|
||||||
return
|
return
|
||||||
caps = msg.args[2].split()
|
caps = msg.args[2].split()
|
||||||
assert caps, 'Empty list of capabilities'
|
assert caps, 'Empty list of capabilities'
|
||||||
self._addCapabilities(msg.args[2])
|
self._addCapabilities(msg.args[2])
|
||||||
if not self.sasl_authenticated and 'sasl' in self.state.capabilities_ls:
|
|
||||||
self.resetSasl()
|
|
||||||
s = self.state.capabilities_ls['sasl']
|
|
||||||
if s is not None:
|
|
||||||
self.filterSaslMechanisms(set(s.split(',')))
|
|
||||||
common_supported_unrequested_capabilities = (
|
common_supported_unrequested_capabilities = (
|
||||||
set(self.state.capabilities_ls) &
|
set(self.state.capabilities_ls) &
|
||||||
self.REQUEST_CAPABILITIES -
|
self.REQUEST_CAPABILITIES -
|
||||||
@ -1394,6 +1549,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self._requestCaps(common_supported_unrequested_capabilities)
|
self._requestCaps(common_supported_unrequested_capabilities)
|
||||||
|
|
||||||
def _requestCaps(self, caps):
|
def _requestCaps(self, caps):
|
||||||
|
self.state.capabilities_req |= caps
|
||||||
|
|
||||||
caps = ' '.join(sorted(caps))
|
caps = ' '.join(sorted(caps))
|
||||||
# textwrap works here because in ASCII, all chars are 1 bytes:
|
# textwrap works here because in ASCII, all chars are 1 bytes:
|
||||||
cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :'))
|
cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :'))
|
||||||
@ -1474,6 +1631,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.outstandingPing = False
|
self.outstandingPing = False
|
||||||
|
|
||||||
def do376(self, msg):
|
def do376(self, msg):
|
||||||
|
self.state.fsm.on_end_motd()
|
||||||
log.info('Got end of MOTD from %s', self.server)
|
log.info('Got end of MOTD from %s', self.server)
|
||||||
self.afterConnect = True
|
self.afterConnect = True
|
||||||
# Let's reset nicks in case we had to use a weird one.
|
# Let's reset nicks in case we had to use a weird one.
|
||||||
|
@ -832,6 +832,8 @@ class SaslTestCase(SupyTestCase):
|
|||||||
while self.irc.takeMsg():
|
while self.irc.takeMsg():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD
|
||||||
|
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'NEW', 'sasl=EXTERNAL')))
|
args=('*', 'NEW', 'sasl=EXTERNAL')))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user