diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index ff473c768..1770f2e95 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -225,6 +225,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): self._attempt += 1 self.nextReconnectTime = None if self.connected: + self.onDisconnect() drivers.log.reconnect(self.irc.network) if self in self._instances: self._instances.remove(self) @@ -242,7 +243,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): if wait: self.scheduleReconnect() return - self.server = server or self._getNextServer() + self.currentServer = server or self._getNextServer() network_config = getattr(conf.supybot.networks, self.irc.network) socks_proxy = network_config.socksproxy() try: @@ -255,7 +256,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): else: try: hostname = utils.net.getAddressFromHostname( - self.server.hostname, + self.currentServer.hostname, attempt=self._attempt) except (socket.gaierror, socket.error) as e: drivers.log.connectError(self.currentServer, e) @@ -264,8 +265,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): drivers.log.connect(self.currentServer) try: self.conn = utils.net.getSocket( - self.server.hostname, - port=self.server.port, + self.currentServer.hostname, + port=self.currentServer.port, socks_proxy=socks_proxy, vhost=conf.supybot.protocols.irc.vhost(), vhostv6=conf.supybot.protocols.irc.vhostv6(), @@ -280,8 +281,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): try: # Connect before SSL, otherwise SSL is disabled if we use SOCKS. # See http://stackoverflow.com/q/16136916/539465 - self.conn.connect((self.server.hostname, self.server.port)) - if network_config.ssl() or self.server.force_tls_verification: + self.conn.connect( + (self.currentServer.hostname, self.currentServer.port)) + if network_config.ssl() \ + or self.currentServer.force_tls_verification: self.starttls() # Suppress this warning for loopback IPs. @@ -351,6 +354,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): if self.writeCheckTime is not None: self.writeCheckTime = None drivers.log.die(self.irc) + drivers.IrcDriver.die(self) + drivers.ServersMixin.die(self) def _reallyDie(self): if self.conn is not None: @@ -382,7 +387,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): drivers.log.warning('Could not find cert file %s.' % certfile) certfile = None - if self.server.force_tls_verification \ + if self.currentServer.force_tls_verification \ and not self.anyCertValidationEnabled(): verifyCertificates = True else: @@ -395,7 +400,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): try: self.conn = utils.net.ssl_wrap_socket(self.conn, logger=drivers.log, - hostname=self.server.hostname, + hostname=self.currentServer.hostname, certfile=certfile, verify=verifyCertificates, trusted_fingerprints=network_config.ssl.serverFingerprints(), diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index f1de62c17..969454b00 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -32,10 +32,11 @@ Contains various drivers (network, file, and otherwise) for using IRC objects. """ +import time import socket from collections import namedtuple -from .. import conf, ircmsgs, log as supylog, utils +from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils from ..utils import minisix @@ -73,6 +74,7 @@ class IrcDriver(object): class ServersMixin(object): def __init__(self, irc, servers=()): + self.networkName = irc.network self.networkGroup = conf.supybot.networks.get(irc.network) self.servers = servers super(ServersMixin, self).__init__() @@ -89,8 +91,36 @@ class ServersMixin(object): assert self.servers, 'Servers value for %s is empty.' % \ self.networkGroup._name server = self.servers.pop(0) - self.currentServer = '%s:%s' % (server.hostname, server.port) - return server + self.currentServer = self._applyStsPolicy(server) + return self.currentServer + + def _applyStsPolicy(self, server): + network = ircdb.networks.getNetwork(self.networkName) + policy = network.stsPolicies.get(server.hostname) + lastDisconnect = network.lastDisconnectTimes.get(server.hostname) + + if policy is None or lastDisconnect is None: + return server + + # The policy was stored, which means it was received on a secure + # connection. + policy = ircutils.parseStsPolicy(log, policy, parseDuration=True) + + if lastDisconnect + policy['duration'] < time.time(): + network.expireStsPolicy(server.hostname) + return server + + # Change the port, and force TLS verification, as required by the STS + # specification. + return Server(server.hostname, policy['port'], + force_tls_verification=True) + + def die(self): + self.onDisconnect() + + def onDisconnect(self): + network = ircdb.networks.getNetwork(self.networkName) + network.addDisconnection(self.currentServer.hostname) def empty(): @@ -138,7 +168,8 @@ def run(): class Log(object): """This is used to have a nice, consistent interface for drivers to use.""" def connect(self, server): - self.info('Connecting to %s.', server) + self.info('Connecting to %s:%s.', + server.hostname, server.port) def connectError(self, server, e): if isinstance(e, Exception): @@ -146,7 +177,8 @@ class Log(object): e = e.args[1] else: e = utils.exnToString(e) - self.warning('Error connecting to %s: %s', server, e) + self.warning('Error connecting to %s:%s: %s', + server.hostname, server.port, e) def disconnect(self, server, e=None): if e: @@ -156,7 +188,8 @@ class Log(object): e = str(e) if not e.endswith('.'): e += '.' - self.warning('Disconnect from %s: %s', server, e) + self.warning('Disconnect from %s:%s: %s', + server.hostname, server.port, e) else: self.info('Disconnect from %s.', server) diff --git a/src/ircdb.py b/src/ircdb.py index 0f23ffc15..462fe57b9 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -515,6 +515,10 @@ class IrcNetwork(object): assert isinstance(stsPolicy, str) self.stsPolicies[server] = stsPolicy + def expireStsPolicy(self, server): + if server in self.stsPolicies: + del self.stsPolicies[server] + def addDisconnection(self, server): self.lastDisconnectTimes[server] = int(time.time()) @@ -674,8 +678,8 @@ class IrcNetworkCreator(Creator): def finish(self): if self.net.name: - self.networks.setNetwork(self.net) - self.name = None + self.networks.setNetwork(self.net.name, self.net) + self.net = IrcNetwork() class DuplicateHostmask(ValueError): diff --git a/src/irclib.py b/src/irclib.py index 4686ae9ff..ff19d3719 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1470,19 +1470,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.capUpkeep() def _onCapSts(self, policy): - parsed_policy = ircutils._parseStsPolicy(log, policy) + secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled() + + parsed_policy = ircutils.parseStsPolicy( + log, policy, parseDuration=secure_connection) if parsed_policy is None: # There was an error (and it was logged). Abort the connection. self.driver.reconnect() return - if not self.driver.ssl or not self.driver.anyCertValidationEnabled(): - hostname = self.driver.server.hostname - # Reconnect to the server, but with TLS *and* certificate - # validation this time. - self.driver.reconnect( - server=Server(hostname, parsed_policy['port'], True)) - else: + if secure_connection: # TLS is enabled and certificate is verified; write the STS policy # in stone. # For future-proofing (because we don't want to write an invalid @@ -1490,6 +1487,12 @@ class Irc(IrcCommandDispatcher, log.Firewalled): # of the parsed one. ircdb.networks.getNetwork(self.network).addStsPolicy( self.driver.server.hostname, policy) + else: + hostname = self.driver.server.hostname + # Reconnect to the server, but with TLS *and* certificate + # validation this time. + self.driver.reconnect( + server=Server(hostname, parsed_policy['port'], True)) def _addCapabilities(self, capstring): for item in capstring.split(): diff --git a/src/ircutils.py b/src/ircutils.py index 2c9556943..2d56e469a 100644 --- a/src/ircutils.py +++ b/src/ircutils.py @@ -931,7 +931,7 @@ class AuthenticateDecoder(object): return base64.b64decode(b''.join(self.chunks)) -def _parseStsPolicy(logger, policy): +def parseStsPolicy(logger, policy, parseDuration): parsed_policy = {} for kv in policy.split(','): if '=' in kv: @@ -941,6 +941,10 @@ def _parseStsPolicy(logger, policy): parsed_policy[kv] = None for key in ('port', 'duration'): + if key == 'duration' and not parseDuration: + if key in parsed_policy: + del parsed_policy[key] + continue if parsed_policy.get(key) is None: logger.error('Missing or empty "%s" key in STS policy.' 'Aborting connection.', key) diff --git a/src/world.py b/src/world.py index eec3b3e65..b00d112d9 100644 --- a/src/world.py +++ b/src/world.py @@ -42,7 +42,7 @@ import multiprocessing import re -from . import conf, drivers, ircutils, log, registry +from . import conf, ircutils, log, registry from .utils import minisix startedAt = time.time() # Just in case it doesn't get set later. @@ -193,6 +193,7 @@ def upkeep(): def makeDriversDie(): """Kills drivers.""" + from . import drivers log.info('Killing Driver objects.') for driver in drivers._drivers.values(): driver.die() diff --git a/test/test_drivers.py b/test/test_drivers.py new file mode 100644 index 000000000..cc7b4fafd --- /dev/null +++ b/test/test_drivers.py @@ -0,0 +1,108 @@ +## +# Copyright (c) 2019, Valentin Lorentz +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions, and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the author of this software nor the name of +# contributors to this software may be used to endorse or promote products +# derived from this software without specific prior written consent. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +### + +from supybot.test import * +import supybot.ircdb as ircdb +import supybot.irclib as irclib +import supybot.drivers as drivers + +class DriversTestCase(SupyTestCase): + def tearDown(self): + ircdb.networks.networks = {} + + def testValidStsPolicy(self): + irc = irclib.Irc('test') + net = ircdb.networks.getNetwork('test') + net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addDisconnection('example.com') + + with conf.supybot.networks.test.servers.context( + ['example.com:6667', 'example.org:6667']): + + driver = drivers.ServersMixin(irc) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) + driver.die() + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.org', 6667, False)) + driver.die() + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) + + def testExpiredStsPolicy(self): + irc = irclib.Irc('test') + net = ircdb.networks.getNetwork('test') + net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addDisconnection('example.com') + + timeFastForward(16) + + with conf.supybot.networks.test.servers.context( + ['example.com:6667']): + + driver = drivers.ServersMixin(irc) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6667, False)) + + def testRescheduledStsPolicy(self): + irc = irclib.Irc('test') + net = ircdb.networks.getNetwork('test') + net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addDisconnection('example.com') + + with conf.supybot.networks.test.servers.context( + ['example.com:6667', 'example.org:6667']): + + driver = drivers.ServersMixin(irc) + + timeFastForward(8) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) + driver.die() + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.org', 6667, False)) + driver.die() + + timeFastForward(8) + + self.assertEqual( + driver._getNextServer(), + drivers.Server('example.com', 6697, True)) diff --git a/test/test_irclib.py b/test/test_irclib.py index 569750a06..288c3aef4 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -550,6 +550,32 @@ class StsTestCase(SupyTestCase): self.irc.driver.reconnect.assert_called_once_with( server=drivers.Server('irc.test', 6697, True)) + def testStsInCleartextConnectionInvalidDuration(self): + # "Servers MAY send this key to all clients, but insecurely + # connected clients MUST ignore it." + self.irc.driver.anyCertValidationEnabled.return_value = False + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=duration=foo,port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) + self.irc.driver.reconnect.assert_called_once_with( + server=drivers.Server('irc.test', 6697, True)) + + def testStsInCleartextConnectionNoDuration(self): + # "Servers MAY send this key to all clients, but insecurely + # connected clients MUST ignore it." + self.irc.driver.anyCertValidationEnabled.return_value = False + self.irc.driver.ssl = True + self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=port=6697'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) + self.irc.driver.reconnect.assert_called_once_with( + server=drivers.Server('irc.test', 6697, True)) + class IrcTestCase(SupyTestCase): def setUp(self): self.irc = irclib.Irc('test')