mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-12-25 04:02:46 +01:00
Add postTransition method to IrcCallback, called when irc.state.fsm changes.
This commit is contained in:
parent
f7130f2629
commit
309fc1233b
121
src/irclib.py
121
src/irclib.py
@ -122,6 +122,7 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
|
||||
'__call__': None,
|
||||
'inFilter': lambda self, irc, msg: msg,
|
||||
'outFilter': lambda self, irc, msg: msg,
|
||||
'postTransition': None,
|
||||
'name': lambda self: self.__class__.__name__,
|
||||
'callPrecedence': lambda self, irc: ([], []),
|
||||
}
|
||||
@ -172,6 +173,12 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
|
||||
"""
|
||||
return msg
|
||||
|
||||
def postTransition(self, irc, msg, from_state, to_state):
|
||||
"""Called when the state of the IRC connection changes.
|
||||
|
||||
`msg` is the message that triggered the transition, if any."""
|
||||
pass
|
||||
|
||||
def __call__(self, irc, msg):
|
||||
"""Used for handling each message."""
|
||||
method = self.dispatchCommand(msg.command, msg.args)
|
||||
@ -430,10 +437,24 @@ class IrcStateFsm(object):
|
||||
self.state, self.States.UNINITIALIZED)
|
||||
self.state = self.States.UNINITIALIZED
|
||||
|
||||
def _transition(self, to_state, expected_from=None):
|
||||
if expected_from is None or self.state in expected_from:
|
||||
def _transition(self, irc, msg, to_state, expected_from=None):
|
||||
"""Transitions to state `to_state`.
|
||||
|
||||
If `expected_from` is not `None`, first checks the current state is
|
||||
in the set.
|
||||
|
||||
After the transition, calls the
|
||||
`postTransition(irc, msg, from_state, to_state)` method of all objects
|
||||
in `irc.callbacks`.
|
||||
|
||||
`msg` may be None if the transition isn't triggered by a message, but
|
||||
`irc` may not."""
|
||||
from_state = self.state
|
||||
if expected_from is None or from_state in expected_from:
|
||||
log.debug('transition from %s to %s', self.state, to_state)
|
||||
self.state = to_state
|
||||
for callback in reversed(irc.callbacks):
|
||||
msg = callback.postTransition(irc, msg, from_state, to_state)
|
||||
else:
|
||||
raise ValueError('unexpected transition to %s while in state %s' %
|
||||
(to_state, self.state))
|
||||
@ -443,53 +464,53 @@ class IrcStateFsm(object):
|
||||
raise ValueError(('Connection in state %s, but expected to be '
|
||||
'in state %s') % (self.state, expected_states))
|
||||
|
||||
def on_init_messages_sent(self):
|
||||
def on_init_messages_sent(self, irc):
|
||||
'''As soon as USER/NICK/CAP LS are sent'''
|
||||
self._transition(self.States.INIT_CAP_NEGOTIATION, [
|
||||
self._transition(irc, None, self.States.INIT_CAP_NEGOTIATION, [
|
||||
self.States.UNINITIALIZED,
|
||||
])
|
||||
|
||||
def on_sasl_cap(self):
|
||||
def on_sasl_cap(self, irc, msg):
|
||||
'''Whenever we see the 'sasl' capability in a CAP LS response'''
|
||||
if self.state == self.States.INIT_CAP_NEGOTIATION:
|
||||
self._transition(self.States.INIT_SASL)
|
||||
self._transition(irc, msg, self.States.INIT_SASL)
|
||||
elif self.state == self.States.CONNECTED:
|
||||
self._transition(self.States.CONNECTED_SASL)
|
||||
self._transition(irc, msg, self.States.CONNECTED_SASL)
|
||||
else:
|
||||
raise ValueError('Got sasl cap while in state %s' % self.state)
|
||||
|
||||
def on_sasl_auth_finished(self):
|
||||
def on_sasl_auth_finished(self, irc, msg):
|
||||
'''When sasl auth either succeeded or failed.'''
|
||||
if self.state == self.States.INIT_SASL:
|
||||
self._transition(self.States.INIT_CAP_NEGOTIATION)
|
||||
self._transition(irc, msg, self.States.INIT_CAP_NEGOTIATION)
|
||||
elif self.state == self.States.CONNECTED_SASL:
|
||||
self._transition(self.States.CONNECTED)
|
||||
self._transition(irc, msg, self.States.CONNECTED)
|
||||
else:
|
||||
raise ValueError('Finished SASL auth while in state %s' % self.state)
|
||||
|
||||
def on_cap_end(self):
|
||||
def on_cap_end(self, irc, msg):
|
||||
'''When we send CAP END'''
|
||||
self._transition(self.States.INIT_WAITING_MOTD, [
|
||||
self._transition(irc, msg, self.States.INIT_WAITING_MOTD, [
|
||||
self.States.INIT_CAP_NEGOTIATION,
|
||||
])
|
||||
|
||||
def on_start_motd(self):
|
||||
def on_start_motd(self, irc, msg):
|
||||
'''On 375 (RPL_MOTDSTART)'''
|
||||
self._transition(self.States.INIT_MOTD, [
|
||||
self._transition(irc, msg, self.States.INIT_MOTD, [
|
||||
self.States.INIT_CAP_NEGOTIATION,
|
||||
self.States.INIT_WAITING_MOTD,
|
||||
])
|
||||
|
||||
def on_end_motd(self):
|
||||
def on_end_motd(self, irc, msg):
|
||||
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
|
||||
self._transition(self.States.CONNECTED, [
|
||||
self._transition(irc, msg, self.States.CONNECTED, [
|
||||
self.States.INIT_CAP_NEGOTIATION,
|
||||
self.States.INIT_WAITING_MOTD,
|
||||
self.States.INIT_MOTD
|
||||
])
|
||||
|
||||
def on_shutdown(self):
|
||||
self._transition(self.States.SHUTTING_DOWN)
|
||||
def on_shutdown(self, irc, msg):
|
||||
self._transition(irc, msg, self.States.SHUTTING_DOWN)
|
||||
|
||||
class IrcState(IrcCommandDispatcher, log.Firewalled):
|
||||
"""Maintains state of the Irc connection. Should also become smarter.
|
||||
@ -1222,7 +1243,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
|
||||
self.sendAuthenticationMessages()
|
||||
|
||||
self.state.fsm.on_init_messages_sent()
|
||||
self.state.fsm.on_init_messages_sent(self)
|
||||
|
||||
def sendAuthenticationMessages(self):
|
||||
# Notes:
|
||||
@ -1244,7 +1265,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
|
||||
self.sendMsg(ircmsgs.user(self.ident, self.user))
|
||||
|
||||
def capUpkeep(self):
|
||||
def capUpkeep(self, msg):
|
||||
"""
|
||||
Called after getting a CAP ACK/NAK to check it's consistent with what
|
||||
was requested, and to end the cap negotiation when we received all the
|
||||
ACK/NAKs we were waiting for.
|
||||
|
||||
`msg` is the message that triggered this call."""
|
||||
self.state.fsm.expect_state([
|
||||
# Normal CAP ACK / CAP NAK during cap negotiation
|
||||
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||
@ -1268,18 +1295,18 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
if self.state.fsm.state in [
|
||||
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||
IrcStateFsm.States.CONNECTED]:
|
||||
self._maybeStartSasl()
|
||||
self._maybeStartSasl(msg)
|
||||
else:
|
||||
pass # Already in the middle of a SASL auth
|
||||
else:
|
||||
self.endCapabilityNegociation()
|
||||
self.endCapabilityNegociation(msg)
|
||||
else:
|
||||
log.debug('Waiting for ACK/NAK of capabilities: %r',
|
||||
self.state.capabilities_req - capabilities_responded)
|
||||
pass # Do nothing, we'll get more
|
||||
|
||||
def endCapabilityNegociation(self):
|
||||
self.state.fsm.on_cap_end()
|
||||
def endCapabilityNegociation(self, msg):
|
||||
self.state.fsm.on_cap_end(self, msg)
|
||||
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
|
||||
|
||||
def sendSaslString(self, string):
|
||||
@ -1287,7 +1314,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
|
||||
args=(chunk,)))
|
||||
|
||||
def tryNextSaslMechanism(self):
|
||||
def tryNextSaslMechanism(self, msg):
|
||||
self.state.fsm.expect_state([
|
||||
IrcStateFsm.States.INIT_SASL,
|
||||
IrcStateFsm.States.CONNECTED_SASL,
|
||||
@ -1301,14 +1328,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
'aborting connection.')
|
||||
else:
|
||||
self.sasl_current_mechanism = None
|
||||
self.state.fsm.on_sasl_auth_finished()
|
||||
self.state.fsm.on_sasl_auth_finished(self, msg)
|
||||
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
||||
self.endCapabilityNegociation()
|
||||
self.endCapabilityNegociation(msg)
|
||||
|
||||
def _maybeStartSasl(self):
|
||||
def _maybeStartSasl(self, msg):
|
||||
if not self.sasl_authenticated and \
|
||||
'sasl' in self.state.capabilities_ack:
|
||||
self.state.fsm.on_sasl_cap()
|
||||
self.state.fsm.on_sasl_cap(self, msg)
|
||||
assert 'sasl' in self.state.capabilities_ls, (
|
||||
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
|
||||
'"CAP NEW sasl" first.')
|
||||
@ -1318,7 +1345,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.sasl_next_mechanisms = [
|
||||
x for x in self.sasl_next_mechanisms
|
||||
if x.lower() in available]
|
||||
self.tryNextSaslMechanism()
|
||||
self.tryNextSaslMechanism(msg)
|
||||
|
||||
def doAuthenticate(self, msg):
|
||||
self.state.fsm.expect_state([
|
||||
@ -1426,28 +1453,28 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
def do903(self, msg):
|
||||
log.info('%s: SASL authentication successful', self.network)
|
||||
self.sasl_authenticated = True
|
||||
self.state.fsm.on_sasl_auth_finished()
|
||||
self.state.fsm.on_sasl_auth_finished(self, msg)
|
||||
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
||||
self.endCapabilityNegociation()
|
||||
self.endCapabilityNegociation(msg)
|
||||
|
||||
def do904(self, msg):
|
||||
log.warning('%s: SASL authentication failed (mechanism: %s)',
|
||||
self.network, self.sasl_current_mechanism)
|
||||
self.tryNextSaslMechanism()
|
||||
self.tryNextSaslMechanism(msg)
|
||||
|
||||
def do905(self, msg):
|
||||
log.warning('%s: SASL authentication failed because the username or '
|
||||
'password is too long.', self.network)
|
||||
self.tryNextSaslMechanism()
|
||||
self.tryNextSaslMechanism(msg)
|
||||
|
||||
def do906(self, msg):
|
||||
log.warning('%s: SASL authentication aborted', self.network)
|
||||
self.tryNextSaslMechanism()
|
||||
self.tryNextSaslMechanism(msg)
|
||||
|
||||
def do907(self, msg):
|
||||
log.warning('%s: Attempted SASL authentication when we were already '
|
||||
'authenticated.', self.network)
|
||||
self.tryNextSaslMechanism()
|
||||
self.tryNextSaslMechanism(msg)
|
||||
|
||||
def do908(self, msg):
|
||||
log.info('%s: Supported SASL mechanisms: %s',
|
||||
@ -1464,7 +1491,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.network, caps)
|
||||
self.state.capabilities_ack.update(caps)
|
||||
|
||||
self.capUpkeep()
|
||||
self.capUpkeep(msg)
|
||||
|
||||
def doCapNak(self, msg):
|
||||
if len(msg.args) != 3:
|
||||
@ -1475,9 +1502,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.state.capabilities_nak.update(caps)
|
||||
log.warning('%s: Server refused capabilities: %L',
|
||||
self.network, caps)
|
||||
self.capUpkeep()
|
||||
self.capUpkeep(msg)
|
||||
|
||||
def _onCapSts(self, policy):
|
||||
def _onCapSts(self, policy, msg):
|
||||
secure_connection = self.driver.currentServer.force_tls_verification \
|
||||
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
|
||||
|
||||
@ -1504,19 +1531,19 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.driver.currentServer)
|
||||
# Reconnect to the server, but with TLS *and* certificate
|
||||
# validation this time.
|
||||
self.state.fsm.on_shutdown()
|
||||
self.state.fsm.on_shutdown(self, msg)
|
||||
self.driver.reconnect(
|
||||
server=Server(hostname, parsed_policy['port'], True),
|
||||
wait=True)
|
||||
|
||||
def _addCapabilities(self, capstring):
|
||||
def _addCapabilities(self, capstring, msg):
|
||||
for item in capstring.split():
|
||||
while item.startswith(('=', '~')):
|
||||
item = item[1:]
|
||||
if '=' in item:
|
||||
(cap, value) = item.split('=', 1)
|
||||
if cap == 'sts':
|
||||
self._onCapSts(value)
|
||||
self._onCapSts(value, msg)
|
||||
self.state.capabilities_ls[cap] = value
|
||||
else:
|
||||
if item == 'sts':
|
||||
@ -1532,9 +1559,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
if msg.args[2] != '*':
|
||||
log.warning('Bad CAP LS from server: %r', msg)
|
||||
return
|
||||
self._addCapabilities(msg.args[3])
|
||||
self._addCapabilities(msg.args[3], msg)
|
||||
elif len(msg.args) == 3: # End of LS
|
||||
self._addCapabilities(msg.args[2])
|
||||
self._addCapabilities(msg.args[2], msg)
|
||||
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
||||
return
|
||||
self.state.fsm.expect_state([
|
||||
@ -1558,7 +1585,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
if new_caps:
|
||||
self._requestCaps(new_caps)
|
||||
else:
|
||||
self.endCapabilityNegociation()
|
||||
self.endCapabilityNegociation(msg)
|
||||
else:
|
||||
log.warning('Bad CAP LS from server: %r', msg)
|
||||
return
|
||||
@ -1590,7 +1617,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
return
|
||||
caps = msg.args[2].split()
|
||||
assert caps, 'Empty list of capabilities'
|
||||
self._addCapabilities(msg.args[2])
|
||||
self._addCapabilities(msg.args[2], msg)
|
||||
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
||||
return
|
||||
common_supported_unrequested_capabilities = (
|
||||
@ -1687,7 +1714,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
log.info('Got start of MOTD from %s', self.server)
|
||||
|
||||
def do376(self, msg):
|
||||
self.state.fsm.on_end_motd()
|
||||
self.state.fsm.on_end_motd(self, msg)
|
||||
log.info('Got end of MOTD from %s', self.server)
|
||||
self.afterConnect = True
|
||||
# Let's reset nicks in case we had to use a weird one.
|
||||
|
Loading…
Reference in New Issue
Block a user