diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index 83e7ec7b1..54eebb2e2 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -97,7 +97,8 @@ class ServersMixin(object): def _applyStsPolicy(self, server): network = ircdb.networks.getNetwork(self.networkName) - policy = network.stsPolicies.get(server.hostname) + (policy_port, policy) = network.stsPolicies.get( + server.hostname, (None, None)) lastDisconnect = network.lastDisconnectTimes.get(server.hostname) if policy is None or lastDisconnect is None: @@ -107,22 +108,22 @@ class ServersMixin(object): # The policy was stored, which means it was received on a secure # connection. - policy = ircutils.parseStsPolicy(log, policy, parseDuration=True) + policy = ircutils.parseStsPolicy(log, policy, secure_connection=True) if lastDisconnect + policy['duration'] < time.time(): log.info('STS policy expired, removing.') network.expireStsPolicy(server.hostname) return server - if server.port == policy['port']: + if server.port == policy_port: log.info('Using STS policy, port %s', server.port) else: log.info('Using STS policy: changing port from %s to %s.', - server.port, policy['port']) + server.port, policy_port) # Change the port, and force TLS verification, as required by the STS # specification. - return Server(server.hostname, policy['port'], server.attempt, + return Server(server.hostname, policy_port, server.attempt, force_tls_verification=True) def die(self): diff --git a/src/ircdb.py b/src/ircdb.py index 2c88e5b83..895318962 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -509,9 +509,10 @@ class IrcNetwork(object): (self.__class__.__name__, self.stsPolicies, self.lastDisconnectTimes) - def addStsPolicy(self, server, stsPolicy): - assert isinstance(stsPolicy, str) - self.stsPolicies[server] = stsPolicy + def addStsPolicy(self, server, port, stsPolicy): + assert isinstance(port, int), repr(port) + assert isinstance(stsPolicy, str), repr(stsPolicy) + self.stsPolicies[server] = (port, stsPolicy) def expireStsPolicy(self, server): if server in self.stsPolicies: @@ -526,8 +527,10 @@ class IrcNetwork(object): fd.write(s) fd.write(os.linesep) - for (server, stsPolicy) in sorted(self.stsPolicies.items()): - write('stsPolicy %s %s' % (server, stsPolicy)) + for (server, (port, stsPolicy)) in sorted(self.stsPolicies.items()): + assert isinstance(port, int), repr(port) + assert isinstance(stsPolicy, str), repr(stsPolicy) + write('stsPolicy %s %s %s' % (server, port, stsPolicy)) for (server, disconnectTime) in \ sorted(self.lastDisconnectTimes.items()): @@ -667,8 +670,12 @@ class IrcNetworkCreator(Creator): IrcNetworkCreator.name = rest def stspolicy(self, rest, lineno): - (server, stsPolicy) = rest.split() - self.net.addStsPolicy(server, stsPolicy) + L = rest.split() + if len(L) == 2: + # Old policy missing a port. Discard it + return + (server, policyPort, stsPolicy) = L + self.net.addStsPolicy(server, int(policyPort), stsPolicy) def lastdisconnecttime(self, rest, lineno): (server, when) = rest.split() diff --git a/src/irclib.py b/src/irclib.py index b40f8af6b..846741745 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -2050,7 +2050,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): or (self.driver.ssl and self.driver.anyCertValidationEnabled()) parsed_policy = ircutils.parseStsPolicy( - log, policy, parseDuration=secure_connection) + log, policy, secure_connection=secure_connection) if parsed_policy is None: # There was an error (and it was logged). Ignore it and proceed # with the connection. @@ -2065,9 +2065,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled): # For future-proofing (because we don't want to write an invalid # value), we write the raw policy received from the server instead # of the parsed one. - log.debug('Storing STS policy: %s', policy) + log.debug('Storing STS policy for %s (TLS port %s): %s', + self.driver.currentServer.hostname, + self.driver.currentServer.port, + policy) ircdb.networks.getNetwork(self.network).addStsPolicy( - self.driver.currentServer.hostname, policy) + self.driver.currentServer.hostname, + self.driver.currentServer.port, + policy) else: hostname = self.driver.currentServer.hostname attempt = self.driver.currentServer.attempt diff --git a/src/ircutils.py b/src/ircutils.py index 80f139002..76608761d 100644 --- a/src/ircutils.py +++ b/src/ircutils.py @@ -1073,11 +1073,15 @@ def parseCapabilityKeyValue(s): return d -def parseStsPolicy(logger, policy, parseDuration): +def parseStsPolicy(logger, policy, secure_connection): parsed_policy = parseCapabilityKeyValue(policy) for key in ('port', 'duration'): - if key == 'duration' and not parseDuration: + if key == 'duration' and not secure_connection: + if key in parsed_policy: + del parsed_policy[key] + continue + elif key == 'port' and secure_connection: if key in parsed_policy: del parsed_policy[key] continue diff --git a/test/test_drivers.py b/test/test_drivers.py index b17cd9aa2..4942f16c8 100644 --- a/test/test_drivers.py +++ b/test/test_drivers.py @@ -39,7 +39,7 @@ class DriversTestCase(SupyTestCase): def testValidStsPolicy(self): irc = irclib.Irc('test') net = ircdb.networks.getNetwork('test') - net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addStsPolicy('example.com', 6697, 'duration=10,port=12345') net.addDisconnection('example.com') with conf.supybot.networks.test.servers.context( @@ -64,7 +64,7 @@ class DriversTestCase(SupyTestCase): def testExpiredStsPolicy(self): irc = irclib.Irc('test') net = ircdb.networks.getNetwork('test') - net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addStsPolicy('example.com', 6697, 'duration=10') net.addDisconnection('example.com') timeFastForward(16) @@ -81,7 +81,7 @@ class DriversTestCase(SupyTestCase): def testRescheduledStsPolicy(self): irc = irclib.Irc('test') net = ircdb.networks.getNetwork('test') - net.addStsPolicy('example.com', 'duration=10,port=6697') + net.addStsPolicy('example.com', 6697, 'duration=10') net.addDisconnection('example.com') with conf.supybot.networks.test.servers.context( diff --git a/test/test_ircdb.py b/test/test_ircdb.py index f5ea8a75b..7b1c06641 100644 --- a/test/test_ircdb.py +++ b/test/test_ircdb.py @@ -358,11 +358,11 @@ class IrcNetworkTestCase(IrcdbTestCase): def testStsPolicy(self): n = ircdb.IrcNetwork() - n.addStsPolicy('foo', 'bar') - n.addStsPolicy('baz', 'qux') + n.addStsPolicy('foo', 123, 'bar') + n.addStsPolicy('baz', 456, 'qux') self.assertEqual(n.stsPolicies, { - 'foo': 'bar', - 'baz': 'qux', + 'foo': (123, 'bar'), + 'baz': (456, 'qux'), }) def testAddDisconnection(self): @@ -374,8 +374,8 @@ class IrcNetworkTestCase(IrcdbTestCase): def testPreserve(self): n = ircdb.IrcNetwork() - n.addStsPolicy('foo', 'sts1') - n.addStsPolicy('bar', 'sts2') + n.addStsPolicy('foo', 123, 'sts1') + n.addStsPolicy('bar', 456,'sts2') n.addDisconnection('foo') n.addDisconnection('baz') disconnect_time_foo = n.lastDisconnectTimes['foo'] @@ -384,8 +384,8 @@ class IrcNetworkTestCase(IrcdbTestCase): n.preserve(fd, indent=' ') fd.seek(0) self.assertCountEqual(fd.read().split('\n'), [ - ' stsPolicy foo sts1', - ' stsPolicy bar sts2', + ' stsPolicy foo 123 sts1', + ' stsPolicy bar 456 sts2', ' lastDisconnectTime foo %d' % disconnect_time_foo, ' lastDisconnectTime baz %d' % disconnect_time_baz, '', @@ -467,8 +467,8 @@ class NetworksDictionaryTestCase(IrcdbTestCase): def testPreserveOne(self): n = ircdb.IrcNetwork() - n.addStsPolicy('foo', 'sts1') - n.addStsPolicy('bar', 'sts2') + n.addStsPolicy('foo', 123, 'sts1') + n.addStsPolicy('bar', 456, 'sts2') n.addDisconnection('foo') n.addDisconnection('baz') disconnect_time_foo = n.lastDisconnectTimes['foo'] @@ -486,8 +486,8 @@ class NetworksDictionaryTestCase(IrcdbTestCase): lines = fd.getvalue().split('\n') self.assertEqual(lines.pop(0), 'network foonet') self.assertCountEqual(lines, [ - ' stsPolicy foo sts1', - ' stsPolicy bar sts2', + ' stsPolicy foo 123 sts1', + ' stsPolicy bar 456 sts2', ' lastDisconnectTime foo %d' % disconnect_time_foo, ' lastDisconnectTime baz %d' % disconnect_time_baz, '', @@ -496,15 +496,15 @@ class NetworksDictionaryTestCase(IrcdbTestCase): def testPreserveThree(self): n = ircdb.IrcNetwork() - n.addStsPolicy('foo', 'sts1') + n.addStsPolicy('foo', 123, 'sts1') self.networks.setNetwork('foonet', n) n = ircdb.IrcNetwork() - n.addStsPolicy('bar', 'sts2') + n.addStsPolicy('bar', 456, 'sts2') self.networks.setNetwork('barnet', n) n = ircdb.IrcNetwork() - n.addStsPolicy('baz', 'sts3') + n.addStsPolicy('baz', 789, 'sts3') self.networks.setNetwork('baznet', n) fd = io.StringIO() @@ -518,13 +518,13 @@ class NetworksDictionaryTestCase(IrcdbTestCase): fd.seek(0) self.assertEqual(fd.getvalue(), 'network barnet\n' - ' stsPolicy bar sts2\n' + ' stsPolicy bar 456 sts2\n' '\n' 'network baznet\n' - ' stsPolicy baz sts3\n' + ' stsPolicy baz 789 sts3\n' '\n' 'network foonet\n' - ' stsPolicy foo sts1\n' + ' stsPolicy foo 123 sts1\n' '\n' ) diff --git a/test/test_irclib.py b/test/test_irclib.py index 15b2a31ae..cfb15f8fd 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -759,16 +759,27 @@ class StsTestCase(SupyTestCase): self.irc.driver.ssl = True self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', - args=('*', 'LS', 'sts=duration=42,port=6697'))) + args=('*', 'LS', 'sts=duration=42,port=12345'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, { - 'irc.test': 'duration=42,port=6697'}) + 'irc.test': (6697, 'duration=42,port=12345')}) + self.irc.driver.reconnect.assert_not_called() + + def testStsInSecureConnectionNoPort(self): + self.irc.driver.anyCertValidationEnabled.return_value = True + self.irc.driver.ssl = True + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', 'sts=duration=42'))) + + self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, { + 'irc.test': (6697, 'duration=42')}) self.irc.driver.reconnect.assert_not_called() def testStsInInsecureTlsConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697')))