From 45ff70907f522b44514c8f29cc7e06d31332994c Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Fri, 25 Oct 2019 23:17:10 +0200 Subject: [PATCH 1/7] [WIP] Start reworking Irc around a FSM. To keep track of connection state instead of a complex implicit flow between handling functions. --- src/irclib.py | 210 ++++++++++++++++++++++++++++++++++++++------ test/test_irclib.py | 2 + 2 files changed, 186 insertions(+), 26 deletions(-) diff --git a/src/irclib.py b/src/irclib.py index 95b555733..1157bc10a 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -30,6 +30,7 @@ import re import copy import time +import enum import random import base64 import textwrap @@ -389,14 +390,108 @@ class ChannelState(utils.python.Object): Batch = collections.namedtuple('Batch', 'type arguments messages') +class IrcStateFsm(object): + '''Finite State Machine keeping track of what part of the connection + initialization we are in.''' + __slots__ = ('state',) + + @enum.unique + class States(enum.Enum): + UNINITIALIZED = 10 + '''Nothing received yet (except server notices)''' + + INIT_CAP_NEGOTIATION = 20 + '''Sent CAP LS, did not send CAP END yet''' + + INIT_SASL = 30 + '''In an AUTHENTICATE session''' + + INIT_WAITING_MOTD = 50 + '''Waiting for start of MOTD''' + + INIT_MOTD = 60 + '''Waiting for end of MOTD''' + + CONNECTED = 70 + '''Normal state of the connections''' + + CONNECTED_SASL = 80 + '''Doing SASL authentication in the middle of a connection.''' + + def __init__(self): + self.reset() + + def reset(self): + self.state = self.States.UNINITIALIZED + + def _transition(self, to_state, expected_from=None): + if expected_from is None or self.state in expected_from: + log.debug('transition from %s to %s', self.state, to_state) + self.state = to_state + else: + raise ValueError('unexpected transition to %s while in state %s' % + (to_state, self.state)) + + def expect_state(self, expected_states): + if self.state not in expected_states: + raise ValueError(('Connection in state %s, but expected to be ' + 'in state %s') % (self.state, expected_states)) + + def on_init_messages_sent(self): + '''As soon as USER/NICK/CAP LS are sent''' + self._transition(self.States.INIT_CAP_NEGOTIATION, [ + self.States.UNINITIALIZED, + ]) + + def on_sasl_cap(self): + '''Whenever we see the 'sasl' capability in a CAP LS response''' + if self.state == self.States.INIT_CAP_NEGOTIATION: + self._transition(self.States.INIT_SASL) + elif self.state == self.States.CONNECTED: + self._transition(self.States.CONNECTED_SASL) + else: + raise ValueError('Got sasl cap while in state %s' % self.state) + + def on_sasl_auth_finished(self): + '''When sasl auth either succeeded or failed.''' + if self.state == self.States.INIT_SASL: + self._transition(self.States.INIT_CAP_NEGOTIATION) + elif self.state == self.States.CONNECTED_SASL: + self._transition(self.States.CONNECTED) + else: + raise ValueError('Finished SASL auth while in state %s' % self.state) + + def on_cap_end(self): + '''When we send CAP END''' + self._transition(self.States.INIT_WAITING_MOTD, [ + self.States.INIT_CAP_NEGOTIATION, + ]) + + def on_start_motd(self): + '''On 375 (RPL_MOTDSTART)''' + self._transition(self.States.INIT_MOTD, [ + self.States.INIT_CAP_NEGOTIATION, + self.States.INIT_WAITING_MOTD, + ]) + + def on_end_motd(self): + '''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)''' + self._transition(self.States.CONNECTED, [ + self.States.INIT_CAP_NEGOTIATION, + self.States.INIT_WAITING_MOTD, + self.States.INIT_MOTD + ]) + class IrcState(IrcCommandDispatcher, log.Firewalled): """Maintains state of the Irc connection. Should also become smarter. """ __firewalled__ = {'addMsg': None} def __init__(self, history=None, supported=None, nicksToHostmasks=None, channels=None, + capabilities_req=None, capabilities_ack=None, capabilities_nak=None, capabilities_ls=None): + self.fsm = IrcStateFsm() if history is None: history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength()) if supported is None: @@ -405,6 +500,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): nicksToHostmasks = ircutils.IrcDict() if channels is None: channels = ircutils.IrcDict() + self.capabilities_req = capabilities_req or set() self.capabilities_ack = capabilities_ack or set() self.capabilities_nak = capabilities_nak or set() self.capabilities_ls = capabilities_ls or {} @@ -417,6 +513,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): def reset(self): """Resets the state to normal, unconnected state.""" + self.fsm.reset() self.history.reset() self.history.resize(conf.supybot.protocols.irc.maxHistoryLength()) self.ircd = None @@ -424,6 +521,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): self.supported.clear() self.nicksToHostmasks.clear() self.batches = {} + self.capabilities_req = set() self.capabilities_ack = set() self.capabilities_nak = set() self.capabilities_ls = {} @@ -1115,6 +1213,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendAuthenticationMessages() + self.state.fsm.on_init_messages_sent() + def sendAuthenticationMessages(self): # Notes: # * using sendMsg instead of queueMsg because these messages cannot @@ -1135,10 +1235,43 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendMsg(ircmsgs.user(self.ident, self.user)) + def capUpkeep(self): + self.state.fsm.expect_state([ + # Normal CAP ACK / CAP NAK during cap negotiation + IrcStateFsm.States.INIT_CAP_NEGOTIATION, + # CAP ACK / CAP NAK after a CAP NEW (probably) + IrcStateFsm.States.CONNECTED, + ]) + + capabilities_responded = (self.state.capabilities_ack | + self.state.capabilities_nak) + if not capabilities_responded <= self.state.capabilities_req: + log.error('Server responded with unrequested ACK/NAK ' + 'capabilities: req=%r, ack=%r, nak=%r', + self.state.capabilities_req, + self.state.capabilities_ack, + self.state.capabilities_nak) + self.driver.reconnect() + elif capabilities_responded == self.state.capabilities_req: + log.debug('Got all capabilities ACKed/NAKed') + # We got all the capabilities we asked for + if 'sasl' in self.state.capabilities_ack: + if self.state.fsm.state in [ + IrcStateFsm.States.INIT_CAP_NEGOTIATION, + IrcStateFsm.States.CONNECTED]: + self._maybeStartSasl() + else: + pass # Already in the middle of a SASL auth + else: + self.endCapabilityNegociation() + else: + log.debug('Waiting for ACK/NAK of capabilities: %r', + self.state.capabilities_req - capabilities_responded) + pass # Do nothing, we'll get more + def endCapabilityNegociation(self): - if not self.capNegociationEnded: - self.capNegociationEnded = True - self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + self.state.fsm.on_cap_end() + self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) def sendSaslString(self, string): for chunk in ircutils.authenticate_generator(string): @@ -1146,6 +1279,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled): args=(chunk,))) def tryNextSaslMechanism(self): + self.state.fsm.expect_state([ + IrcStateFsm.States.INIT_SASL, + IrcStateFsm.States.CONNECTED_SASL, + ]) if self.sasl_next_mechanisms: self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0) self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', @@ -1155,15 +1292,30 @@ class Irc(IrcCommandDispatcher, log.Firewalled): 'aborting connection.') else: self.sasl_current_mechanism = None - self.endCapabilityNegociation() + self.state.fsm.on_sasl_auth_finished() + if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: + self.endCapabilityNegociation() - def filterSaslMechanisms(self, available): - available = set(map(str.lower, available)) - self.sasl_next_mechanisms = [ - x for x in self.sasl_next_mechanisms - if x.lower() in available] + def _maybeStartSasl(self): + if not self.sasl_authenticated and \ + 'sasl' in self.state.capabilities_ack: + self.state.fsm.on_sasl_cap() + assert 'sasl' in self.state.capabilities_ls, ( + 'Got "CAP ACK sasl" without receiving "CAP LS sasl" or ' + '"CAP NEW sasl" first.') + s = self.state.capabilities_ls['sasl'] + if s is not None: + available = set(map(str.lower, s.split(','))) + self.sasl_next_mechanisms = [ + x for x in self.sasl_next_mechanisms + if x.lower() in available] + self.tryNextSaslMechanism() def doAuthenticate(self, msg): + self.state.fsm.expect_state([ + IrcStateFsm.States.INIT_SASL, + IrcStateFsm.States.CONNECTED_SASL, + ]) if not self.authenticate_decoder: self.authenticate_decoder = ircutils.AuthenticateDecoder() self.authenticate_decoder.feed(msg) @@ -1265,7 +1417,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled): def do903(self, msg): log.info('%s: SASL authentication successful', self.network) self.sasl_authenticated = True - self.endCapabilityNegociation() + self.state.fsm.on_sasl_auth_finished() + if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: + self.endCapabilityNegociation() def do904(self, msg): log.warning('%s: SASL authentication failed (mechanism: %s)', @@ -1301,10 +1455,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.network, caps) self.state.capabilities_ack.update(caps) - if 'sasl' in caps: - self.tryNextSaslMechanism() - else: - self.endCapabilityNegociation() + self.capUpkeep() + def doCapNak(self, msg): if len(msg.args) != 3: log.warning('Bad CAP NAK from server: %r', msg) @@ -1314,7 +1466,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_nak.update(caps) log.warning('%s: Server refused capabilities: %L', self.network, caps) - self.endCapabilityNegociation() + self.capUpkeep() + def _addCapabilities(self, capstring): for item in capstring.split(): while item.startswith(('=', '~')): @@ -1324,6 +1477,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_ls[cap] = value else: self.state.capabilities_ls[item] = None + def doCapLs(self, msg): if len(msg.args) == 4: # Multi-line LS @@ -1333,12 +1487,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._addCapabilities(msg.args[3]) elif len(msg.args) == 3: # End of LS self._addCapabilities(msg.args[2]) - - if 'sasl' in self.state.capabilities_ls: - s = self.state.capabilities_ls['sasl'] - if s is not None: - self.filterSaslMechanisms(set(s.split(','))) - + self.state.fsm.expect_state([ + # Normal case: + IrcStateFsm.States.INIT_CAP_NEGOTIATION, + # Should only happen if a plugin sends a CAP LS (which they + # shouldn't do): + IrcStateFsm.States.CONNECTED, + IrcStateFsm.States.CONNECTED_SASL, + ]) # Normally at this point, self.state.capabilities_ack should be # empty; but let's just make sure we're not requesting the same # caps twice for no reason. @@ -1356,6 +1512,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): else: log.warning('Bad CAP LS from server: %r', msg) return + def doCapDel(self, msg): if len(msg.args) != 3: log.warning('Bad CAP DEL from server: %r', msg) @@ -1374,18 +1531,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_ack.remove(cap) except KeyError: pass + def doCapNew(self, msg): + # Note that in theory, this method may be called at any time, even + # before CAP END (or even before the initial CAP LS). if len(msg.args) != 3: log.warning('Bad CAP NEW from server: %r', msg) return caps = msg.args[2].split() assert caps, 'Empty list of capabilities' self._addCapabilities(msg.args[2]) - if not self.sasl_authenticated and 'sasl' in self.state.capabilities_ls: - self.resetSasl() - s = self.state.capabilities_ls['sasl'] - if s is not None: - self.filterSaslMechanisms(set(s.split(','))) common_supported_unrequested_capabilities = ( set(self.state.capabilities_ls) & self.REQUEST_CAPABILITIES - @@ -1394,6 +1549,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._requestCaps(common_supported_unrequested_capabilities) def _requestCaps(self, caps): + self.state.capabilities_req |= caps + caps = ' '.join(sorted(caps)) # textwrap works here because in ASCII, all chars are 1 bytes: cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :')) @@ -1474,6 +1631,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.outstandingPing = False def do376(self, msg): + self.state.fsm.on_end_motd() log.info('Got end of MOTD from %s', self.server) self.afterConnect = True # Let's reset nicks in case we had to use a weird one. diff --git a/test/test_irclib.py b/test/test_irclib.py index bcd2719a8..63fff6db2 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -832,6 +832,8 @@ class SaslTestCase(SupyTestCase): while self.irc.takeMsg(): pass + self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'NEW', 'sasl=EXTERNAL'))) From ff5edd95a3e75f389ba033b73df030ee9f7e7e47 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Sat, 7 Dec 2019 20:19:03 +0100 Subject: [PATCH 2/7] Remove Twisted. There's no reason to use it anymore instead of Socket. It's already missing features compared to Socket, and I don't want to maintain it anymore so it will keep getting worse. --- src/conf.py | 6 +- src/drivers/Socket.py | 4 +- src/drivers/Twisted.py | 160 ----------------------------------------- 3 files changed, 3 insertions(+), 167 deletions(-) delete mode 100644 src/drivers/Twisted.py diff --git a/src/conf.py b/src/conf.py index 758e406ea..c1ad5ac16 100644 --- a/src/conf.py +++ b/src/conf.py @@ -880,13 +880,11 @@ registerGlobalValue(supybot.drivers, 'poll', class ValidDriverModule(registry.OnlySomeStrings): __slots__ = () - validStrings = ('default', 'Socket', 'Twisted') + validStrings = ('default', 'Socket') registerGlobalValue(supybot.drivers, 'module', ValidDriverModule('default', _("""Determines what driver module the - bot will use. The default is Socket which is simple and stable - and supports SSL. Twisted doesn't work if the IRC server which - you are connecting to has IPv6 (most of them do)."""))) + bot will use. Current, the only (and default) driver is Socket."""))) registerGlobalValue(supybot.drivers, 'maxReconnectWait', registry.PositiveFloat(300.0, _("""Determines the maximum time the bot will diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index cc265f9e5..29987d559 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -82,9 +82,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): self.resetDelay() if self.networkGroup.get('ssl').value and 'ssl' not in globals(): drivers.log.error('The Socket driver can not connect to SSL ' - 'servers for your Python version. Try the ' - 'Twisted driver instead, or install a Python' - 'version that supports SSL (2.6 and greater).') + 'servers for your Python version.') self.ssl = False else: self.ssl = self.networkGroup.get('ssl').value diff --git a/src/drivers/Twisted.py b/src/drivers/Twisted.py deleted file mode 100644 index dc65b90ae..000000000 --- a/src/drivers/Twisted.py +++ /dev/null @@ -1,160 +0,0 @@ -### -# Copyright (c) 2002-2004, Jeremiah Fincher -# Copyright (c) 2009, James McCoy -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions, and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions, and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of the author of this software nor the name of -# contributors to this software may be used to endorse or promote products -# derived from this software without specific prior written consent. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. -### - -from .. import conf, drivers - -from twisted.names import client -from twisted.internet import reactor, error -from twisted.protocols.basic import LineReceiver -from twisted.internet.protocol import ReconnectingClientFactory - - -# This hack prevents the standard Twisted resolver from starting any -# threads, which allows for a clean shut-down in Twisted>=2.0 -reactor.installResolver(client.createResolver()) - - -try: - from OpenSSL import SSL - from twisted.internet import ssl -except ImportError: - drivers.log.debug('PyOpenSSL is not available, ' - 'cannot connect to SSL servers.') - SSL = None - -class TwistedRunnerDriver(drivers.IrcDriver): - def name(self): - return self.__class__.__name__ - - def run(self): - try: - reactor.iterate(conf.supybot.drivers.poll()) - except: - drivers.log.exception('Uncaught exception outside reactor:') - -class SupyIrcProtocol(LineReceiver): - delimiter = '\n' - MAX_LENGTH = 1024 - def __init__(self): - self.mostRecentCall = reactor.callLater(0.1, self.checkIrcForMsgs) - - def lineReceived(self, line): - msg = drivers.parseMsg(line) - if msg is not None: - self.irc.feedMsg(msg) - - def checkIrcForMsgs(self): - if self.connected: - msg = self.irc.takeMsg() - while msg: - self.transport.write(str(msg)) - msg = self.irc.takeMsg() - self.mostRecentCall = reactor.callLater(0.1, self.checkIrcForMsgs) - - def connectionLost(self, r): - self.mostRecentCall.cancel() - if r.check(error.ConnectionDone): - drivers.log.disconnect(self.factory.currentServer) - else: - drivers.log.disconnect(self.factory.currentServer, errorMsg(r)) - if self.irc.zombie: - self.factory.stopTrying() - while self.irc.takeMsg(): - continue - else: - self.irc.reset() - - def connectionMade(self): - self.factory.resetDelay() - self.irc.driver = self - - def die(self): - drivers.log.die(self.irc) - self.factory.stopTrying() - self.transport.loseConnection() - - def reconnect(self, wait=None): - # We ignore wait here, because we handled our own waiting. - drivers.log.reconnect(self.irc.network) - self.transport.loseConnection() - -def errorMsg(reason): - return reason.getErrorMessage() - -class SupyReconnectingFactory(ReconnectingClientFactory, drivers.ServersMixin): - maxDelay = property(lambda self: conf.supybot.drivers.maxReconnectWait()) - protocol = SupyIrcProtocol - def __init__(self, irc): - drivers.log.warning('Twisted driver is deprecated. You should ' - 'consider switching to Socket (set ' - 'supybot.drivers.module to Socket).') - self.irc = irc - drivers.ServersMixin.__init__(self, irc) - (server, port) = self._getNextServer() - vhost = conf.supybot.protocols.irc.vhost() - if self.networkGroup.get('ssl').value: - self.connectSSL(server, port, vhost) - else: - self.connectTCP(server, port, vhost) - - def connectTCP(self, server, port, vhost): - """Connect to the server with a standard TCP connection.""" - reactor.connectTCP(server, port, self, bindAddress=(vhost, 0)) - - def connectSSL(self, server, port, vhost): - """Connect to the server using an SSL socket.""" - drivers.log.info('Attempting an SSL connection.') - if SSL: - reactor.connectSSL(server, port, self, - ssl.ClientContextFactory(), bindAddress=(vhost, 0)) - else: - drivers.log.error('PyOpenSSL is not available. Not connecting.') - - def clientConnectionFailed(self, connector, r): - drivers.log.connectError(self.currentServer, errorMsg(r)) - (connector.host, connector.port) = self._getNextServer() - ReconnectingClientFactory.clientConnectionFailed(self, connector,r) - - def clientConnectionLost(self, connector, r): - (connector.host, connector.port) = self._getNextServer() - ReconnectingClientFactory.clientConnectionLost(self, connector, r) - - def startedConnecting(self, connector): - drivers.log.connect(self.currentServer) - - def buildProtocol(self, addr): - protocol = ReconnectingClientFactory.buildProtocol(self, addr) - protocol.irc = self.irc - return protocol - -Driver = SupyReconnectingFactory -poller = TwistedRunnerDriver() - -# vim:set shiftwidth=4 softtabstop=4 expandtab textwidth=79: From ecc2c32950efd0389d9d2c6b37ea7fb2728ab1d1 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Sat, 7 Dec 2019 23:33:04 +0100 Subject: [PATCH 3/7] Add support for storing STS policies. If on an insecure connection: reconnect. If on a secure connect: store it and do nothing else. For now, stored STS policies are not read when connecting to an insecure server. --- plugins/Owner/plugin.py | 6 +- src/conf.py | 16 +++- src/drivers/Socket.py | 52 ++++++++----- src/drivers/__init__.py | 11 ++- src/ircdb.py | 157 ++++++++++++++++++++++++++++++++++++++-- src/irclib.py | 29 ++++++++ src/ircutils.py | 25 +++++++ test/test_ircdb.py | 129 +++++++++++++++++++++++++++++++++ test/test_irclib.py | 59 ++++++++++++++- 9 files changed, 449 insertions(+), 35 deletions(-) diff --git a/plugins/Owner/plugin.py b/plugins/Owner/plugin.py index 056ce8b02..6f40736b7 100644 --- a/plugins/Owner/plugin.py +++ b/plugins/Owner/plugin.py @@ -148,14 +148,14 @@ class Owner(callbacks.Plugin): def _connect(self, network, serverPort=None, password='', ssl=False): try: group = conf.supybot.networks.get(network) - (server, port) = group.servers()[0] + group.servers()[0] except (registry.NonExistentRegistryEntry, IndexError): if serverPort is None: raise ValueError('connect requires a (server, port) ' \ 'if the network is not registered.') conf.registerNetwork(network, password, ssl) - serverS = '%s:%s' % serverPort - conf.supybot.networks.get(network).servers.append(serverS) + server = '%s:%s' % serverPort + conf.supybot.networks.get(network).servers.append(server) assert conf.supybot.networks.get(network).servers(), \ 'No servers are set for the %s network.' % network self.log.debug('Creating new Irc for %s.', network) diff --git a/src/conf.py b/src/conf.py index c1ad5ac16..a0b0bd269 100644 --- a/src/conf.py +++ b/src/conf.py @@ -273,15 +273,17 @@ class Servers(registry.SpaceSeparatedListOfStrings): return s def convert(self, s): + from .drivers import Server + s = self.normalize(s) - (server, port) = s.rsplit(':', 1) + (hostname, port) = s.rsplit(':', 1) # support for `[ipv6]:port` format - if server.startswith("[") and server.endswith("]"): - server = server[1:-1] + if hostname.startswith("[") and hostname.endswith("]"): + hostname = hostname[1:-1] port = int(port) - return (server, port) + return Server(hostname, port, force_tls_verification=False) def __call__(self): L = registry.SpaceSeparatedListOfStrings.__call__(self) @@ -1039,6 +1041,12 @@ registerGlobalValue(supybot.databases.channels, 'filename', for the channels database. This file will go into the directory specified by the supybot.directories.conf variable."""))) +registerGroup(supybot.databases, 'networks') +registerGlobalValue(supybot.databases.networks, 'filename', + registry.String('networks.conf', _("""Determines what filename will be used + for the networks database. This file will go into the directory specified + by the supybot.directories.conf variable."""))) + # TODO This will need to do more in the future (such as making sure link.allow # will let the link occur), but for now let's just leave it as this. class ChannelSpecific(registry.Boolean): diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index 29987d559..ff473c768 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -35,12 +35,12 @@ Contains simple socket drivers. Asyncore bugged (haha, pun!) me. from __future__ import division import os +import sys import time import errno import threading import select import socket -import sys try: import ipaddress # Python >= 3.3 or backported ipaddress @@ -221,7 +221,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): def connect(self, **kwargs): self.reconnect(reset=False, **kwargs) - def reconnect(self, wait=False, reset=True): + def reconnect(self, wait=False, reset=True, server=None): self._attempt += 1 self.nextReconnectTime = None if self.connected: @@ -242,7 +242,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): if wait: self.scheduleReconnect() return - self.server = self._getNextServer() + self.server = server or self._getNextServer() network_config = getattr(conf.supybot.networks, self.irc.network) socks_proxy = network_config.socksproxy() try: @@ -252,20 +252,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): log.error('Cannot use socks proxy (SocksiPy not installed), ' 'using direct connection instead.') socks_proxy = '' - if socks_proxy: - address = self.server[0] else: try: - address = utils.net.getAddressFromHostname(self.server[0], - attempt=self._attempt) + hostname = utils.net.getAddressFromHostname( + self.server.hostname, + attempt=self._attempt) except (socket.gaierror, socket.error) as e: drivers.log.connectError(self.currentServer, e) self.scheduleReconnect() return - port = self.server[1] drivers.log.connect(self.currentServer) try: - self.conn = utils.net.getSocket(address, port=port, + self.conn = utils.net.getSocket( + self.server.hostname, + port=self.server.port, socks_proxy=socks_proxy, vhost=conf.supybot.protocols.irc.vhost(), vhostv6=conf.supybot.protocols.irc.vhostv6(), @@ -280,12 +280,12 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): try: # Connect before SSL, otherwise SSL is disabled if we use SOCKS. # See http://stackoverflow.com/q/16136916/539465 - self.conn.connect((address, port)) - if network_config.ssl(): + self.conn.connect((self.server.hostname, self.server.port)) + if network_config.ssl() or self.server.force_tls_verification: self.starttls() # Suppress this warning for loopback IPs. - targetip = address + targetip = hostname if sys.version_info[0] < 3: # Backported Python 2 ipaddress demands unicode instead of str targetip = targetip.decode('utf-8') @@ -361,6 +361,15 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): def name(self): return '%s(%s)' % (self.__class__.__name__, self.irc) + def anyCertValidationEnabled(self): + """Returns whether any kind of certificate validation is enabled, other + than Server.force_tls_verification.""" + return any([ + conf.supybot.protocols.ssl.verifyCertificates(), + network_config.ssl.serverFingerprints(), + network_config.ssl.authorityCertificate(), + ]) + def starttls(self): assert 'ssl' in globals() network_config = getattr(conf.supybot.networks, self.irc.network) @@ -373,15 +382,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): drivers.log.warning('Could not find cert file %s.' % certfile) certfile = None - verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates() - if not verifyCertificates: - drivers.log.warning('Not checking SSL certificates, connections ' - 'are vulnerable to man-in-the-middle attacks. Set ' - 'supybot.protocols.ssl.verifyCertificates to "true" ' - 'to enable validity checks.') + if self.server.force_tls_verification \ + and not self.anyCertValidationEnabled(): + verifyCertificates = True + else: + verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates() + if not verifyCertificates: + drivers.log.warning('Not checking SSL certificates, connections ' + 'are vulnerable to man-in-the-middle attacks. Set ' + 'supybot.protocols.ssl.verifyCertificates to "true" ' + 'to enable validity checks.') try: self.conn = utils.net.ssl_wrap_socket(self.conn, - logger=drivers.log, hostname=self.server[0], + logger=drivers.log, + hostname=self.server.hostname, certfile=certfile, verify=verifyCertificates, trusted_fingerprints=network_config.ssl.serverFingerprints(), diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index 89f90e481..f1de62c17 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -33,10 +33,19 @@ Contains various drivers (network, file, and otherwise) for using IRC objects. """ import socket +from collections import namedtuple from .. import conf, ircmsgs, log as supylog, utils from ..utils import minisix + +Server = namedtuple('Server', 'hostname port force_tls_verification') +# force_tls_verification=True implies two things: +# 1. force TLS to be enabled for this server +# 2. ensure there is some kind of verification. If the user did not enable +# any, use standard PKI validation. + + _drivers = {} _deadDrivers = set() _newDrivers = [] @@ -80,7 +89,7 @@ class ServersMixin(object): assert self.servers, 'Servers value for %s is empty.' % \ self.networkGroup._name server = self.servers.pop(0) - self.currentServer = '%s:%s' % server + self.currentServer = '%s:%s' % (server.hostname, server.port) return server diff --git a/src/ircdb.py b/src/ircdb.py index 773caf565..0f23ffc15 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -496,6 +496,44 @@ class IrcChannel(object): fd.write(os.linesep) +class IrcNetwork(object): + """This class holds dynamic information about a network that should be + preserved across restarts.""" + __slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes') + + def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None): + self.name = name + self.stsPolicies = stsPolicies or {} + self.lastDisconnectTimes = lastDisconnectTimes or {} + + def __repr__(self): + return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \ + (self.__class__.__name, self.name, self.stsPolicy, + self.lastDisconnectTimes) + + def addStsPolicy(self, server, stsPolicy): + assert isinstance(stsPolicy, str) + self.stsPolicies[server] = stsPolicy + + def addDisconnection(self, server): + self.lastDisconnectTimes[server] = int(time.time()) + + def preserve(self, fd, indent=''): + def write(s): + fd.write(indent) + fd.write(s) + fd.write(os.linesep) + + for (server, stsPolicy) in sorted(self.stsPolicies.items()): + write('stsPolicy %s %s' % (server, stsPolicy)) + + for (server, disconnectTime) in \ + sorted(self.lastDisconnectTimes.items()): + write('lastDisconnectTime %s %s' % (server, disconnectTime)) + + fd.write(os.linesep) + + class Creator(object): __slots__ = () def badCommand(self, command, rest, lineno): @@ -615,6 +653,31 @@ class IrcChannelCreator(Creator): IrcChannelCreator.name = None +class IrcNetworkCreator(Creator): + __slots__ = ('net', 'networks') + + def __init__(self, networks): + self.net = IrcNetwork() + self.networks = networks + + def network(self, rest, lineno): + self.net.name = rest + + def stspolicy(self, rest, lineno): + (server, stsPolicy) = rest.split() + self.net.addStsPolicy(server, stsPolicy) + + def lastdisconnecttime(self, rest, lineno): + (server, when) = rest.split() + when = int(when) + self.net.lastDisconnectTimes[server] = when + + def finish(self): + if self.net.name: + self.networks.setNetwork(self.net) + self.name = None + + class DuplicateHostmask(ValueError): pass @@ -666,10 +729,8 @@ class UsersDictionary(utils.IterableMap): """Flushes the database to its file.""" if not self.noFlush: if self.filename is not None: - L = list(self.users.items()) - L.sort() fd = utils.file.AtomicFile(self.filename) - for (id, u) in L: + for (id, u) in sorted(self.users.items()): fd.write('user %s' % id) fd.write(os.linesep) u.preserve(fd, indent=' ') @@ -861,7 +922,7 @@ class ChannelsDictionary(utils.IterableMap): if not self.noFlush: if self.filename is not None: fd = utils.file.AtomicFile(self.filename) - for (channel, c) in self.channels.items(): + for (channel, c) in sorted(self.channels.items()): fd.write('channel %s' % channel) fd.write(os.linesep) c.preserve(fd, indent=' ') @@ -907,6 +968,83 @@ class ChannelsDictionary(utils.IterableMap): def items(self): return self.channels.items() +class NetworksDictionary(utils.IterableMap): + __slots__ = ('noFlush', 'filename', 'networks') + + def __init__(self): + self.noFlush = False + self.filename = None + self.networks = ircutils.IrcDict() + + def open(self, filename): + self.noFlush = True + try: + self.filename = filename + reader = unpreserve.Reader(IrcNetworkCreator, self) + try: + reader.readFile(filename) + self.noFlush = False + self.flush() + except EnvironmentError as e: + log.error('Invalid network database, resetting to empty.') + log.error('Exact error: %s', utils.exnToString(e)) + except Exception as e: + log.error('Invalid network database, resetting to empty.') + log.exception('Exact error:') + finally: + self.noFlush = False + + def flush(self): + """Flushes the network database to its file.""" + if not self.noFlush: + if self.filename is not None: + fd = utils.file.AtomicFile(self.filename) + for (network, net) in sorted(self.networks.items()): + fd.write('network %s' % network) + fd.write(os.linesep) + net.preserve(fd, indent=' ') + fd.close() + else: + log.warning('NetworksDictionary.flush without self.filename.') + else: + log.debug('Not flushing NetworksDictionary because of noFlush.') + + def close(self): + self.flush() + if self.flush in world.flushers: + world.flushers.remove(self.flush) + self.networks.clear() + + def reload(self): + """Reloads the network database from its file.""" + if self.filename is not None: + self.networks.clear() + try: + self.open(self.filename) + except EnvironmentError as e: + log.warning('NetworksDictionary.reload failed: %s', e) + else: + log.warning('NetworksDictionary.reload without self.filename.') + + def getNetwork(self, network): + """Returns an IrcNetwork object for the given network.""" + network = network.lower() + if network in self.networks: + return self.networks[network] + else: + c = IrcNetwork() + self.networks[network] = c + return c + + def setNetwork(self, network, ircNetwork): + """Sets a given network to the IrcNetwork object given.""" + network = network.lower() + self.networks[network] = ircNetwork + self.flush() + + def items(self): + return self.networks.items() + class IgnoresDB(object): __slots__ = ('filename', 'hostmasks') @@ -996,6 +1134,14 @@ try: except EnvironmentError as e: log.warning('Couldn\'t open channel database: %s', e) +try: + networkFile = os.path.join(confDir, + conf.supybot.databases.networks.filename()) + networks = NetworksDictionary() + networks.open(networkFile) +except EnvironmentError as e: + log.warning('Couldn\'t open network database: %s', e) + try: ignoreFile = os.path.join(confDir, conf.supybot.databases.ignores.filename()) @@ -1006,8 +1152,9 @@ except EnvironmentError as e: world.flushers.append(users.flush) -world.flushers.append(ignores.flush) world.flushers.append(channels.flush) +world.flushers.append(networks.flush) +world.flushers.append(ignores.flush) ### diff --git a/src/irclib.py b/src/irclib.py index 1157bc10a..4686ae9ff 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -55,6 +55,7 @@ except ImportError: scram = None from . import conf, ircdb, ircmsgs, ircutils, log, utils, world +from .drivers import Server from .utils.str import rsplit from .utils.iter import chain from .utils.structures import smallqueue, RingBuffer @@ -1468,14 +1469,42 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.network, caps) self.capUpkeep() + def _onCapSts(self, policy): + parsed_policy = ircutils._parseStsPolicy(log, policy) + if parsed_policy is None: + # There was an error (and it was logged). Abort the connection. + self.driver.reconnect() + return + + if not self.driver.ssl or not self.driver.anyCertValidationEnabled(): + hostname = self.driver.server.hostname + # Reconnect to the server, but with TLS *and* certificate + # validation this time. + self.driver.reconnect( + server=Server(hostname, parsed_policy['port'], True)) + else: + # TLS is enabled and certificate is verified; write the STS policy + # in stone. + # For future-proofing (because we don't want to write an invalid + # value), we write the raw policy received from the server instead + # of the parsed one. + ircdb.networks.getNetwork(self.network).addStsPolicy( + self.driver.server.hostname, policy) + def _addCapabilities(self, capstring): for item in capstring.split(): while item.startswith(('=', '~')): item = item[1:] if '=' in item: (cap, value) = item.split('=', 1) + if cap == 'sts': + self._onCapSts(value) self.state.capabilities_ls[cap] = value else: + if item == 'sts': + log.error('Got "sts" capability without value. Aborting ' + 'connection.') + self.driver.reconnect() self.state.capabilities_ls[item] = None def doCapLs(self, msg): diff --git a/src/ircutils.py b/src/ircutils.py index 8bc8c0540..2c9556943 100644 --- a/src/ircutils.py +++ b/src/ircutils.py @@ -931,6 +931,31 @@ class AuthenticateDecoder(object): return base64.b64decode(b''.join(self.chunks)) +def _parseStsPolicy(logger, policy): + parsed_policy = {} + for kv in policy.split(','): + if '=' in kv: + (k, v) = kv.split('=', 1) + parsed_policy[k] = v + else: + parsed_policy[kv] = None + + for key in ('port', 'duration'): + if parsed_policy.get(key) is None: + logger.error('Missing or empty "%s" key in STS policy.' + 'Aborting connection.', key) + return None + try: + parsed_policy[key] = int(parsed_policy[key]) + except ValueError: + logger.error('Expected integer as value for key "%s" in STS ' + 'policy, got %r instead. Aborting connection.', + key, parsed_policy[key]) + return None + + return parsed_policy + + numerics = { # <= 2.10 # Reply diff --git a/test/test_ircdb.py b/test/test_ircdb.py index 58a6627d0..d8f716618 100644 --- a/test/test_ircdb.py +++ b/test/test_ircdb.py @@ -31,11 +31,13 @@ from supybot.test import * import os import unittest +import unittest.mock import supybot.conf as conf import supybot.world as world import supybot.ircdb as ircdb import supybot.ircutils as ircutils +from supybot.utils.minisix import io class IrcdbTestCase(SupyTestCase): def setUp(self): @@ -347,6 +349,50 @@ class IrcChannelTestCase(IrcdbTestCase): c.removeBan(banmask) self.assertFalse(c.checkIgnored(prefix)) +class IrcNetworkTestCase(IrcdbTestCase): + def testDefaults(self): + n = ircdb.IrcNetwork() + self.assertIsNone(n.name) + self.assertEqual(n.stsPolicies, {}) + self.assertEqual(n.lastDisconnectTimes, {}) + + def testStsPolicy(self): + n = ircdb.IrcNetwork() + n.addStsPolicy('foo', 'bar') + n.addStsPolicy('baz', 'qux') + self.assertEqual(n.stsPolicies, { + 'foo': 'bar', + 'baz': 'qux', + }) + + def testAddDisconnection(self): + n = ircdb.IrcNetwork() + min_ = int(time.time()) + n.addDisconnection('foo') + max_ = int(time.time()) + self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_) + + def testPreserve(self): + n = ircdb.IrcNetwork('foonet') + n.addStsPolicy('foo', 'sts1') + n.addStsPolicy('bar', 'sts2') + n.addDisconnection('foo') + n.addDisconnection('baz') + disconnect_time_foo = n.lastDisconnectTimes['foo'] + disconnect_time_baz = n.lastDisconnectTimes['baz'] + fd = io.StringIO() + n.preserve(fd, indent=' ') + fd.seek(0) + self.assertCountEqual(fd.read().split('\n'), [ + ' stsPolicy foo sts1', + ' stsPolicy bar sts2', + ' lastDisconnectTime foo %d' % disconnect_time_foo, + ' lastDisconnectTime baz %d' % disconnect_time_baz, + '', + '', + ]) + + class UsersDictionaryTestCase(IrcdbTestCase): filename = os.path.join(conf.supybot.directories.conf(), 'UsersDictionaryTestCase.conf') @@ -401,6 +447,89 @@ class UsersDictionaryTestCase(IrcdbTestCase): self.assertRaises(ValueError, self.users.setUser, u2) +class NetworksDictionaryTestCase(IrcdbTestCase): + filename = os.path.join(conf.supybot.directories.conf(), + 'NetworksDictionaryTestCase.conf') + def setUp(self): + try: + os.remove(self.filename) + except: + pass + self.networks = ircdb.NetworksDictionary() + IrcdbTestCase.setUp(self) + + def testGetSetNetwork(self): + self.assertEqual(self.networks.getNetwork('foo').name, None) + + n = ircdb.IrcNetwork() + n.name = 'foo' + self.networks.setNetwork('foo', n) + self.assertEqual(self.networks.getNetwork('foo').name, 'foo') + + def testPreserveOne(self): + n = ircdb.IrcNetwork('foonet') + n.addStsPolicy('foo', 'sts1') + n.addStsPolicy('bar', 'sts2') + n.addDisconnection('foo') + n.addDisconnection('baz') + disconnect_time_foo = n.lastDisconnectTimes['foo'] + disconnect_time_baz = n.lastDisconnectTimes['baz'] + self.networks.setNetwork('foonet', n) + + fd = io.StringIO() + fd.close = lambda: None + self.networks.filename = 'blah' + original_Atomicfile = utils.file.AtomicFile + with unittest.mock.patch( + 'supybot.utils.file.AtomicFile', return_value=fd): + self.networks.flush() + + lines = fd.getvalue().split('\n') + self.assertEqual(lines.pop(0), 'network foonet') + self.assertCountEqual(lines, [ + ' stsPolicy foo sts1', + ' stsPolicy bar sts2', + ' lastDisconnectTime foo %d' % disconnect_time_foo, + ' lastDisconnectTime baz %d' % disconnect_time_baz, + '', + '', + ]) + + def testPreserveThree(self): + n = ircdb.IrcNetwork('foonet') + n.addStsPolicy('foo', 'sts1') + self.networks.setNetwork('foonet', n) + + n = ircdb.IrcNetwork('barnet') + n.addStsPolicy('bar', 'sts2') + self.networks.setNetwork('barnet', n) + + n = ircdb.IrcNetwork('baznet') + n.addStsPolicy('baz', 'sts3') + self.networks.setNetwork('baznet', n) + + fd = io.StringIO() + fd.close = lambda: None + self.networks.filename = 'blah' + original_Atomicfile = utils.file.AtomicFile + with unittest.mock.patch( + 'supybot.utils.file.AtomicFile', return_value=fd): + self.networks.flush() + + fd.seek(0) + self.assertEqual(fd.getvalue(), + 'network barnet\n' + ' stsPolicy bar sts2\n' + '\n' + 'network baznet\n' + ' stsPolicy baz sts3\n' + '\n' + 'network foonet\n' + ' stsPolicy foo sts1\n' + '\n' + ) + + class CheckCapabilityTestCase(IrcdbTestCase): filename = os.path.join(conf.supybot.directories.conf(), 'CheckCapabilityTestCase.conf') diff --git a/test/test_irclib.py b/test/test_irclib.py index 63fff6db2..569750a06 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -27,14 +27,15 @@ # POSSIBILITY OF SUCH DAMAGE. ### -from supybot.test import * - import copy import pickle -import warnings +import unittest.mock + +from supybot.test import * import supybot.conf as conf import supybot.irclib as irclib +import supybot.drivers as drivers import supybot.ircmsgs as ircmsgs import supybot.ircutils as ircutils @@ -497,6 +498,58 @@ class IrcCapsTestCase(SupyTestCase): self.assertEqual(m.args[0], 'REQ', m) self.assertEqual(m.args[1], 'b'*400) +class StsTestCase(SupyTestCase): + def setUp(self): + self.irc = irclib.Irc('test') + + m = self.irc.takeMsg() + self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m) + self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m) + + m = self.irc.takeMsg() + self.failUnless(m.command == 'NICK', 'Expected NICK, got %r.' % m) + + m = self.irc.takeMsg() + self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m) + + self.irc.driver = unittest.mock.Mock() + + def tearDown(self): + ircdb.networks.networks = {} + + def testStsInSecureConnection(self): + self.irc.driver.anyCertValidationEnabled.return_value = True + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6697, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=duration=42,port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, { + 'irc.test': 'duration=42,port=6697'}) + self.irc.driver.reconnect.assert_not_called() + + def testStsInInsecureTlsConnection(self): + self.irc.driver.anyCertValidationEnabled.return_value = False + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6697, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=duration=42,port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) + self.irc.driver.reconnect.assert_called_once_with( + server=drivers.Server('irc.test', 6697, True)) + + def testStsInCleartextConnection(self): + self.irc.driver.anyCertValidationEnabled.return_value = False + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=duration=42,port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) + self.irc.driver.reconnect.assert_called_once_with( + server=drivers.Server('irc.test', 6697, True)) + class IrcTestCase(SupyTestCase): def setUp(self): self.irc = irclib.Irc('test') From 51ff013fcc6fb6c10264db3ee91115752c10f462 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Sun, 8 Dec 2019 15:54:48 +0100 Subject: [PATCH 4/7] Apply STS policies when connecting to a server. --- src/drivers/Socket.py | 21 +++++--- src/drivers/__init__.py | 45 ++++++++++++++--- src/ircdb.py | 8 ++- src/irclib.py | 19 ++++--- src/ircutils.py | 6 ++- src/world.py | 3 +- test/test_drivers.py | 108 ++++++++++++++++++++++++++++++++++++++++ test/test_irclib.py | 26 ++++++++++ 8 files changed, 210 insertions(+), 26 deletions(-) create mode 100644 test/test_drivers.py diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index ff473c768..1770f2e95 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -225,6 +225,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): self._attempt += 1 self.nextReconnectTime = None if self.connected: + self.onDisconnect() drivers.log.reconnect(self.irc.network) if self in self._instances: self._instances.remove(self) @@ -242,7 +243,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): if wait: self.scheduleReconnect() return - self.server = server or self._getNextServer() + self.currentServer = server or self._getNextServer() network_config = getattr(conf.supybot.networks, self.irc.network) socks_proxy = network_config.socksproxy() try: @@ -255,7 +256,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): else: try: hostname = utils.net.getAddressFromHostname( - self.server.hostname, + self.currentServer.hostname, attempt=self._attempt) except (socket.gaierror, socket.error) as e: drivers.log.connectError(self.currentServer, e) @@ -264,8 +265,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): drivers.log.connect(self.currentServer) try: self.conn = utils.net.getSocket( - self.server.hostname, - port=self.server.port, + self.currentServer.hostname, + port=self.currentServer.port, socks_proxy=socks_proxy, vhost=conf.supybot.protocols.irc.vhost(), vhostv6=conf.supybot.protocols.irc.vhostv6(), @@ -280,8 +281,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): try: # Connect before SSL, otherwise SSL is disabled if we use SOCKS. # See http://stackoverflow.com/q/16136916/539465 - self.conn.connect((self.server.hostname, self.server.port)) - if network_config.ssl() or self.server.force_tls_verification: + self.conn.connect( + (self.currentServer.hostname, self.currentServer.port)) + if network_config.ssl() \ + or self.currentServer.force_tls_verification: self.starttls() # Suppress this warning for loopback IPs. @@ -351,6 +354,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): if self.writeCheckTime is not None: self.writeCheckTime = None drivers.log.die(self.irc) + drivers.IrcDriver.die(self) + drivers.ServersMixin.die(self) def _reallyDie(self): if self.conn is not None: @@ -382,7 +387,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): drivers.log.warning('Could not find cert file %s.' % certfile) certfile = None - if self.server.force_tls_verification \ + if self.currentServer.force_tls_verification \ and not self.anyCertValidationEnabled(): verifyCertificates = True else: @@ -395,7 +400,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): try: self.conn = utils.net.ssl_wrap_socket(self.conn, logger=drivers.log, - hostname=self.server.hostname, + hostname=self.currentServer.hostname, certfile=certfile, verify=verifyCertificates, trusted_fingerprints=network_config.ssl.serverFingerprints(), diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index f1de62c17..969454b00 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -32,10 +32,11 @@ Contains various drivers (network, file, and otherwise) for using IRC objects. """ +import time import socket from collections import namedtuple -from .. import conf, ircmsgs, log as supylog, utils +from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils from ..utils import minisix @@ -73,6 +74,7 @@ class IrcDriver(object): class ServersMixin(object): def __init__(self, irc, servers=()): + self.networkName = irc.network self.networkGroup = conf.supybot.networks.get(irc.network) self.servers = servers super(ServersMixin, self).__init__() @@ -89,8 +91,36 @@ class ServersMixin(object): assert self.servers, 'Servers value for %s is empty.' % \ self.networkGroup._name server = self.servers.pop(0) - self.currentServer = '%s:%s' % (server.hostname, server.port) - return server + self.currentServer = self._applyStsPolicy(server) + return self.currentServer + + def _applyStsPolicy(self, server): + network = ircdb.networks.getNetwork(self.networkName) + policy = network.stsPolicies.get(server.hostname) + lastDisconnect = network.lastDisconnectTimes.get(server.hostname) + + if policy is None or lastDisconnect is None: + return server + + # The policy was stored, which means it was received on a secure + # connection. + policy = ircutils.parseStsPolicy(log, policy, parseDuration=True) + + if lastDisconnect + policy['duration'] < time.time(): + network.expireStsPolicy(server.hostname) + return server + + # Change the port, and force TLS verification, as required by the STS + # specification. + return Server(server.hostname, policy['port'], + force_tls_verification=True) + + def die(self): + self.onDisconnect() + + def onDisconnect(self): + network = ircdb.networks.getNetwork(self.networkName) + network.addDisconnection(self.currentServer.hostname) def empty(): @@ -138,7 +168,8 @@ def run(): class Log(object): """This is used to have a nice, consistent interface for drivers to use.""" def connect(self, server): - self.info('Connecting to %s.', server) + self.info('Connecting to %s:%s.', + server.hostname, server.port) def connectError(self, server, e): if isinstance(e, Exception): @@ -146,7 +177,8 @@ class Log(object): e = e.args[1] else: e = utils.exnToString(e) - self.warning('Error connecting to %s: %s', server, e) + self.warning('Error connecting to %s:%s: %s', + server.hostname, server.port, e) def disconnect(self, server, e=None): if e: @@ -156,7 +188,8 @@ class Log(object): e = str(e) if not e.endswith('.'): e += '.' - self.warning('Disconnect from %s: %s', server, e) + self.warning('Disconnect from %s:%s: %s', + server.hostname, server.port, e) else: self.info('Disconnect from %s.', server) diff --git a/src/ircdb.py b/src/ircdb.py index 0f23ffc15..462fe57b9 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -515,6 +515,10 @@ class IrcNetwork(object): assert isinstance(stsPolicy, str) self.stsPolicies[server] = stsPolicy + def expireStsPolicy(self, server): + if server in self.stsPolicies: + del self.stsPolicies[server] + def addDisconnection(self, server): self.lastDisconnectTimes[server] = int(time.time()) @@ -674,8 +678,8 @@ class IrcNetworkCreator(Creator): def finish(self): if self.net.name: - self.networks.setNetwork(self.net) - self.name = None + self.networks.setNetwork(self.net.name, self.net) + self.net = IrcNetwork() class DuplicateHostmask(ValueError): diff --git a/src/irclib.py b/src/irclib.py index 4686ae9ff..ff19d3719 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1470,19 +1470,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.capUpkeep() def _onCapSts(self, policy): - parsed_policy = ircutils._parseStsPolicy(log, policy) + secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled() + + parsed_policy = ircutils.parseStsPolicy( + log, policy, parseDuration=secure_connection) if parsed_policy is None: # There was an error (and it was logged). Abort the connection. self.driver.reconnect() return - if not self.driver.ssl or not self.driver.anyCertValidationEnabled(): - hostname = self.driver.server.hostname - # Reconnect to the server, but with TLS *and* certificate - # validation this time. - self.driver.reconnect( - server=Server(hostname, parsed_policy['port'], True)) - else: + if secure_connection: # TLS is enabled and certificate is verified; write the STS policy # in stone. # For future-proofing (because we don't want to write an invalid @@ -1490,6 +1487,12 @@ class Irc(IrcCommandDispatcher, log.Firewalled): # of the parsed one. ircdb.networks.getNetwork(self.network).addStsPolicy( self.driver.server.hostname, policy) + else: + hostname = self.driver.server.hostname + # Reconnect to the server, but with TLS *and* certificate + # validation this time. + self.driver.reconnect( + server=Server(hostname, parsed_policy['port'], True)) def _addCapabilities(self, capstring): for item in capstring.split(): diff --git a/src/ircutils.py b/src/ircutils.py index 2c9556943..2d56e469a 100644 --- a/src/ircutils.py +++ b/src/ircutils.py @@ -931,7 +931,7 @@ class AuthenticateDecoder(object): return base64.b64decode(b''.join(self.chunks)) -def _parseStsPolicy(logger, policy): +def parseStsPolicy(logger, policy, parseDuration): parsed_policy = {} for kv in policy.split(','): if '=' in kv: @@ -941,6 +941,10 @@ def _parseStsPolicy(logger, policy): parsed_policy[kv] = None for key in ('port', 'duration'): + if key == 'duration' and not parseDuration: + if key in parsed_policy: + del parsed_policy[key] + continue if parsed_policy.get(key) is None: logger.error('Missing or empty "%s" key in STS policy.' 'Aborting connection.', key) diff --git a/src/world.py b/src/world.py index eec3b3e65..b00d112d9 100644 --- a/src/world.py +++ b/src/world.py @@ -42,7 +42,7 @@ import multiprocessing import re -from . import conf, drivers, ircutils, log, registry +from . import conf, ircutils, log, registry from .utils import minisix startedAt = time.time() # Just in case it doesn't get set later. @@ -193,6 +193,7 @@ def upkeep(): def makeDriversDie(): """Kills drivers.""" + from . import drivers log.info('Killing Driver objects.') for driver in drivers._drivers.values(): driver.die() diff --git a/test/test_drivers.py b/test/test_drivers.py new file mode 100644 index 000000000..cc7b4fafd --- /dev/null +++ b/test/test_drivers.py @@ -0,0 +1,108 @@ +## +# Copyright (c) 2019, Valentin Lorentz +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions, and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the author of this software nor the name of +# contributors to this software may be used to endorse or promote products +# derived from this software without specific prior written consent. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +### + +from supybot.test import * +import supybot.ircdb as ircdb +import supybot.irclib as irclib +import supybot.drivers as drivers + +class DriversTestCase(SupyTestCase): + def tearDown(self): + ircdb.networks.networks = {} + + def testValidStsPolicy(self): + irc = irclib.Irc('test') + net = ircdb.networks.getNetwork('test') + net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addDisconnection('example.com') + + with conf.supybot.networks.test.servers.context( + ['example.com:6667', 'example.org:6667']): + + driver = drivers.ServersMixin(irc) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) + driver.die() + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.org', 6667, False)) + driver.die() + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) + + def testExpiredStsPolicy(self): + irc = irclib.Irc('test') + net = ircdb.networks.getNetwork('test') + net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addDisconnection('example.com') + + timeFastForward(16) + + with conf.supybot.networks.test.servers.context( + ['example.com:6667']): + + driver = drivers.ServersMixin(irc) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6667, False)) + + def testRescheduledStsPolicy(self): + irc = irclib.Irc('test') + net = ircdb.networks.getNetwork('test') + net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addDisconnection('example.com') + + with conf.supybot.networks.test.servers.context( + ['example.com:6667', 'example.org:6667']): + + driver = drivers.ServersMixin(irc) + + timeFastForward(8) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) + driver.die() + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.org', 6667, False)) + driver.die() + + timeFastForward(8) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) diff --git a/test/test_irclib.py b/test/test_irclib.py index 569750a06..288c3aef4 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -550,6 +550,32 @@ class StsTestCase(SupyTestCase): self.irc.driver.reconnect.assert_called_once_with( server=drivers.Server('irc.test', 6697, True)) + def testStsInCleartextConnectionInvalidDuration(self): + # "Servers MAY send this key to all clients, but insecurely + # connected clients MUST ignore it." + self.irc.driver.anyCertValidationEnabled.return_value = False + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=duration=foo,port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) + self.irc.driver.reconnect.assert_called_once_with( + server=drivers.Server('irc.test', 6697, True)) + + def testStsInCleartextConnectionNoDuration(self): + # "Servers MAY send this key to all clients, but insecurely + # connected clients MUST ignore it." + self.irc.driver.anyCertValidationEnabled.return_value = False + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) + self.irc.driver.reconnect.assert_called_once_with( + server=drivers.Server('irc.test', 6697, True)) + class IrcTestCase(SupyTestCase): def setUp(self): self.irc = irclib.Irc('test') From 22120ee862f0728170998b93598dc2d7f7c2c165 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Sun, 8 Dec 2019 21:25:59 +0100 Subject: [PATCH 5/7] Fix various issues with STS handling. --- src/drivers/Socket.py | 12 +++++++++--- src/drivers/__init__.py | 6 ++++++ src/ircdb.py | 16 ++++++++-------- src/irclib.py | 34 +++++++++++++++++++++++++++------- test/test_ircdb.py | 16 +++++++--------- test/test_irclib.py | 22 +++++++++++++--------- 6 files changed, 70 insertions(+), 36 deletions(-) diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index 1770f2e95..0017c6d20 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -241,6 +241,9 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): else: drivers.log.debug('Not resetting %s.', self.irc) if wait: + if server is not None: + # Make this server be the next one to be used. + self.servers.insert(0, server) self.scheduleReconnect() return self.currentServer = server or self._getNextServer() @@ -283,8 +286,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): # See http://stackoverflow.com/q/16136916/539465 self.conn.connect( (self.currentServer.hostname, self.currentServer.port)) - if network_config.ssl() \ - or self.currentServer.force_tls_verification: + if network_config.ssl() or \ + self.currentServer.force_tls_verification: self.starttls() # Suppress this warning for loopback IPs. @@ -294,6 +297,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): targetip = targetip.decode('utf-8') elif (not network_config.requireStarttls()) and \ (not network_config.ssl()) and \ + (not self.currentServer.force_tls_verification) and \ (ipaddress is None or not ipaddress.ip_address(targetip).is_loopback): drivers.log.warning(('Connection to network %s ' 'does not use SSL/TLS, which makes it vulnerable to ' @@ -369,6 +373,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): def anyCertValidationEnabled(self): """Returns whether any kind of certificate validation is enabled, other than Server.force_tls_verification.""" + network_config = getattr(conf.supybot.networks, self.irc.network) return any([ conf.supybot.protocols.ssl.verifyCertificates(), network_config.ssl.serverFingerprints(), @@ -392,7 +397,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): verifyCertificates = True else: verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates() - if not verifyCertificates: + if not self.currentServer.force_tls_verification \ + and not self.anyCertValidationEnabled(): drivers.log.warning('Not checking SSL certificates, connections ' 'are vulnerable to man-in-the-middle attacks. Set ' 'supybot.protocols.ssl.verifyCertificates to "true" ' diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index 969454b00..fcc14cc06 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -100,6 +100,8 @@ class ServersMixin(object): lastDisconnect = network.lastDisconnectTimes.get(server.hostname) if policy is None or lastDisconnect is None: + log.debug('No STS policy, or never disconnected from this server. %r %r', + policy, lastDisconnect) return server # The policy was stored, which means it was received on a secure @@ -107,9 +109,13 @@ class ServersMixin(object): policy = ircutils.parseStsPolicy(log, policy, parseDuration=True) if lastDisconnect + policy['duration'] < time.time(): + log.info('STS policy expired, removing.') network.expireStsPolicy(server.hostname) return server + log.info('Using STS policy: changing port from %s to %s.', + server.port, policy['port']) + # Change the port, and force TLS verification, as required by the STS # specification. return Server(server.hostname, policy['port'], diff --git a/src/ircdb.py b/src/ircdb.py index 462fe57b9..6c965a2cb 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -499,16 +499,15 @@ class IrcChannel(object): class IrcNetwork(object): """This class holds dynamic information about a network that should be preserved across restarts.""" - __slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes') + __slots__ = ('stsPolicies', 'lastDisconnectTimes') - def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None): - self.name = name + def __init__(self, stsPolicies=None, lastDisconnectTimes=None): self.stsPolicies = stsPolicies or {} self.lastDisconnectTimes = lastDisconnectTimes or {} def __repr__(self): - return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \ - (self.__class__.__name, self.name, self.stsPolicy, + return '%s(stsPolicies=%r, lastDisconnectTimes=%s)' % \ + (self.__class__.__name__, self.stsPolicies, self.lastDisconnectTimes) def addStsPolicy(self, server, stsPolicy): @@ -659,13 +658,14 @@ class IrcChannelCreator(Creator): class IrcNetworkCreator(Creator): __slots__ = ('net', 'networks') + name = None def __init__(self, networks): self.net = IrcNetwork() self.networks = networks def network(self, rest, lineno): - self.net.name = rest + IrcNetworkCreator.name = rest def stspolicy(self, rest, lineno): (server, stsPolicy) = rest.split() @@ -677,8 +677,8 @@ class IrcNetworkCreator(Creator): self.net.lastDisconnectTimes[server] = when def finish(self): - if self.net.name: - self.networks.setNetwork(self.net.name, self.net) + if self.name: + self.networks.setNetwork(self.name, self.net) self.net = IrcNetwork() diff --git a/src/irclib.py b/src/irclib.py index ff19d3719..db9a988bc 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -419,10 +419,15 @@ class IrcStateFsm(object): CONNECTED_SASL = 80 '''Doing SASL authentication in the middle of a connection.''' + SHUTTING_DOWN = 100 + def __init__(self): self.reset() def reset(self): + if getattr(self, 'state', None) is not None: + log.debug('resetting from %s to %s', + self.state, self.States.UNINITIALIZED) self.state = self.States.UNINITIALIZED def _transition(self, to_state, expected_from=None): @@ -483,6 +488,9 @@ class IrcStateFsm(object): self.States.INIT_MOTD ]) + def on_shutdown(self): + self._transition(self.States.SHUTTING_DOWN) + class IrcState(IrcCommandDispatcher, log.Firewalled): """Maintains state of the Irc connection. Should also become smarter. """ @@ -1252,7 +1260,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_req, self.state.capabilities_ack, self.state.capabilities_nak) - self.driver.reconnect() + self.driver.reconnect(wait=True) elif capabilities_responded == self.state.capabilities_req: log.debug('Got all capabilities ACKed/NAKed') # We got all the capabilities we asked for @@ -1470,13 +1478,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.capUpkeep() def _onCapSts(self, policy): - secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled() + secure_connection = self.driver.currentServer.force_tls_verification \ + or (self.driver.ssl and self.driver.anyCertValidationEnabled()) parsed_policy = ircutils.parseStsPolicy( log, policy, parseDuration=secure_connection) if parsed_policy is None: # There was an error (and it was logged). Abort the connection. - self.driver.reconnect() + self.driver.reconnect(wait=True) return if secure_connection: @@ -1485,14 +1494,20 @@ class Irc(IrcCommandDispatcher, log.Firewalled): # For future-proofing (because we don't want to write an invalid # value), we write the raw policy received from the server instead # of the parsed one. + log.debug('Storing STS policy: %s', policy) ircdb.networks.getNetwork(self.network).addStsPolicy( - self.driver.server.hostname, policy) + self.driver.currentServer.hostname, policy) else: - hostname = self.driver.server.hostname + hostname = self.driver.currentServer.hostname + log.info('Got STS policy over insecure connection; ' + 'reconnecting to secure port. %r', + self.driver.currentServer) # Reconnect to the server, but with TLS *and* certificate # validation this time. + self.state.fsm.on_shutdown() self.driver.reconnect( - server=Server(hostname, parsed_policy['port'], True)) + server=Server(hostname, parsed_policy['port'], True), + wait=True) def _addCapabilities(self, capstring): for item in capstring.split(): @@ -1507,9 +1522,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if item == 'sts': log.error('Got "sts" capability without value. Aborting ' 'connection.') - self.driver.reconnect() + self.driver.reconnect(wait=True) self.state.capabilities_ls[item] = None + def doCapLs(self, msg): if len(msg.args) == 4: # Multi-line LS @@ -1519,6 +1535,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._addCapabilities(msg.args[3]) elif len(msg.args) == 3: # End of LS self._addCapabilities(msg.args[2]) + if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: + return self.state.fsm.expect_state([ # Normal case: IrcStateFsm.States.INIT_CAP_NEGOTIATION, @@ -1573,6 +1591,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): caps = msg.args[2].split() assert caps, 'Empty list of capabilities' self._addCapabilities(msg.args[2]) + if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: + return common_supported_unrequested_capabilities = ( set(self.state.capabilities_ls) & self.REQUEST_CAPABILITIES - diff --git a/test/test_ircdb.py b/test/test_ircdb.py index d8f716618..add6b6064 100644 --- a/test/test_ircdb.py +++ b/test/test_ircdb.py @@ -352,7 +352,6 @@ class IrcChannelTestCase(IrcdbTestCase): class IrcNetworkTestCase(IrcdbTestCase): def testDefaults(self): n = ircdb.IrcNetwork() - self.assertIsNone(n.name) self.assertEqual(n.stsPolicies, {}) self.assertEqual(n.lastDisconnectTimes, {}) @@ -373,7 +372,7 @@ class IrcNetworkTestCase(IrcdbTestCase): self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_) def testPreserve(self): - n = ircdb.IrcNetwork('foonet') + n = ircdb.IrcNetwork() n.addStsPolicy('foo', 'sts1') n.addStsPolicy('bar', 'sts2') n.addDisconnection('foo') @@ -459,15 +458,14 @@ class NetworksDictionaryTestCase(IrcdbTestCase): IrcdbTestCase.setUp(self) def testGetSetNetwork(self): - self.assertEqual(self.networks.getNetwork('foo').name, None) + self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {}) n = ircdb.IrcNetwork() - n.name = 'foo' self.networks.setNetwork('foo', n) - self.assertEqual(self.networks.getNetwork('foo').name, 'foo') + self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {}) def testPreserveOne(self): - n = ircdb.IrcNetwork('foonet') + n = ircdb.IrcNetwork() n.addStsPolicy('foo', 'sts1') n.addStsPolicy('bar', 'sts2') n.addDisconnection('foo') @@ -496,15 +494,15 @@ class NetworksDictionaryTestCase(IrcdbTestCase): ]) def testPreserveThree(self): - n = ircdb.IrcNetwork('foonet') + n = ircdb.IrcNetwork() n.addStsPolicy('foo', 'sts1') self.networks.setNetwork('foonet', n) - n = ircdb.IrcNetwork('barnet') + n = ircdb.IrcNetwork() n.addStsPolicy('bar', 'sts2') self.networks.setNetwork('barnet', n) - n = ircdb.IrcNetwork('baznet') + n = ircdb.IrcNetwork() n.addStsPolicy('baz', 'sts3') self.networks.setNetwork('baznet', n) diff --git a/test/test_irclib.py b/test/test_irclib.py index 288c3aef4..cdc0ea62d 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -520,7 +520,7 @@ class StsTestCase(SupyTestCase): def testStsInSecureConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = True self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) @@ -531,50 +531,54 @@ class StsTestCase(SupyTestCase): def testStsInInsecureTlsConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) def testStsInCleartextConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) def testStsInCleartextConnectionInvalidDuration(self): # "Servers MAY send this key to all clients, but insecurely # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=foo,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) def testStsInCleartextConnectionNoDuration(self): # "Servers MAY send this key to all clients, but insecurely # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) class IrcTestCase(SupyTestCase): def setUp(self): From f7130f2629d99a7e6382145c5cf1f988089093ac Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Fri, 1 May 2020 20:19:00 +0200 Subject: [PATCH 6/7] Add missing transition trigger on MOTD start. --- src/irclib.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/irclib.py b/src/irclib.py index db9a988bc..b556934f6 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1682,6 +1682,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled): """Handles PONG messages.""" self.outstandingPing = False + def do375(self, msg): + self.state.fsm.on_start_motd(self, msg) + log.info('Got start of MOTD from %s', self.server) + def do376(self, msg): self.state.fsm.on_end_motd() log.info('Got end of MOTD from %s', self.server) From 309fc1233b22cd714d247a64ad97bdbe2a3e8c39 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Fri, 1 May 2020 20:19:53 +0200 Subject: [PATCH 7/7] Add postTransition method to IrcCallback, called when irc.state.fsm changes. --- src/irclib.py | 121 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 47 deletions(-) diff --git a/src/irclib.py b/src/irclib.py index b556934f6..36a244c4f 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -122,6 +122,7 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled): '__call__': None, 'inFilter': lambda self, irc, msg: msg, 'outFilter': lambda self, irc, msg: msg, + 'postTransition': None, 'name': lambda self: self.__class__.__name__, 'callPrecedence': lambda self, irc: ([], []), } @@ -172,6 +173,12 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled): """ return msg + def postTransition(self, irc, msg, from_state, to_state): + """Called when the state of the IRC connection changes. + + `msg` is the message that triggered the transition, if any.""" + pass + def __call__(self, irc, msg): """Used for handling each message.""" method = self.dispatchCommand(msg.command, msg.args) @@ -430,10 +437,24 @@ class IrcStateFsm(object): self.state, self.States.UNINITIALIZED) self.state = self.States.UNINITIALIZED - def _transition(self, to_state, expected_from=None): - if expected_from is None or self.state in expected_from: + def _transition(self, irc, msg, to_state, expected_from=None): + """Transitions to state `to_state`. + + If `expected_from` is not `None`, first checks the current state is + in the set. + + After the transition, calls the + `postTransition(irc, msg, from_state, to_state)` method of all objects + in `irc.callbacks`. + + `msg` may be None if the transition isn't triggered by a message, but + `irc` may not.""" + from_state = self.state + if expected_from is None or from_state in expected_from: log.debug('transition from %s to %s', self.state, to_state) self.state = to_state + for callback in reversed(irc.callbacks): + msg = callback.postTransition(irc, msg, from_state, to_state) else: raise ValueError('unexpected transition to %s while in state %s' % (to_state, self.state)) @@ -443,53 +464,53 @@ class IrcStateFsm(object): raise ValueError(('Connection in state %s, but expected to be ' 'in state %s') % (self.state, expected_states)) - def on_init_messages_sent(self): + def on_init_messages_sent(self, irc): '''As soon as USER/NICK/CAP LS are sent''' - self._transition(self.States.INIT_CAP_NEGOTIATION, [ + self._transition(irc, None, self.States.INIT_CAP_NEGOTIATION, [ self.States.UNINITIALIZED, ]) - def on_sasl_cap(self): + def on_sasl_cap(self, irc, msg): '''Whenever we see the 'sasl' capability in a CAP LS response''' if self.state == self.States.INIT_CAP_NEGOTIATION: - self._transition(self.States.INIT_SASL) + self._transition(irc, msg, self.States.INIT_SASL) elif self.state == self.States.CONNECTED: - self._transition(self.States.CONNECTED_SASL) + self._transition(irc, msg, self.States.CONNECTED_SASL) else: raise ValueError('Got sasl cap while in state %s' % self.state) - def on_sasl_auth_finished(self): + def on_sasl_auth_finished(self, irc, msg): '''When sasl auth either succeeded or failed.''' if self.state == self.States.INIT_SASL: - self._transition(self.States.INIT_CAP_NEGOTIATION) + self._transition(irc, msg, self.States.INIT_CAP_NEGOTIATION) elif self.state == self.States.CONNECTED_SASL: - self._transition(self.States.CONNECTED) + self._transition(irc, msg, self.States.CONNECTED) else: raise ValueError('Finished SASL auth while in state %s' % self.state) - def on_cap_end(self): + def on_cap_end(self, irc, msg): '''When we send CAP END''' - self._transition(self.States.INIT_WAITING_MOTD, [ + self._transition(irc, msg, self.States.INIT_WAITING_MOTD, [ self.States.INIT_CAP_NEGOTIATION, ]) - def on_start_motd(self): + def on_start_motd(self, irc, msg): '''On 375 (RPL_MOTDSTART)''' - self._transition(self.States.INIT_MOTD, [ + self._transition(irc, msg, self.States.INIT_MOTD, [ self.States.INIT_CAP_NEGOTIATION, self.States.INIT_WAITING_MOTD, ]) - def on_end_motd(self): + def on_end_motd(self, irc, msg): '''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)''' - self._transition(self.States.CONNECTED, [ + self._transition(irc, msg, self.States.CONNECTED, [ self.States.INIT_CAP_NEGOTIATION, self.States.INIT_WAITING_MOTD, self.States.INIT_MOTD ]) - def on_shutdown(self): - self._transition(self.States.SHUTTING_DOWN) + def on_shutdown(self, irc, msg): + self._transition(irc, msg, self.States.SHUTTING_DOWN) class IrcState(IrcCommandDispatcher, log.Firewalled): """Maintains state of the Irc connection. Should also become smarter. @@ -1222,7 +1243,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendAuthenticationMessages() - self.state.fsm.on_init_messages_sent() + self.state.fsm.on_init_messages_sent(self) def sendAuthenticationMessages(self): # Notes: @@ -1244,7 +1265,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendMsg(ircmsgs.user(self.ident, self.user)) - def capUpkeep(self): + def capUpkeep(self, msg): + """ + Called after getting a CAP ACK/NAK to check it's consistent with what + was requested, and to end the cap negotiation when we received all the + ACK/NAKs we were waiting for. + + `msg` is the message that triggered this call.""" self.state.fsm.expect_state([ # Normal CAP ACK / CAP NAK during cap negotiation IrcStateFsm.States.INIT_CAP_NEGOTIATION, @@ -1268,18 +1295,18 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if self.state.fsm.state in [ IrcStateFsm.States.INIT_CAP_NEGOTIATION, IrcStateFsm.States.CONNECTED]: - self._maybeStartSasl() + self._maybeStartSasl(msg) else: pass # Already in the middle of a SASL auth else: - self.endCapabilityNegociation() + self.endCapabilityNegociation(msg) else: log.debug('Waiting for ACK/NAK of capabilities: %r', self.state.capabilities_req - capabilities_responded) pass # Do nothing, we'll get more - def endCapabilityNegociation(self): - self.state.fsm.on_cap_end() + def endCapabilityNegociation(self, msg): + self.state.fsm.on_cap_end(self, msg) self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) def sendSaslString(self, string): @@ -1287,7 +1314,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', args=(chunk,))) - def tryNextSaslMechanism(self): + def tryNextSaslMechanism(self, msg): self.state.fsm.expect_state([ IrcStateFsm.States.INIT_SASL, IrcStateFsm.States.CONNECTED_SASL, @@ -1301,14 +1328,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled): 'aborting connection.') else: self.sasl_current_mechanism = None - self.state.fsm.on_sasl_auth_finished() + self.state.fsm.on_sasl_auth_finished(self, msg) if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: - self.endCapabilityNegociation() + self.endCapabilityNegociation(msg) - def _maybeStartSasl(self): + def _maybeStartSasl(self, msg): if not self.sasl_authenticated and \ 'sasl' in self.state.capabilities_ack: - self.state.fsm.on_sasl_cap() + self.state.fsm.on_sasl_cap(self, msg) assert 'sasl' in self.state.capabilities_ls, ( 'Got "CAP ACK sasl" without receiving "CAP LS sasl" or ' '"CAP NEW sasl" first.') @@ -1318,7 +1345,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sasl_next_mechanisms = [ x for x in self.sasl_next_mechanisms if x.lower() in available] - self.tryNextSaslMechanism() + self.tryNextSaslMechanism(msg) def doAuthenticate(self, msg): self.state.fsm.expect_state([ @@ -1426,28 +1453,28 @@ class Irc(IrcCommandDispatcher, log.Firewalled): def do903(self, msg): log.info('%s: SASL authentication successful', self.network) self.sasl_authenticated = True - self.state.fsm.on_sasl_auth_finished() + self.state.fsm.on_sasl_auth_finished(self, msg) if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION: - self.endCapabilityNegociation() + self.endCapabilityNegociation(msg) def do904(self, msg): log.warning('%s: SASL authentication failed (mechanism: %s)', self.network, self.sasl_current_mechanism) - self.tryNextSaslMechanism() + self.tryNextSaslMechanism(msg) def do905(self, msg): log.warning('%s: SASL authentication failed because the username or ' 'password is too long.', self.network) - self.tryNextSaslMechanism() + self.tryNextSaslMechanism(msg) def do906(self, msg): log.warning('%s: SASL authentication aborted', self.network) - self.tryNextSaslMechanism() + self.tryNextSaslMechanism(msg) def do907(self, msg): log.warning('%s: Attempted SASL authentication when we were already ' 'authenticated.', self.network) - self.tryNextSaslMechanism() + self.tryNextSaslMechanism(msg) def do908(self, msg): log.info('%s: Supported SASL mechanisms: %s', @@ -1464,7 +1491,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.network, caps) self.state.capabilities_ack.update(caps) - self.capUpkeep() + self.capUpkeep(msg) def doCapNak(self, msg): if len(msg.args) != 3: @@ -1475,9 +1502,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_nak.update(caps) log.warning('%s: Server refused capabilities: %L', self.network, caps) - self.capUpkeep() + self.capUpkeep(msg) - def _onCapSts(self, policy): + def _onCapSts(self, policy, msg): secure_connection = self.driver.currentServer.force_tls_verification \ or (self.driver.ssl and self.driver.anyCertValidationEnabled()) @@ -1504,19 +1531,19 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.driver.currentServer) # Reconnect to the server, but with TLS *and* certificate # validation this time. - self.state.fsm.on_shutdown() + self.state.fsm.on_shutdown(self, msg) self.driver.reconnect( server=Server(hostname, parsed_policy['port'], True), wait=True) - def _addCapabilities(self, capstring): + def _addCapabilities(self, capstring, msg): for item in capstring.split(): while item.startswith(('=', '~')): item = item[1:] if '=' in item: (cap, value) = item.split('=', 1) if cap == 'sts': - self._onCapSts(value) + self._onCapSts(value, msg) self.state.capabilities_ls[cap] = value else: if item == 'sts': @@ -1532,9 +1559,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if msg.args[2] != '*': log.warning('Bad CAP LS from server: %r', msg) return - self._addCapabilities(msg.args[3]) + self._addCapabilities(msg.args[3], msg) elif len(msg.args) == 3: # End of LS - self._addCapabilities(msg.args[2]) + self._addCapabilities(msg.args[2], msg) if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: return self.state.fsm.expect_state([ @@ -1558,7 +1585,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if new_caps: self._requestCaps(new_caps) else: - self.endCapabilityNegociation() + self.endCapabilityNegociation(msg) else: log.warning('Bad CAP LS from server: %r', msg) return @@ -1590,7 +1617,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): return caps = msg.args[2].split() assert caps, 'Empty list of capabilities' - self._addCapabilities(msg.args[2]) + self._addCapabilities(msg.args[2], msg) if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: return common_supported_unrequested_capabilities = ( @@ -1687,7 +1714,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): log.info('Got start of MOTD from %s', self.server) def do376(self, msg): - self.state.fsm.on_end_motd() + self.state.fsm.on_end_motd(self, msg) log.info('Got end of MOTD from %s', self.server) self.afterConnect = True # Let's reset nicks in case we had to use a weird one.