Add postTransition method to IrcCallback, called when irc.state.fsm changes.

This commit is contained in:
Valentin Lorentz 2020-05-01 20:19:53 +02:00
parent f7130f2629
commit 309fc1233b

View File

@ -122,6 +122,7 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
'__call__': None,
'inFilter': lambda self, irc, msg: msg,
'outFilter': lambda self, irc, msg: msg,
'postTransition': None,
'name': lambda self: self.__class__.__name__,
'callPrecedence': lambda self, irc: ([], []),
}
@ -172,6 +173,12 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
"""
return msg
def postTransition(self, irc, msg, from_state, to_state):
"""Called when the state of the IRC connection changes.
`msg` is the message that triggered the transition, if any."""
pass
def __call__(self, irc, msg):
"""Used for handling each message."""
method = self.dispatchCommand(msg.command, msg.args)
@ -430,10 +437,24 @@ class IrcStateFsm(object):
self.state, self.States.UNINITIALIZED)
self.state = self.States.UNINITIALIZED
def _transition(self, to_state, expected_from=None):
if expected_from is None or self.state in expected_from:
def _transition(self, irc, msg, to_state, expected_from=None):
"""Transitions to state `to_state`.
If `expected_from` is not `None`, first checks the current state is
in the set.
After the transition, calls the
`postTransition(irc, msg, from_state, to_state)` method of all objects
in `irc.callbacks`.
`msg` may be None if the transition isn't triggered by a message, but
`irc` may not."""
from_state = self.state
if expected_from is None or from_state in expected_from:
log.debug('transition from %s to %s', self.state, to_state)
self.state = to_state
for callback in reversed(irc.callbacks):
msg = callback.postTransition(irc, msg, from_state, to_state)
else:
raise ValueError('unexpected transition to %s while in state %s' %
(to_state, self.state))
@ -443,53 +464,53 @@ class IrcStateFsm(object):
raise ValueError(('Connection in state %s, but expected to be '
'in state %s') % (self.state, expected_states))
def on_init_messages_sent(self):
def on_init_messages_sent(self, irc):
'''As soon as USER/NICK/CAP LS are sent'''
self._transition(self.States.INIT_CAP_NEGOTIATION, [
self._transition(irc, None, self.States.INIT_CAP_NEGOTIATION, [
self.States.UNINITIALIZED,
])
def on_sasl_cap(self):
def on_sasl_cap(self, irc, msg):
'''Whenever we see the 'sasl' capability in a CAP LS response'''
if self.state == self.States.INIT_CAP_NEGOTIATION:
self._transition(self.States.INIT_SASL)
self._transition(irc, msg, self.States.INIT_SASL)
elif self.state == self.States.CONNECTED:
self._transition(self.States.CONNECTED_SASL)
self._transition(irc, msg, self.States.CONNECTED_SASL)
else:
raise ValueError('Got sasl cap while in state %s' % self.state)
def on_sasl_auth_finished(self):
def on_sasl_auth_finished(self, irc, msg):
'''When sasl auth either succeeded or failed.'''
if self.state == self.States.INIT_SASL:
self._transition(self.States.INIT_CAP_NEGOTIATION)
self._transition(irc, msg, self.States.INIT_CAP_NEGOTIATION)
elif self.state == self.States.CONNECTED_SASL:
self._transition(self.States.CONNECTED)
self._transition(irc, msg, self.States.CONNECTED)
else:
raise ValueError('Finished SASL auth while in state %s' % self.state)
def on_cap_end(self):
def on_cap_end(self, irc, msg):
'''When we send CAP END'''
self._transition(self.States.INIT_WAITING_MOTD, [
self._transition(irc, msg, self.States.INIT_WAITING_MOTD, [
self.States.INIT_CAP_NEGOTIATION,
])
def on_start_motd(self):
def on_start_motd(self, irc, msg):
'''On 375 (RPL_MOTDSTART)'''
self._transition(self.States.INIT_MOTD, [
self._transition(irc, msg, self.States.INIT_MOTD, [
self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD,
])
def on_end_motd(self):
def on_end_motd(self, irc, msg):
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
self._transition(self.States.CONNECTED, [
self._transition(irc, msg, self.States.CONNECTED, [
self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD,
self.States.INIT_MOTD
])
def on_shutdown(self):
self._transition(self.States.SHUTTING_DOWN)
def on_shutdown(self, irc, msg):
self._transition(irc, msg, self.States.SHUTTING_DOWN)
class IrcState(IrcCommandDispatcher, log.Firewalled):
"""Maintains state of the Irc connection. Should also become smarter.
@ -1222,7 +1243,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendAuthenticationMessages()
self.state.fsm.on_init_messages_sent()
self.state.fsm.on_init_messages_sent(self)
def sendAuthenticationMessages(self):
# Notes:
@ -1244,7 +1265,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.user(self.ident, self.user))
def capUpkeep(self):
def capUpkeep(self, msg):
"""
Called after getting a CAP ACK/NAK to check it's consistent with what
was requested, and to end the cap negotiation when we received all the
ACK/NAKs we were waiting for.
`msg` is the message that triggered this call."""
self.state.fsm.expect_state([
# Normal CAP ACK / CAP NAK during cap negotiation
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
@ -1268,18 +1295,18 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if self.state.fsm.state in [
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
IrcStateFsm.States.CONNECTED]:
self._maybeStartSasl()
self._maybeStartSasl(msg)
else:
pass # Already in the middle of a SASL auth
else:
self.endCapabilityNegociation()
self.endCapabilityNegociation(msg)
else:
log.debug('Waiting for ACK/NAK of capabilities: %r',
self.state.capabilities_req - capabilities_responded)
pass # Do nothing, we'll get more
def endCapabilityNegociation(self):
self.state.fsm.on_cap_end()
def endCapabilityNegociation(self, msg):
self.state.fsm.on_cap_end(self, msg)
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
def sendSaslString(self, string):
@ -1287,7 +1314,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
args=(chunk,)))
def tryNextSaslMechanism(self):
def tryNextSaslMechanism(self, msg):
self.state.fsm.expect_state([
IrcStateFsm.States.INIT_SASL,
IrcStateFsm.States.CONNECTED_SASL,
@ -1301,14 +1328,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
'aborting connection.')
else:
self.sasl_current_mechanism = None
self.state.fsm.on_sasl_auth_finished()
self.state.fsm.on_sasl_auth_finished(self, msg)
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation()
self.endCapabilityNegociation(msg)
def _maybeStartSasl(self):
def _maybeStartSasl(self, msg):
if not self.sasl_authenticated and \
'sasl' in self.state.capabilities_ack:
self.state.fsm.on_sasl_cap()
self.state.fsm.on_sasl_cap(self, msg)
assert 'sasl' in self.state.capabilities_ls, (
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
'"CAP NEW sasl" first.')
@ -1318,7 +1345,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sasl_next_mechanisms = [
x for x in self.sasl_next_mechanisms
if x.lower() in available]
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def doAuthenticate(self, msg):
self.state.fsm.expect_state([
@ -1426,28 +1453,28 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
def do903(self, msg):
log.info('%s: SASL authentication successful', self.network)
self.sasl_authenticated = True
self.state.fsm.on_sasl_auth_finished()
self.state.fsm.on_sasl_auth_finished(self, msg)
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation()
self.endCapabilityNegociation(msg)
def do904(self, msg):
log.warning('%s: SASL authentication failed (mechanism: %s)',
self.network, self.sasl_current_mechanism)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do905(self, msg):
log.warning('%s: SASL authentication failed because the username or '
'password is too long.', self.network)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do906(self, msg):
log.warning('%s: SASL authentication aborted', self.network)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do907(self, msg):
log.warning('%s: Attempted SASL authentication when we were already '
'authenticated.', self.network)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do908(self, msg):
log.info('%s: Supported SASL mechanisms: %s',
@ -1464,7 +1491,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.network, caps)
self.state.capabilities_ack.update(caps)
self.capUpkeep()
self.capUpkeep(msg)
def doCapNak(self, msg):
if len(msg.args) != 3:
@ -1475,9 +1502,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_nak.update(caps)
log.warning('%s: Server refused capabilities: %L',
self.network, caps)
self.capUpkeep()
self.capUpkeep(msg)
def _onCapSts(self, policy):
def _onCapSts(self, policy, msg):
secure_connection = self.driver.currentServer.force_tls_verification \
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
@ -1504,19 +1531,19 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.driver.currentServer)
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.state.fsm.on_shutdown()
self.state.fsm.on_shutdown(self, msg)
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True),
wait=True)
def _addCapabilities(self, capstring):
def _addCapabilities(self, capstring, msg):
for item in capstring.split():
while item.startswith(('=', '~')):
item = item[1:]
if '=' in item:
(cap, value) = item.split('=', 1)
if cap == 'sts':
self._onCapSts(value)
self._onCapSts(value, msg)
self.state.capabilities_ls[cap] = value
else:
if item == 'sts':
@ -1532,9 +1559,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if msg.args[2] != '*':
log.warning('Bad CAP LS from server: %r', msg)
return
self._addCapabilities(msg.args[3])
self._addCapabilities(msg.args[3], msg)
elif len(msg.args) == 3: # End of LS
self._addCapabilities(msg.args[2])
self._addCapabilities(msg.args[2], msg)
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return
self.state.fsm.expect_state([
@ -1558,7 +1585,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if new_caps:
self._requestCaps(new_caps)
else:
self.endCapabilityNegociation()
self.endCapabilityNegociation(msg)
else:
log.warning('Bad CAP LS from server: %r', msg)
return
@ -1590,7 +1617,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
return
caps = msg.args[2].split()
assert caps, 'Empty list of capabilities'
self._addCapabilities(msg.args[2])
self._addCapabilities(msg.args[2], msg)
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return
common_supported_unrequested_capabilities = (
@ -1687,7 +1714,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
log.info('Got start of MOTD from %s', self.server)
def do376(self, msg):
self.state.fsm.on_end_motd()
self.state.fsm.on_end_motd(self, msg)
log.info('Got end of MOTD from %s', self.server)
self.afterConnect = True
# Let's reset nicks in case we had to use a weird one.