From 22120ee862f0728170998b93598dc2d7f7c2c165 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Sun, 8 Dec 2019 21:25:59 +0100 Subject: [PATCH] Fix various issues with STS handling. --- src/drivers/Socket.py | 12 +++++++++--- src/drivers/__init__.py | 6 ++++++ src/ircdb.py | 16 ++++++++-------- src/irclib.py | 34 +++++++++++++++++++++++++++------- test/test_ircdb.py | 16 +++++++--------- test/test_irclib.py | 22 +++++++++++++--------- 6 files changed, 70 insertions(+), 36 deletions(-) diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index 1770f2e95..0017c6d20 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -241,6 +241,9 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): else: drivers.log.debug('Not resetting %s.', self.irc) if wait: + if server is not None: + # Make this server be the next one to be used. + self.servers.insert(0, server) self.scheduleReconnect() return self.currentServer = server or self._getNextServer() @@ -283,8 +286,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): # See http://stackoverflow.com/q/16136916/539465 self.conn.connect( (self.currentServer.hostname, self.currentServer.port)) - if network_config.ssl() \ - or self.currentServer.force_tls_verification: + if network_config.ssl() or \ + self.currentServer.force_tls_verification: self.starttls() # Suppress this warning for loopback IPs. @@ -294,6 +297,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): targetip = targetip.decode('utf-8') elif (not network_config.requireStarttls()) and \ (not network_config.ssl()) and \ + (not self.currentServer.force_tls_verification) and \ (ipaddress is None or not ipaddress.ip_address(targetip).is_loopback): drivers.log.warning(('Connection to network %s ' 'does not use SSL/TLS, which makes it vulnerable to ' @@ -369,6 +373,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): def anyCertValidationEnabled(self): """Returns whether any kind of certificate validation is enabled, other than Server.force_tls_verification.""" + network_config = getattr(conf.supybot.networks, self.irc.network) return any([ conf.supybot.protocols.ssl.verifyCertificates(), network_config.ssl.serverFingerprints(), @@ -392,7 +397,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): verifyCertificates = True else: verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates() - if not verifyCertificates: + if not self.currentServer.force_tls_verification \ + and not self.anyCertValidationEnabled(): drivers.log.warning('Not checking SSL certificates, connections ' 'are vulnerable to man-in-the-middle attacks. Set ' 'supybot.protocols.ssl.verifyCertificates to "true" ' diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index 969454b00..fcc14cc06 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -100,6 +100,8 @@ class ServersMixin(object): lastDisconnect = network.lastDisconnectTimes.get(server.hostname) if policy is None or lastDisconnect is None: + log.debug('No STS policy, or never disconnected from this server. %r %r', + policy, lastDisconnect) return server # The policy was stored, which means it was received on a secure @@ -107,9 +109,13 @@ class ServersMixin(object): policy = ircutils.parseStsPolicy(log, policy, parseDuration=True) if lastDisconnect + policy['duration'] < time.time(): + log.info('STS policy expired, removing.') network.expireStsPolicy(server.hostname) return server + log.info('Using STS policy: changing port from %s to %s.', + server.port, policy['port']) + # Change the port, and force TLS verification, as required by the STS # specification. return Server(server.hostname, policy['port'], diff --git a/src/ircdb.py b/src/ircdb.py index 462fe57b9..6c965a2cb 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -499,16 +499,15 @@ class IrcChannel(object): class IrcNetwork(object): """This class holds dynamic information about a network that should be preserved across restarts.""" - __slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes') + __slots__ = ('stsPolicies', 'lastDisconnectTimes') - def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None): - self.name = name + def __init__(self, stsPolicies=None, lastDisconnectTimes=None): self.stsPolicies = stsPolicies or {} self.lastDisconnectTimes = lastDisconnectTimes or {} def __repr__(self): - return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \ - (self.__class__.__name, self.name, self.stsPolicy, + return '%s(stsPolicies=%r, lastDisconnectTimes=%s)' % \ + (self.__class__.__name__, self.stsPolicies, self.lastDisconnectTimes) def addStsPolicy(self, server, stsPolicy): @@ -659,13 +658,14 @@ class IrcChannelCreator(Creator): class IrcNetworkCreator(Creator): __slots__ = ('net', 'networks') + name = None def __init__(self, networks): self.net = IrcNetwork() self.networks = networks def network(self, rest, lineno): - self.net.name = rest + IrcNetworkCreator.name = rest def stspolicy(self, rest, lineno): (server, stsPolicy) = rest.split() @@ -677,8 +677,8 @@ class IrcNetworkCreator(Creator): self.net.lastDisconnectTimes[server] = when def finish(self): - if self.net.name: - self.networks.setNetwork(self.net.name, self.net) + if self.name: + self.networks.setNetwork(self.name, self.net) self.net = IrcNetwork() diff --git a/src/irclib.py b/src/irclib.py index ff19d3719..db9a988bc 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -419,10 +419,15 @@ class IrcStateFsm(object): CONNECTED_SASL = 80 '''Doing SASL authentication in the middle of a connection.''' + SHUTTING_DOWN = 100 + def __init__(self): self.reset() def reset(self): + if getattr(self, 'state', None) is not None: + log.debug('resetting from %s to %s', + self.state, self.States.UNINITIALIZED) self.state = self.States.UNINITIALIZED def _transition(self, to_state, expected_from=None): @@ -483,6 +488,9 @@ class IrcStateFsm(object): self.States.INIT_MOTD ]) + def on_shutdown(self): + self._transition(self.States.SHUTTING_DOWN) + class IrcState(IrcCommandDispatcher, log.Firewalled): """Maintains state of the Irc connection. Should also become smarter. """ @@ -1252,7 +1260,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_req, self.state.capabilities_ack, self.state.capabilities_nak) - self.driver.reconnect() + self.driver.reconnect(wait=True) elif capabilities_responded == self.state.capabilities_req: log.debug('Got all capabilities ACKed/NAKed') # We got all the capabilities we asked for @@ -1470,13 +1478,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.capUpkeep() def _onCapSts(self, policy): - secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled() + secure_connection = self.driver.currentServer.force_tls_verification \ + or (self.driver.ssl and self.driver.anyCertValidationEnabled()) parsed_policy = ircutils.parseStsPolicy( log, policy, parseDuration=secure_connection) if parsed_policy is None: # There was an error (and it was logged). Abort the connection. - self.driver.reconnect() + self.driver.reconnect(wait=True) return if secure_connection: @@ -1485,14 +1494,20 @@ class Irc(IrcCommandDispatcher, log.Firewalled): # For future-proofing (because we don't want to write an invalid # value), we write the raw policy received from the server instead # of the parsed one. + log.debug('Storing STS policy: %s', policy) ircdb.networks.getNetwork(self.network).addStsPolicy( - self.driver.server.hostname, policy) + self.driver.currentServer.hostname, policy) else: - hostname = self.driver.server.hostname + hostname = self.driver.currentServer.hostname + log.info('Got STS policy over insecure connection; ' + 'reconnecting to secure port. %r', + self.driver.currentServer) # Reconnect to the server, but with TLS *and* certificate # validation this time. + self.state.fsm.on_shutdown() self.driver.reconnect( - server=Server(hostname, parsed_policy['port'], True)) + server=Server(hostname, parsed_policy['port'], True), + wait=True) def _addCapabilities(self, capstring): for item in capstring.split(): @@ -1507,9 +1522,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if item == 'sts': log.error('Got "sts" capability without value. Aborting ' 'connection.') - self.driver.reconnect() + self.driver.reconnect(wait=True) self.state.capabilities_ls[item] = None + def doCapLs(self, msg): if len(msg.args) == 4: # Multi-line LS @@ -1519,6 +1535,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._addCapabilities(msg.args[3]) elif len(msg.args) == 3: # End of LS self._addCapabilities(msg.args[2]) + if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: + return self.state.fsm.expect_state([ # Normal case: IrcStateFsm.States.INIT_CAP_NEGOTIATION, @@ -1573,6 +1591,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): caps = msg.args[2].split() assert caps, 'Empty list of capabilities' self._addCapabilities(msg.args[2]) + if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN: + return common_supported_unrequested_capabilities = ( set(self.state.capabilities_ls) & self.REQUEST_CAPABILITIES - diff --git a/test/test_ircdb.py b/test/test_ircdb.py index d8f716618..add6b6064 100644 --- a/test/test_ircdb.py +++ b/test/test_ircdb.py @@ -352,7 +352,6 @@ class IrcChannelTestCase(IrcdbTestCase): class IrcNetworkTestCase(IrcdbTestCase): def testDefaults(self): n = ircdb.IrcNetwork() - self.assertIsNone(n.name) self.assertEqual(n.stsPolicies, {}) self.assertEqual(n.lastDisconnectTimes, {}) @@ -373,7 +372,7 @@ class IrcNetworkTestCase(IrcdbTestCase): self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_) def testPreserve(self): - n = ircdb.IrcNetwork('foonet') + n = ircdb.IrcNetwork() n.addStsPolicy('foo', 'sts1') n.addStsPolicy('bar', 'sts2') n.addDisconnection('foo') @@ -459,15 +458,14 @@ class NetworksDictionaryTestCase(IrcdbTestCase): IrcdbTestCase.setUp(self) def testGetSetNetwork(self): - self.assertEqual(self.networks.getNetwork('foo').name, None) + self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {}) n = ircdb.IrcNetwork() - n.name = 'foo' self.networks.setNetwork('foo', n) - self.assertEqual(self.networks.getNetwork('foo').name, 'foo') + self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {}) def testPreserveOne(self): - n = ircdb.IrcNetwork('foonet') + n = ircdb.IrcNetwork() n.addStsPolicy('foo', 'sts1') n.addStsPolicy('bar', 'sts2') n.addDisconnection('foo') @@ -496,15 +494,15 @@ class NetworksDictionaryTestCase(IrcdbTestCase): ]) def testPreserveThree(self): - n = ircdb.IrcNetwork('foonet') + n = ircdb.IrcNetwork() n.addStsPolicy('foo', 'sts1') self.networks.setNetwork('foonet', n) - n = ircdb.IrcNetwork('barnet') + n = ircdb.IrcNetwork() n.addStsPolicy('bar', 'sts2') self.networks.setNetwork('barnet', n) - n = ircdb.IrcNetwork('baznet') + n = ircdb.IrcNetwork() n.addStsPolicy('baz', 'sts3') self.networks.setNetwork('baznet', n) diff --git a/test/test_irclib.py b/test/test_irclib.py index 288c3aef4..cdc0ea62d 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -520,7 +520,7 @@ class StsTestCase(SupyTestCase): def testStsInSecureConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = True self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) @@ -531,50 +531,54 @@ class StsTestCase(SupyTestCase): def testStsInInsecureTlsConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) def testStsInCleartextConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) def testStsInCleartextConnectionInvalidDuration(self): # "Servers MAY send this key to all clients, but insecurely # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=foo,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) def testStsInCleartextConnectionNoDuration(self): # "Servers MAY send this key to all clients, but insecurely # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.server = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True)) + server=drivers.Server('irc.test', 6697, True), + wait=True) class IrcTestCase(SupyTestCase): def setUp(self):