[WIP] Start reworking Irc around a FSM.

To keep track of connection state instead of a complex implicit flow
between handling functions.
This commit is contained in:
Valentin Lorentz 2019-10-25 23:17:10 +02:00
parent 3eb20adaf2
commit 45ff70907f
2 changed files with 186 additions and 26 deletions

View File

@ -30,6 +30,7 @@
import re import re
import copy import copy
import time import time
import enum
import random import random
import base64 import base64
import textwrap import textwrap
@ -389,14 +390,108 @@ class ChannelState(utils.python.Object):
Batch = collections.namedtuple('Batch', 'type arguments messages') Batch = collections.namedtuple('Batch', 'type arguments messages')
class IrcStateFsm(object):
'''Finite State Machine keeping track of what part of the connection
initialization we are in.'''
__slots__ = ('state',)
@enum.unique
class States(enum.Enum):
UNINITIALIZED = 10
'''Nothing received yet (except server notices)'''
INIT_CAP_NEGOTIATION = 20
'''Sent CAP LS, did not send CAP END yet'''
INIT_SASL = 30
'''In an AUTHENTICATE session'''
INIT_WAITING_MOTD = 50
'''Waiting for start of MOTD'''
INIT_MOTD = 60
'''Waiting for end of MOTD'''
CONNECTED = 70
'''Normal state of the connections'''
CONNECTED_SASL = 80
'''Doing SASL authentication in the middle of a connection.'''
def __init__(self):
self.reset()
def reset(self):
self.state = self.States.UNINITIALIZED
def _transition(self, to_state, expected_from=None):
if expected_from is None or self.state in expected_from:
log.debug('transition from %s to %s', self.state, to_state)
self.state = to_state
else:
raise ValueError('unexpected transition to %s while in state %s' %
(to_state, self.state))
def expect_state(self, expected_states):
if self.state not in expected_states:
raise ValueError(('Connection in state %s, but expected to be '
'in state %s') % (self.state, expected_states))
def on_init_messages_sent(self):
'''As soon as USER/NICK/CAP LS are sent'''
self._transition(self.States.INIT_CAP_NEGOTIATION, [
self.States.UNINITIALIZED,
])
def on_sasl_cap(self):
'''Whenever we see the 'sasl' capability in a CAP LS response'''
if self.state == self.States.INIT_CAP_NEGOTIATION:
self._transition(self.States.INIT_SASL)
elif self.state == self.States.CONNECTED:
self._transition(self.States.CONNECTED_SASL)
else:
raise ValueError('Got sasl cap while in state %s' % self.state)
def on_sasl_auth_finished(self):
'''When sasl auth either succeeded or failed.'''
if self.state == self.States.INIT_SASL:
self._transition(self.States.INIT_CAP_NEGOTIATION)
elif self.state == self.States.CONNECTED_SASL:
self._transition(self.States.CONNECTED)
else:
raise ValueError('Finished SASL auth while in state %s' % self.state)
def on_cap_end(self):
'''When we send CAP END'''
self._transition(self.States.INIT_WAITING_MOTD, [
self.States.INIT_CAP_NEGOTIATION,
])
def on_start_motd(self):
'''On 375 (RPL_MOTDSTART)'''
self._transition(self.States.INIT_MOTD, [
self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD,
])
def on_end_motd(self):
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
self._transition(self.States.CONNECTED, [
self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD,
self.States.INIT_MOTD
])
class IrcState(IrcCommandDispatcher, log.Firewalled): class IrcState(IrcCommandDispatcher, log.Firewalled):
"""Maintains state of the Irc connection. Should also become smarter. """Maintains state of the Irc connection. Should also become smarter.
""" """
__firewalled__ = {'addMsg': None} __firewalled__ = {'addMsg': None}
def __init__(self, history=None, supported=None, def __init__(self, history=None, supported=None,
nicksToHostmasks=None, channels=None, nicksToHostmasks=None, channels=None,
capabilities_req=None,
capabilities_ack=None, capabilities_nak=None, capabilities_ack=None, capabilities_nak=None,
capabilities_ls=None): capabilities_ls=None):
self.fsm = IrcStateFsm()
if history is None: if history is None:
history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength()) history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength())
if supported is None: if supported is None:
@ -405,6 +500,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
nicksToHostmasks = ircutils.IrcDict() nicksToHostmasks = ircutils.IrcDict()
if channels is None: if channels is None:
channels = ircutils.IrcDict() channels = ircutils.IrcDict()
self.capabilities_req = capabilities_req or set()
self.capabilities_ack = capabilities_ack or set() self.capabilities_ack = capabilities_ack or set()
self.capabilities_nak = capabilities_nak or set() self.capabilities_nak = capabilities_nak or set()
self.capabilities_ls = capabilities_ls or {} self.capabilities_ls = capabilities_ls or {}
@ -417,6 +513,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
def reset(self): def reset(self):
"""Resets the state to normal, unconnected state.""" """Resets the state to normal, unconnected state."""
self.fsm.reset()
self.history.reset() self.history.reset()
self.history.resize(conf.supybot.protocols.irc.maxHistoryLength()) self.history.resize(conf.supybot.protocols.irc.maxHistoryLength())
self.ircd = None self.ircd = None
@ -424,6 +521,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
self.supported.clear() self.supported.clear()
self.nicksToHostmasks.clear() self.nicksToHostmasks.clear()
self.batches = {} self.batches = {}
self.capabilities_req = set()
self.capabilities_ack = set() self.capabilities_ack = set()
self.capabilities_nak = set() self.capabilities_nak = set()
self.capabilities_ls = {} self.capabilities_ls = {}
@ -1115,6 +1213,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendAuthenticationMessages() self.sendAuthenticationMessages()
self.state.fsm.on_init_messages_sent()
def sendAuthenticationMessages(self): def sendAuthenticationMessages(self):
# Notes: # Notes:
# * using sendMsg instead of queueMsg because these messages cannot # * using sendMsg instead of queueMsg because these messages cannot
@ -1135,10 +1235,43 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.user(self.ident, self.user)) self.sendMsg(ircmsgs.user(self.ident, self.user))
def capUpkeep(self):
self.state.fsm.expect_state([
# Normal CAP ACK / CAP NAK during cap negotiation
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
# CAP ACK / CAP NAK after a CAP NEW (probably)
IrcStateFsm.States.CONNECTED,
])
capabilities_responded = (self.state.capabilities_ack |
self.state.capabilities_nak)
if not capabilities_responded <= self.state.capabilities_req:
log.error('Server responded with unrequested ACK/NAK '
'capabilities: req=%r, ack=%r, nak=%r',
self.state.capabilities_req,
self.state.capabilities_ack,
self.state.capabilities_nak)
self.driver.reconnect()
elif capabilities_responded == self.state.capabilities_req:
log.debug('Got all capabilities ACKed/NAKed')
# We got all the capabilities we asked for
if 'sasl' in self.state.capabilities_ack:
if self.state.fsm.state in [
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
IrcStateFsm.States.CONNECTED]:
self._maybeStartSasl()
else:
pass # Already in the middle of a SASL auth
else:
self.endCapabilityNegociation()
else:
log.debug('Waiting for ACK/NAK of capabilities: %r',
self.state.capabilities_req - capabilities_responded)
pass # Do nothing, we'll get more
def endCapabilityNegociation(self): def endCapabilityNegociation(self):
if not self.capNegociationEnded: self.state.fsm.on_cap_end()
self.capNegociationEnded = True self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
def sendSaslString(self, string): def sendSaslString(self, string):
for chunk in ircutils.authenticate_generator(string): for chunk in ircutils.authenticate_generator(string):
@ -1146,6 +1279,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
args=(chunk,))) args=(chunk,)))
def tryNextSaslMechanism(self): def tryNextSaslMechanism(self):
self.state.fsm.expect_state([
IrcStateFsm.States.INIT_SASL,
IrcStateFsm.States.CONNECTED_SASL,
])
if self.sasl_next_mechanisms: if self.sasl_next_mechanisms:
self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0) self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0)
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
@ -1155,15 +1292,30 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
'aborting connection.') 'aborting connection.')
else: else:
self.sasl_current_mechanism = None self.sasl_current_mechanism = None
self.endCapabilityNegociation() self.state.fsm.on_sasl_auth_finished()
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation()
def filterSaslMechanisms(self, available): def _maybeStartSasl(self):
available = set(map(str.lower, available)) if not self.sasl_authenticated and \
self.sasl_next_mechanisms = [ 'sasl' in self.state.capabilities_ack:
x for x in self.sasl_next_mechanisms self.state.fsm.on_sasl_cap()
if x.lower() in available] assert 'sasl' in self.state.capabilities_ls, (
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
'"CAP NEW sasl" first.')
s = self.state.capabilities_ls['sasl']
if s is not None:
available = set(map(str.lower, s.split(',')))
self.sasl_next_mechanisms = [
x for x in self.sasl_next_mechanisms
if x.lower() in available]
self.tryNextSaslMechanism()
def doAuthenticate(self, msg): def doAuthenticate(self, msg):
self.state.fsm.expect_state([
IrcStateFsm.States.INIT_SASL,
IrcStateFsm.States.CONNECTED_SASL,
])
if not self.authenticate_decoder: if not self.authenticate_decoder:
self.authenticate_decoder = ircutils.AuthenticateDecoder() self.authenticate_decoder = ircutils.AuthenticateDecoder()
self.authenticate_decoder.feed(msg) self.authenticate_decoder.feed(msg)
@ -1265,7 +1417,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
def do903(self, msg): def do903(self, msg):
log.info('%s: SASL authentication successful', self.network) log.info('%s: SASL authentication successful', self.network)
self.sasl_authenticated = True self.sasl_authenticated = True
self.endCapabilityNegociation() self.state.fsm.on_sasl_auth_finished()
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation()
def do904(self, msg): def do904(self, msg):
log.warning('%s: SASL authentication failed (mechanism: %s)', log.warning('%s: SASL authentication failed (mechanism: %s)',
@ -1301,10 +1455,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.network, caps) self.network, caps)
self.state.capabilities_ack.update(caps) self.state.capabilities_ack.update(caps)
if 'sasl' in caps: self.capUpkeep()
self.tryNextSaslMechanism()
else:
self.endCapabilityNegociation()
def doCapNak(self, msg): def doCapNak(self, msg):
if len(msg.args) != 3: if len(msg.args) != 3:
log.warning('Bad CAP NAK from server: %r', msg) log.warning('Bad CAP NAK from server: %r', msg)
@ -1314,7 +1466,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_nak.update(caps) self.state.capabilities_nak.update(caps)
log.warning('%s: Server refused capabilities: %L', log.warning('%s: Server refused capabilities: %L',
self.network, caps) self.network, caps)
self.endCapabilityNegociation() self.capUpkeep()
def _addCapabilities(self, capstring): def _addCapabilities(self, capstring):
for item in capstring.split(): for item in capstring.split():
while item.startswith(('=', '~')): while item.startswith(('=', '~')):
@ -1324,6 +1477,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_ls[cap] = value self.state.capabilities_ls[cap] = value
else: else:
self.state.capabilities_ls[item] = None self.state.capabilities_ls[item] = None
def doCapLs(self, msg): def doCapLs(self, msg):
if len(msg.args) == 4: if len(msg.args) == 4:
# Multi-line LS # Multi-line LS
@ -1333,12 +1487,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self._addCapabilities(msg.args[3]) self._addCapabilities(msg.args[3])
elif len(msg.args) == 3: # End of LS elif len(msg.args) == 3: # End of LS
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2])
self.state.fsm.expect_state([
if 'sasl' in self.state.capabilities_ls: # Normal case:
s = self.state.capabilities_ls['sasl'] IrcStateFsm.States.INIT_CAP_NEGOTIATION,
if s is not None: # Should only happen if a plugin sends a CAP LS (which they
self.filterSaslMechanisms(set(s.split(','))) # shouldn't do):
IrcStateFsm.States.CONNECTED,
IrcStateFsm.States.CONNECTED_SASL,
])
# Normally at this point, self.state.capabilities_ack should be # Normally at this point, self.state.capabilities_ack should be
# empty; but let's just make sure we're not requesting the same # empty; but let's just make sure we're not requesting the same
# caps twice for no reason. # caps twice for no reason.
@ -1356,6 +1512,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
else: else:
log.warning('Bad CAP LS from server: %r', msg) log.warning('Bad CAP LS from server: %r', msg)
return return
def doCapDel(self, msg): def doCapDel(self, msg):
if len(msg.args) != 3: if len(msg.args) != 3:
log.warning('Bad CAP DEL from server: %r', msg) log.warning('Bad CAP DEL from server: %r', msg)
@ -1374,18 +1531,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_ack.remove(cap) self.state.capabilities_ack.remove(cap)
except KeyError: except KeyError:
pass pass
def doCapNew(self, msg): def doCapNew(self, msg):
# Note that in theory, this method may be called at any time, even
# before CAP END (or even before the initial CAP LS).
if len(msg.args) != 3: if len(msg.args) != 3:
log.warning('Bad CAP NEW from server: %r', msg) log.warning('Bad CAP NEW from server: %r', msg)
return return
caps = msg.args[2].split() caps = msg.args[2].split()
assert caps, 'Empty list of capabilities' assert caps, 'Empty list of capabilities'
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2])
if not self.sasl_authenticated and 'sasl' in self.state.capabilities_ls:
self.resetSasl()
s = self.state.capabilities_ls['sasl']
if s is not None:
self.filterSaslMechanisms(set(s.split(',')))
common_supported_unrequested_capabilities = ( common_supported_unrequested_capabilities = (
set(self.state.capabilities_ls) & set(self.state.capabilities_ls) &
self.REQUEST_CAPABILITIES - self.REQUEST_CAPABILITIES -
@ -1394,6 +1549,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self._requestCaps(common_supported_unrequested_capabilities) self._requestCaps(common_supported_unrequested_capabilities)
def _requestCaps(self, caps): def _requestCaps(self, caps):
self.state.capabilities_req |= caps
caps = ' '.join(sorted(caps)) caps = ' '.join(sorted(caps))
# textwrap works here because in ASCII, all chars are 1 bytes: # textwrap works here because in ASCII, all chars are 1 bytes:
cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :')) cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :'))
@ -1474,6 +1631,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.outstandingPing = False self.outstandingPing = False
def do376(self, msg): def do376(self, msg):
self.state.fsm.on_end_motd()
log.info('Got end of MOTD from %s', self.server) log.info('Got end of MOTD from %s', self.server)
self.afterConnect = True self.afterConnect = True
# Let's reset nicks in case we had to use a weird one. # Let's reset nicks in case we had to use a weird one.

View File

@ -832,6 +832,8 @@ class SaslTestCase(SupyTestCase):
while self.irc.takeMsg(): while self.irc.takeMsg():
pass pass
self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'NEW', 'sasl=EXTERNAL'))) args=('*', 'NEW', 'sasl=EXTERNAL')))