mirror of
https://github.com/Mikaela/Limnoria.git
synced 2025-11-13 14:17:21 +01:00
Add postTransition method to IrcCallback, called when irc.state.fsm changes.
This commit is contained in:
parent
f7130f2629
commit
309fc1233b
121
src/irclib.py
121
src/irclib.py
@ -122,6 +122,7 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
|
|||||||
'__call__': None,
|
'__call__': None,
|
||||||
'inFilter': lambda self, irc, msg: msg,
|
'inFilter': lambda self, irc, msg: msg,
|
||||||
'outFilter': lambda self, irc, msg: msg,
|
'outFilter': lambda self, irc, msg: msg,
|
||||||
|
'postTransition': None,
|
||||||
'name': lambda self: self.__class__.__name__,
|
'name': lambda self: self.__class__.__name__,
|
||||||
'callPrecedence': lambda self, irc: ([], []),
|
'callPrecedence': lambda self, irc: ([], []),
|
||||||
}
|
}
|
||||||
@ -172,6 +173,12 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
|
|||||||
"""
|
"""
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
def postTransition(self, irc, msg, from_state, to_state):
|
||||||
|
"""Called when the state of the IRC connection changes.
|
||||||
|
|
||||||
|
`msg` is the message that triggered the transition, if any."""
|
||||||
|
pass
|
||||||
|
|
||||||
def __call__(self, irc, msg):
|
def __call__(self, irc, msg):
|
||||||
"""Used for handling each message."""
|
"""Used for handling each message."""
|
||||||
method = self.dispatchCommand(msg.command, msg.args)
|
method = self.dispatchCommand(msg.command, msg.args)
|
||||||
@ -430,10 +437,24 @@ class IrcStateFsm(object):
|
|||||||
self.state, self.States.UNINITIALIZED)
|
self.state, self.States.UNINITIALIZED)
|
||||||
self.state = self.States.UNINITIALIZED
|
self.state = self.States.UNINITIALIZED
|
||||||
|
|
||||||
def _transition(self, to_state, expected_from=None):
|
def _transition(self, irc, msg, to_state, expected_from=None):
|
||||||
if expected_from is None or self.state in expected_from:
|
"""Transitions to state `to_state`.
|
||||||
|
|
||||||
|
If `expected_from` is not `None`, first checks the current state is
|
||||||
|
in the set.
|
||||||
|
|
||||||
|
After the transition, calls the
|
||||||
|
`postTransition(irc, msg, from_state, to_state)` method of all objects
|
||||||
|
in `irc.callbacks`.
|
||||||
|
|
||||||
|
`msg` may be None if the transition isn't triggered by a message, but
|
||||||
|
`irc` may not."""
|
||||||
|
from_state = self.state
|
||||||
|
if expected_from is None or from_state in expected_from:
|
||||||
log.debug('transition from %s to %s', self.state, to_state)
|
log.debug('transition from %s to %s', self.state, to_state)
|
||||||
self.state = to_state
|
self.state = to_state
|
||||||
|
for callback in reversed(irc.callbacks):
|
||||||
|
msg = callback.postTransition(irc, msg, from_state, to_state)
|
||||||
else:
|
else:
|
||||||
raise ValueError('unexpected transition to %s while in state %s' %
|
raise ValueError('unexpected transition to %s while in state %s' %
|
||||||
(to_state, self.state))
|
(to_state, self.state))
|
||||||
@ -443,53 +464,53 @@ class IrcStateFsm(object):
|
|||||||
raise ValueError(('Connection in state %s, but expected to be '
|
raise ValueError(('Connection in state %s, but expected to be '
|
||||||
'in state %s') % (self.state, expected_states))
|
'in state %s') % (self.state, expected_states))
|
||||||
|
|
||||||
def on_init_messages_sent(self):
|
def on_init_messages_sent(self, irc):
|
||||||
'''As soon as USER/NICK/CAP LS are sent'''
|
'''As soon as USER/NICK/CAP LS are sent'''
|
||||||
self._transition(self.States.INIT_CAP_NEGOTIATION, [
|
self._transition(irc, None, self.States.INIT_CAP_NEGOTIATION, [
|
||||||
self.States.UNINITIALIZED,
|
self.States.UNINITIALIZED,
|
||||||
])
|
])
|
||||||
|
|
||||||
def on_sasl_cap(self):
|
def on_sasl_cap(self, irc, msg):
|
||||||
'''Whenever we see the 'sasl' capability in a CAP LS response'''
|
'''Whenever we see the 'sasl' capability in a CAP LS response'''
|
||||||
if self.state == self.States.INIT_CAP_NEGOTIATION:
|
if self.state == self.States.INIT_CAP_NEGOTIATION:
|
||||||
self._transition(self.States.INIT_SASL)
|
self._transition(irc, msg, self.States.INIT_SASL)
|
||||||
elif self.state == self.States.CONNECTED:
|
elif self.state == self.States.CONNECTED:
|
||||||
self._transition(self.States.CONNECTED_SASL)
|
self._transition(irc, msg, self.States.CONNECTED_SASL)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Got sasl cap while in state %s' % self.state)
|
raise ValueError('Got sasl cap while in state %s' % self.state)
|
||||||
|
|
||||||
def on_sasl_auth_finished(self):
|
def on_sasl_auth_finished(self, irc, msg):
|
||||||
'''When sasl auth either succeeded or failed.'''
|
'''When sasl auth either succeeded or failed.'''
|
||||||
if self.state == self.States.INIT_SASL:
|
if self.state == self.States.INIT_SASL:
|
||||||
self._transition(self.States.INIT_CAP_NEGOTIATION)
|
self._transition(irc, msg, self.States.INIT_CAP_NEGOTIATION)
|
||||||
elif self.state == self.States.CONNECTED_SASL:
|
elif self.state == self.States.CONNECTED_SASL:
|
||||||
self._transition(self.States.CONNECTED)
|
self._transition(irc, msg, self.States.CONNECTED)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Finished SASL auth while in state %s' % self.state)
|
raise ValueError('Finished SASL auth while in state %s' % self.state)
|
||||||
|
|
||||||
def on_cap_end(self):
|
def on_cap_end(self, irc, msg):
|
||||||
'''When we send CAP END'''
|
'''When we send CAP END'''
|
||||||
self._transition(self.States.INIT_WAITING_MOTD, [
|
self._transition(irc, msg, self.States.INIT_WAITING_MOTD, [
|
||||||
self.States.INIT_CAP_NEGOTIATION,
|
self.States.INIT_CAP_NEGOTIATION,
|
||||||
])
|
])
|
||||||
|
|
||||||
def on_start_motd(self):
|
def on_start_motd(self, irc, msg):
|
||||||
'''On 375 (RPL_MOTDSTART)'''
|
'''On 375 (RPL_MOTDSTART)'''
|
||||||
self._transition(self.States.INIT_MOTD, [
|
self._transition(irc, msg, self.States.INIT_MOTD, [
|
||||||
self.States.INIT_CAP_NEGOTIATION,
|
self.States.INIT_CAP_NEGOTIATION,
|
||||||
self.States.INIT_WAITING_MOTD,
|
self.States.INIT_WAITING_MOTD,
|
||||||
])
|
])
|
||||||
|
|
||||||
def on_end_motd(self):
|
def on_end_motd(self, irc, msg):
|
||||||
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
|
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
|
||||||
self._transition(self.States.CONNECTED, [
|
self._transition(irc, msg, self.States.CONNECTED, [
|
||||||
self.States.INIT_CAP_NEGOTIATION,
|
self.States.INIT_CAP_NEGOTIATION,
|
||||||
self.States.INIT_WAITING_MOTD,
|
self.States.INIT_WAITING_MOTD,
|
||||||
self.States.INIT_MOTD
|
self.States.INIT_MOTD
|
||||||
])
|
])
|
||||||
|
|
||||||
def on_shutdown(self):
|
def on_shutdown(self, irc, msg):
|
||||||
self._transition(self.States.SHUTTING_DOWN)
|
self._transition(irc, msg, self.States.SHUTTING_DOWN)
|
||||||
|
|
||||||
class IrcState(IrcCommandDispatcher, log.Firewalled):
|
class IrcState(IrcCommandDispatcher, log.Firewalled):
|
||||||
"""Maintains state of the Irc connection. Should also become smarter.
|
"""Maintains state of the Irc connection. Should also become smarter.
|
||||||
@ -1222,7 +1243,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
|
|
||||||
self.sendAuthenticationMessages()
|
self.sendAuthenticationMessages()
|
||||||
|
|
||||||
self.state.fsm.on_init_messages_sent()
|
self.state.fsm.on_init_messages_sent(self)
|
||||||
|
|
||||||
def sendAuthenticationMessages(self):
|
def sendAuthenticationMessages(self):
|
||||||
# Notes:
|
# Notes:
|
||||||
@ -1244,7 +1265,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
|
|
||||||
self.sendMsg(ircmsgs.user(self.ident, self.user))
|
self.sendMsg(ircmsgs.user(self.ident, self.user))
|
||||||
|
|
||||||
def capUpkeep(self):
|
def capUpkeep(self, msg):
|
||||||
|
"""
|
||||||
|
Called after getting a CAP ACK/NAK to check it's consistent with what
|
||||||
|
was requested, and to end the cap negotiation when we received all the
|
||||||
|
ACK/NAKs we were waiting for.
|
||||||
|
|
||||||
|
`msg` is the message that triggered this call."""
|
||||||
self.state.fsm.expect_state([
|
self.state.fsm.expect_state([
|
||||||
# Normal CAP ACK / CAP NAK during cap negotiation
|
# Normal CAP ACK / CAP NAK during cap negotiation
|
||||||
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||||
@ -1268,18 +1295,18 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
if self.state.fsm.state in [
|
if self.state.fsm.state in [
|
||||||
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||||
IrcStateFsm.States.CONNECTED]:
|
IrcStateFsm.States.CONNECTED]:
|
||||||
self._maybeStartSasl()
|
self._maybeStartSasl(msg)
|
||||||
else:
|
else:
|
||||||
pass # Already in the middle of a SASL auth
|
pass # Already in the middle of a SASL auth
|
||||||
else:
|
else:
|
||||||
self.endCapabilityNegociation()
|
self.endCapabilityNegociation(msg)
|
||||||
else:
|
else:
|
||||||
log.debug('Waiting for ACK/NAK of capabilities: %r',
|
log.debug('Waiting for ACK/NAK of capabilities: %r',
|
||||||
self.state.capabilities_req - capabilities_responded)
|
self.state.capabilities_req - capabilities_responded)
|
||||||
pass # Do nothing, we'll get more
|
pass # Do nothing, we'll get more
|
||||||
|
|
||||||
def endCapabilityNegociation(self):
|
def endCapabilityNegociation(self, msg):
|
||||||
self.state.fsm.on_cap_end()
|
self.state.fsm.on_cap_end(self, msg)
|
||||||
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
|
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
|
||||||
|
|
||||||
def sendSaslString(self, string):
|
def sendSaslString(self, string):
|
||||||
@ -1287,7 +1314,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
|
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
|
||||||
args=(chunk,)))
|
args=(chunk,)))
|
||||||
|
|
||||||
def tryNextSaslMechanism(self):
|
def tryNextSaslMechanism(self, msg):
|
||||||
self.state.fsm.expect_state([
|
self.state.fsm.expect_state([
|
||||||
IrcStateFsm.States.INIT_SASL,
|
IrcStateFsm.States.INIT_SASL,
|
||||||
IrcStateFsm.States.CONNECTED_SASL,
|
IrcStateFsm.States.CONNECTED_SASL,
|
||||||
@ -1301,14 +1328,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
'aborting connection.')
|
'aborting connection.')
|
||||||
else:
|
else:
|
||||||
self.sasl_current_mechanism = None
|
self.sasl_current_mechanism = None
|
||||||
self.state.fsm.on_sasl_auth_finished()
|
self.state.fsm.on_sasl_auth_finished(self, msg)
|
||||||
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
||||||
self.endCapabilityNegociation()
|
self.endCapabilityNegociation(msg)
|
||||||
|
|
||||||
def _maybeStartSasl(self):
|
def _maybeStartSasl(self, msg):
|
||||||
if not self.sasl_authenticated and \
|
if not self.sasl_authenticated and \
|
||||||
'sasl' in self.state.capabilities_ack:
|
'sasl' in self.state.capabilities_ack:
|
||||||
self.state.fsm.on_sasl_cap()
|
self.state.fsm.on_sasl_cap(self, msg)
|
||||||
assert 'sasl' in self.state.capabilities_ls, (
|
assert 'sasl' in self.state.capabilities_ls, (
|
||||||
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
|
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
|
||||||
'"CAP NEW sasl" first.')
|
'"CAP NEW sasl" first.')
|
||||||
@ -1318,7 +1345,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.sasl_next_mechanisms = [
|
self.sasl_next_mechanisms = [
|
||||||
x for x in self.sasl_next_mechanisms
|
x for x in self.sasl_next_mechanisms
|
||||||
if x.lower() in available]
|
if x.lower() in available]
|
||||||
self.tryNextSaslMechanism()
|
self.tryNextSaslMechanism(msg)
|
||||||
|
|
||||||
def doAuthenticate(self, msg):
|
def doAuthenticate(self, msg):
|
||||||
self.state.fsm.expect_state([
|
self.state.fsm.expect_state([
|
||||||
@ -1426,28 +1453,28 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
def do903(self, msg):
|
def do903(self, msg):
|
||||||
log.info('%s: SASL authentication successful', self.network)
|
log.info('%s: SASL authentication successful', self.network)
|
||||||
self.sasl_authenticated = True
|
self.sasl_authenticated = True
|
||||||
self.state.fsm.on_sasl_auth_finished()
|
self.state.fsm.on_sasl_auth_finished(self, msg)
|
||||||
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
|
||||||
self.endCapabilityNegociation()
|
self.endCapabilityNegociation(msg)
|
||||||
|
|
||||||
def do904(self, msg):
|
def do904(self, msg):
|
||||||
log.warning('%s: SASL authentication failed (mechanism: %s)',
|
log.warning('%s: SASL authentication failed (mechanism: %s)',
|
||||||
self.network, self.sasl_current_mechanism)
|
self.network, self.sasl_current_mechanism)
|
||||||
self.tryNextSaslMechanism()
|
self.tryNextSaslMechanism(msg)
|
||||||
|
|
||||||
def do905(self, msg):
|
def do905(self, msg):
|
||||||
log.warning('%s: SASL authentication failed because the username or '
|
log.warning('%s: SASL authentication failed because the username or '
|
||||||
'password is too long.', self.network)
|
'password is too long.', self.network)
|
||||||
self.tryNextSaslMechanism()
|
self.tryNextSaslMechanism(msg)
|
||||||
|
|
||||||
def do906(self, msg):
|
def do906(self, msg):
|
||||||
log.warning('%s: SASL authentication aborted', self.network)
|
log.warning('%s: SASL authentication aborted', self.network)
|
||||||
self.tryNextSaslMechanism()
|
self.tryNextSaslMechanism(msg)
|
||||||
|
|
||||||
def do907(self, msg):
|
def do907(self, msg):
|
||||||
log.warning('%s: Attempted SASL authentication when we were already '
|
log.warning('%s: Attempted SASL authentication when we were already '
|
||||||
'authenticated.', self.network)
|
'authenticated.', self.network)
|
||||||
self.tryNextSaslMechanism()
|
self.tryNextSaslMechanism(msg)
|
||||||
|
|
||||||
def do908(self, msg):
|
def do908(self, msg):
|
||||||
log.info('%s: Supported SASL mechanisms: %s',
|
log.info('%s: Supported SASL mechanisms: %s',
|
||||||
@ -1464,7 +1491,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.network, caps)
|
self.network, caps)
|
||||||
self.state.capabilities_ack.update(caps)
|
self.state.capabilities_ack.update(caps)
|
||||||
|
|
||||||
self.capUpkeep()
|
self.capUpkeep(msg)
|
||||||
|
|
||||||
def doCapNak(self, msg):
|
def doCapNak(self, msg):
|
||||||
if len(msg.args) != 3:
|
if len(msg.args) != 3:
|
||||||
@ -1475,9 +1502,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.state.capabilities_nak.update(caps)
|
self.state.capabilities_nak.update(caps)
|
||||||
log.warning('%s: Server refused capabilities: %L',
|
log.warning('%s: Server refused capabilities: %L',
|
||||||
self.network, caps)
|
self.network, caps)
|
||||||
self.capUpkeep()
|
self.capUpkeep(msg)
|
||||||
|
|
||||||
def _onCapSts(self, policy):
|
def _onCapSts(self, policy, msg):
|
||||||
secure_connection = self.driver.currentServer.force_tls_verification \
|
secure_connection = self.driver.currentServer.force_tls_verification \
|
||||||
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
|
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
|
||||||
|
|
||||||
@ -1504,19 +1531,19 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.driver.currentServer)
|
self.driver.currentServer)
|
||||||
# Reconnect to the server, but with TLS *and* certificate
|
# Reconnect to the server, but with TLS *and* certificate
|
||||||
# validation this time.
|
# validation this time.
|
||||||
self.state.fsm.on_shutdown()
|
self.state.fsm.on_shutdown(self, msg)
|
||||||
self.driver.reconnect(
|
self.driver.reconnect(
|
||||||
server=Server(hostname, parsed_policy['port'], True),
|
server=Server(hostname, parsed_policy['port'], True),
|
||||||
wait=True)
|
wait=True)
|
||||||
|
|
||||||
def _addCapabilities(self, capstring):
|
def _addCapabilities(self, capstring, msg):
|
||||||
for item in capstring.split():
|
for item in capstring.split():
|
||||||
while item.startswith(('=', '~')):
|
while item.startswith(('=', '~')):
|
||||||
item = item[1:]
|
item = item[1:]
|
||||||
if '=' in item:
|
if '=' in item:
|
||||||
(cap, value) = item.split('=', 1)
|
(cap, value) = item.split('=', 1)
|
||||||
if cap == 'sts':
|
if cap == 'sts':
|
||||||
self._onCapSts(value)
|
self._onCapSts(value, msg)
|
||||||
self.state.capabilities_ls[cap] = value
|
self.state.capabilities_ls[cap] = value
|
||||||
else:
|
else:
|
||||||
if item == 'sts':
|
if item == 'sts':
|
||||||
@ -1532,9 +1559,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
if msg.args[2] != '*':
|
if msg.args[2] != '*':
|
||||||
log.warning('Bad CAP LS from server: %r', msg)
|
log.warning('Bad CAP LS from server: %r', msg)
|
||||||
return
|
return
|
||||||
self._addCapabilities(msg.args[3])
|
self._addCapabilities(msg.args[3], msg)
|
||||||
elif len(msg.args) == 3: # End of LS
|
elif len(msg.args) == 3: # End of LS
|
||||||
self._addCapabilities(msg.args[2])
|
self._addCapabilities(msg.args[2], msg)
|
||||||
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
||||||
return
|
return
|
||||||
self.state.fsm.expect_state([
|
self.state.fsm.expect_state([
|
||||||
@ -1558,7 +1585,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
if new_caps:
|
if new_caps:
|
||||||
self._requestCaps(new_caps)
|
self._requestCaps(new_caps)
|
||||||
else:
|
else:
|
||||||
self.endCapabilityNegociation()
|
self.endCapabilityNegociation(msg)
|
||||||
else:
|
else:
|
||||||
log.warning('Bad CAP LS from server: %r', msg)
|
log.warning('Bad CAP LS from server: %r', msg)
|
||||||
return
|
return
|
||||||
@ -1590,7 +1617,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
return
|
return
|
||||||
caps = msg.args[2].split()
|
caps = msg.args[2].split()
|
||||||
assert caps, 'Empty list of capabilities'
|
assert caps, 'Empty list of capabilities'
|
||||||
self._addCapabilities(msg.args[2])
|
self._addCapabilities(msg.args[2], msg)
|
||||||
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
||||||
return
|
return
|
||||||
common_supported_unrequested_capabilities = (
|
common_supported_unrequested_capabilities = (
|
||||||
@ -1687,7 +1714,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
log.info('Got start of MOTD from %s', self.server)
|
log.info('Got start of MOTD from %s', self.server)
|
||||||
|
|
||||||
def do376(self, msg):
|
def do376(self, msg):
|
||||||
self.state.fsm.on_end_motd()
|
self.state.fsm.on_end_motd(self, msg)
|
||||||
log.info('Got end of MOTD from %s', self.server)
|
log.info('Got end of MOTD from %s', self.server)
|
||||||
self.afterConnect = True
|
self.afterConnect = True
|
||||||
# Let's reset nicks in case we had to use a weird one.
|
# Let's reset nicks in case we had to use a weird one.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user