Add postTransition method to IrcCallback, called when irc.state.fsm changes.

This commit is contained in:
Valentin Lorentz 2020-05-01 20:19:53 +02:00
parent f7130f2629
commit 309fc1233b

View File

@ -122,6 +122,7 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
'__call__': None, '__call__': None,
'inFilter': lambda self, irc, msg: msg, 'inFilter': lambda self, irc, msg: msg,
'outFilter': lambda self, irc, msg: msg, 'outFilter': lambda self, irc, msg: msg,
'postTransition': None,
'name': lambda self: self.__class__.__name__, 'name': lambda self: self.__class__.__name__,
'callPrecedence': lambda self, irc: ([], []), 'callPrecedence': lambda self, irc: ([], []),
} }
@ -172,6 +173,12 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
""" """
return msg return msg
def postTransition(self, irc, msg, from_state, to_state):
"""Called when the state of the IRC connection changes.
`msg` is the message that triggered the transition, if any."""
pass
def __call__(self, irc, msg): def __call__(self, irc, msg):
"""Used for handling each message.""" """Used for handling each message."""
method = self.dispatchCommand(msg.command, msg.args) method = self.dispatchCommand(msg.command, msg.args)
@ -430,10 +437,24 @@ class IrcStateFsm(object):
self.state, self.States.UNINITIALIZED) self.state, self.States.UNINITIALIZED)
self.state = self.States.UNINITIALIZED self.state = self.States.UNINITIALIZED
def _transition(self, to_state, expected_from=None): def _transition(self, irc, msg, to_state, expected_from=None):
if expected_from is None or self.state in expected_from: """Transitions to state `to_state`.
If `expected_from` is not `None`, first checks the current state is
in the set.
After the transition, calls the
`postTransition(irc, msg, from_state, to_state)` method of all objects
in `irc.callbacks`.
`msg` may be None if the transition isn't triggered by a message, but
`irc` may not."""
from_state = self.state
if expected_from is None or from_state in expected_from:
log.debug('transition from %s to %s', self.state, to_state) log.debug('transition from %s to %s', self.state, to_state)
self.state = to_state self.state = to_state
for callback in reversed(irc.callbacks):
msg = callback.postTransition(irc, msg, from_state, to_state)
else: else:
raise ValueError('unexpected transition to %s while in state %s' % raise ValueError('unexpected transition to %s while in state %s' %
(to_state, self.state)) (to_state, self.state))
@ -443,53 +464,53 @@ class IrcStateFsm(object):
raise ValueError(('Connection in state %s, but expected to be ' raise ValueError(('Connection in state %s, but expected to be '
'in state %s') % (self.state, expected_states)) 'in state %s') % (self.state, expected_states))
def on_init_messages_sent(self): def on_init_messages_sent(self, irc):
'''As soon as USER/NICK/CAP LS are sent''' '''As soon as USER/NICK/CAP LS are sent'''
self._transition(self.States.INIT_CAP_NEGOTIATION, [ self._transition(irc, None, self.States.INIT_CAP_NEGOTIATION, [
self.States.UNINITIALIZED, self.States.UNINITIALIZED,
]) ])
def on_sasl_cap(self): def on_sasl_cap(self, irc, msg):
'''Whenever we see the 'sasl' capability in a CAP LS response''' '''Whenever we see the 'sasl' capability in a CAP LS response'''
if self.state == self.States.INIT_CAP_NEGOTIATION: if self.state == self.States.INIT_CAP_NEGOTIATION:
self._transition(self.States.INIT_SASL) self._transition(irc, msg, self.States.INIT_SASL)
elif self.state == self.States.CONNECTED: elif self.state == self.States.CONNECTED:
self._transition(self.States.CONNECTED_SASL) self._transition(irc, msg, self.States.CONNECTED_SASL)
else: else:
raise ValueError('Got sasl cap while in state %s' % self.state) raise ValueError('Got sasl cap while in state %s' % self.state)
def on_sasl_auth_finished(self): def on_sasl_auth_finished(self, irc, msg):
'''When sasl auth either succeeded or failed.''' '''When sasl auth either succeeded or failed.'''
if self.state == self.States.INIT_SASL: if self.state == self.States.INIT_SASL:
self._transition(self.States.INIT_CAP_NEGOTIATION) self._transition(irc, msg, self.States.INIT_CAP_NEGOTIATION)
elif self.state == self.States.CONNECTED_SASL: elif self.state == self.States.CONNECTED_SASL:
self._transition(self.States.CONNECTED) self._transition(irc, msg, self.States.CONNECTED)
else: else:
raise ValueError('Finished SASL auth while in state %s' % self.state) raise ValueError('Finished SASL auth while in state %s' % self.state)
def on_cap_end(self): def on_cap_end(self, irc, msg):
'''When we send CAP END''' '''When we send CAP END'''
self._transition(self.States.INIT_WAITING_MOTD, [ self._transition(irc, msg, self.States.INIT_WAITING_MOTD, [
self.States.INIT_CAP_NEGOTIATION, self.States.INIT_CAP_NEGOTIATION,
]) ])
def on_start_motd(self): def on_start_motd(self, irc, msg):
'''On 375 (RPL_MOTDSTART)''' '''On 375 (RPL_MOTDSTART)'''
self._transition(self.States.INIT_MOTD, [ self._transition(irc, msg, self.States.INIT_MOTD, [
self.States.INIT_CAP_NEGOTIATION, self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD, self.States.INIT_WAITING_MOTD,
]) ])
def on_end_motd(self): def on_end_motd(self, irc, msg):
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)''' '''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
self._transition(self.States.CONNECTED, [ self._transition(irc, msg, self.States.CONNECTED, [
self.States.INIT_CAP_NEGOTIATION, self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD, self.States.INIT_WAITING_MOTD,
self.States.INIT_MOTD self.States.INIT_MOTD
]) ])
def on_shutdown(self): def on_shutdown(self, irc, msg):
self._transition(self.States.SHUTTING_DOWN) self._transition(irc, msg, self.States.SHUTTING_DOWN)
class IrcState(IrcCommandDispatcher, log.Firewalled): class IrcState(IrcCommandDispatcher, log.Firewalled):
"""Maintains state of the Irc connection. Should also become smarter. """Maintains state of the Irc connection. Should also become smarter.
@ -1222,7 +1243,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendAuthenticationMessages() self.sendAuthenticationMessages()
self.state.fsm.on_init_messages_sent() self.state.fsm.on_init_messages_sent(self)
def sendAuthenticationMessages(self): def sendAuthenticationMessages(self):
# Notes: # Notes:
@ -1244,7 +1265,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.user(self.ident, self.user)) self.sendMsg(ircmsgs.user(self.ident, self.user))
def capUpkeep(self): def capUpkeep(self, msg):
"""
Called after getting a CAP ACK/NAK to check it's consistent with what
was requested, and to end the cap negotiation when we received all the
ACK/NAKs we were waiting for.
`msg` is the message that triggered this call."""
self.state.fsm.expect_state([ self.state.fsm.expect_state([
# Normal CAP ACK / CAP NAK during cap negotiation # Normal CAP ACK / CAP NAK during cap negotiation
IrcStateFsm.States.INIT_CAP_NEGOTIATION, IrcStateFsm.States.INIT_CAP_NEGOTIATION,
@ -1268,18 +1295,18 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if self.state.fsm.state in [ if self.state.fsm.state in [
IrcStateFsm.States.INIT_CAP_NEGOTIATION, IrcStateFsm.States.INIT_CAP_NEGOTIATION,
IrcStateFsm.States.CONNECTED]: IrcStateFsm.States.CONNECTED]:
self._maybeStartSasl() self._maybeStartSasl(msg)
else: else:
pass # Already in the middle of a SASL auth pass # Already in the middle of a SASL auth
else: else:
self.endCapabilityNegociation() self.endCapabilityNegociation(msg)
else: else:
log.debug('Waiting for ACK/NAK of capabilities: %r', log.debug('Waiting for ACK/NAK of capabilities: %r',
self.state.capabilities_req - capabilities_responded) self.state.capabilities_req - capabilities_responded)
pass # Do nothing, we'll get more pass # Do nothing, we'll get more
def endCapabilityNegociation(self): def endCapabilityNegociation(self, msg):
self.state.fsm.on_cap_end() self.state.fsm.on_cap_end(self, msg)
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
def sendSaslString(self, string): def sendSaslString(self, string):
@ -1287,7 +1314,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
args=(chunk,))) args=(chunk,)))
def tryNextSaslMechanism(self): def tryNextSaslMechanism(self, msg):
self.state.fsm.expect_state([ self.state.fsm.expect_state([
IrcStateFsm.States.INIT_SASL, IrcStateFsm.States.INIT_SASL,
IrcStateFsm.States.CONNECTED_SASL, IrcStateFsm.States.CONNECTED_SASL,
@ -1301,14 +1328,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
'aborting connection.') 'aborting connection.')
else: else:
self.sasl_current_mechanism = None self.sasl_current_mechanism = None
self.state.fsm.on_sasl_auth_finished() self.state.fsm.on_sasl_auth_finished(self, msg)
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation() self.endCapabilityNegociation(msg)
def _maybeStartSasl(self): def _maybeStartSasl(self, msg):
if not self.sasl_authenticated and \ if not self.sasl_authenticated and \
'sasl' in self.state.capabilities_ack: 'sasl' in self.state.capabilities_ack:
self.state.fsm.on_sasl_cap() self.state.fsm.on_sasl_cap(self, msg)
assert 'sasl' in self.state.capabilities_ls, ( assert 'sasl' in self.state.capabilities_ls, (
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or ' 'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
'"CAP NEW sasl" first.') '"CAP NEW sasl" first.')
@ -1318,7 +1345,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sasl_next_mechanisms = [ self.sasl_next_mechanisms = [
x for x in self.sasl_next_mechanisms x for x in self.sasl_next_mechanisms
if x.lower() in available] if x.lower() in available]
self.tryNextSaslMechanism() self.tryNextSaslMechanism(msg)
def doAuthenticate(self, msg): def doAuthenticate(self, msg):
self.state.fsm.expect_state([ self.state.fsm.expect_state([
@ -1426,28 +1453,28 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
def do903(self, msg): def do903(self, msg):
log.info('%s: SASL authentication successful', self.network) log.info('%s: SASL authentication successful', self.network)
self.sasl_authenticated = True self.sasl_authenticated = True
self.state.fsm.on_sasl_auth_finished() self.state.fsm.on_sasl_auth_finished(self, msg)
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation() self.endCapabilityNegociation(msg)
def do904(self, msg): def do904(self, msg):
log.warning('%s: SASL authentication failed (mechanism: %s)', log.warning('%s: SASL authentication failed (mechanism: %s)',
self.network, self.sasl_current_mechanism) self.network, self.sasl_current_mechanism)
self.tryNextSaslMechanism() self.tryNextSaslMechanism(msg)
def do905(self, msg): def do905(self, msg):
log.warning('%s: SASL authentication failed because the username or ' log.warning('%s: SASL authentication failed because the username or '
'password is too long.', self.network) 'password is too long.', self.network)
self.tryNextSaslMechanism() self.tryNextSaslMechanism(msg)
def do906(self, msg): def do906(self, msg):
log.warning('%s: SASL authentication aborted', self.network) log.warning('%s: SASL authentication aborted', self.network)
self.tryNextSaslMechanism() self.tryNextSaslMechanism(msg)
def do907(self, msg): def do907(self, msg):
log.warning('%s: Attempted SASL authentication when we were already ' log.warning('%s: Attempted SASL authentication when we were already '
'authenticated.', self.network) 'authenticated.', self.network)
self.tryNextSaslMechanism() self.tryNextSaslMechanism(msg)
def do908(self, msg): def do908(self, msg):
log.info('%s: Supported SASL mechanisms: %s', log.info('%s: Supported SASL mechanisms: %s',
@ -1464,7 +1491,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.network, caps) self.network, caps)
self.state.capabilities_ack.update(caps) self.state.capabilities_ack.update(caps)
self.capUpkeep() self.capUpkeep(msg)
def doCapNak(self, msg): def doCapNak(self, msg):
if len(msg.args) != 3: if len(msg.args) != 3:
@ -1475,9 +1502,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_nak.update(caps) self.state.capabilities_nak.update(caps)
log.warning('%s: Server refused capabilities: %L', log.warning('%s: Server refused capabilities: %L',
self.network, caps) self.network, caps)
self.capUpkeep() self.capUpkeep(msg)
def _onCapSts(self, policy): def _onCapSts(self, policy, msg):
secure_connection = self.driver.currentServer.force_tls_verification \ secure_connection = self.driver.currentServer.force_tls_verification \
or (self.driver.ssl and self.driver.anyCertValidationEnabled()) or (self.driver.ssl and self.driver.anyCertValidationEnabled())
@ -1504,19 +1531,19 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.driver.currentServer) self.driver.currentServer)
# Reconnect to the server, but with TLS *and* certificate # Reconnect to the server, but with TLS *and* certificate
# validation this time. # validation this time.
self.state.fsm.on_shutdown() self.state.fsm.on_shutdown(self, msg)
self.driver.reconnect( self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True), server=Server(hostname, parsed_policy['port'], True),
wait=True) wait=True)
def _addCapabilities(self, capstring): def _addCapabilities(self, capstring, msg):
for item in capstring.split(): for item in capstring.split():
while item.startswith(('=', '~')): while item.startswith(('=', '~')):
item = item[1:] item = item[1:]
if '=' in item: if '=' in item:
(cap, value) = item.split('=', 1) (cap, value) = item.split('=', 1)
if cap == 'sts': if cap == 'sts':
self._onCapSts(value) self._onCapSts(value, msg)
self.state.capabilities_ls[cap] = value self.state.capabilities_ls[cap] = value
else: else:
if item == 'sts': if item == 'sts':
@ -1532,9 +1559,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if msg.args[2] != '*': if msg.args[2] != '*':
log.warning('Bad CAP LS from server: %r', msg) log.warning('Bad CAP LS from server: %r', msg)
return return
self._addCapabilities(msg.args[3]) self._addCapabilities(msg.args[3], msg)
elif len(msg.args) == 3: # End of LS elif len(msg.args) == 3: # End of LS
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2], msg)
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return return
self.state.fsm.expect_state([ self.state.fsm.expect_state([
@ -1558,7 +1585,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if new_caps: if new_caps:
self._requestCaps(new_caps) self._requestCaps(new_caps)
else: else:
self.endCapabilityNegociation() self.endCapabilityNegociation(msg)
else: else:
log.warning('Bad CAP LS from server: %r', msg) log.warning('Bad CAP LS from server: %r', msg)
return return
@ -1590,7 +1617,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
return return
caps = msg.args[2].split() caps = msg.args[2].split()
assert caps, 'Empty list of capabilities' assert caps, 'Empty list of capabilities'
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2], msg)
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return return
common_supported_unrequested_capabilities = ( common_supported_unrequested_capabilities = (
@ -1687,7 +1714,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
log.info('Got start of MOTD from %s', self.server) log.info('Got start of MOTD from %s', self.server)
def do376(self, msg): def do376(self, msg):
self.state.fsm.on_end_motd() self.state.fsm.on_end_motd(self, msg)
log.info('Got end of MOTD from %s', self.server) log.info('Got end of MOTD from %s', self.server)
self.afterConnect = True self.afterConnect = True
# Let's reset nicks in case we had to use a weird one. # Let's reset nicks in case we had to use a weird one.