From 772ec8d6a9d7b0d91c3c14664e745528cef1f112 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz Date: Mon, 11 Jan 2021 23:22:21 +0100 Subject: [PATCH] When getting STS policy over insecure connection, reuse the exact same IP address Otherwise, if some IP addresses don't work (eg. all odd ones), the bot will consecutively fail because it can't connect, then connect + get STS + reconnect, then fail again, then connect + get STS, etc. --- src/conf.py | 2 +- src/drivers/Socket.py | 4 ++++ src/drivers/__init__.py | 4 ++-- src/irclib.py | 5 ++++- test/test_drivers.py | 14 +++++++------- test/test_irclib.py | 18 +++++++++--------- 6 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/conf.py b/src/conf.py index 3d955917e..0e416b218 100644 --- a/src/conf.py +++ b/src/conf.py @@ -284,7 +284,7 @@ class Servers(registry.SpaceSeparatedListOfStrings): hostname = hostname[1:-1] port = int(port) - return Server(hostname, port, force_tls_verification=False) + return Server(hostname, port, None, force_tls_verification=False) def __call__(self): L = registry.SpaceSeparatedListOfStrings.__call__(self) diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index bd603642e..0fb332384 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -257,6 +257,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): return self.currentServer = server or self._getNextServer() network_config = getattr(conf.supybot.networks, self.irc.network) + if self.currentServer.attempt is None: + self.currentServer = self.currentServer._replace(attempt=self._attempt) + else: + self._attempt = self.currentServer.attempt socks_proxy = network_config.socksproxy() try: if socks_proxy: diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index fcc14cc06..c57a0bcbe 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -40,7 +40,7 @@ from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils from ..utils import minisix -Server = namedtuple('Server', 'hostname port force_tls_verification') +Server = namedtuple('Server', 'hostname port attempt force_tls_verification') # force_tls_verification=True implies two things: # 1. force TLS to be enabled for this server # 2. ensure there is some kind of verification. If the user did not enable @@ -118,7 +118,7 @@ class ServersMixin(object): # Change the port, and force TLS verification, as required by the STS # specification. - return Server(server.hostname, policy['port'], + return Server(server.hostname, policy['port'], server.attempt, force_tls_verification=True) def die(self): diff --git a/src/irclib.py b/src/irclib.py index 623dc1168..5b78b8d22 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1823,14 +1823,17 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.driver.currentServer.hostname, policy) else: hostname = self.driver.currentServer.hostname + attempt = self.driver.currentServer.attempt + log.info('Got STS policy over insecure connection; ' 'reconnecting to secure port. %r', self.driver.currentServer) # Reconnect to the server, but with TLS *and* certificate # validation this time. self.state.fsm.on_shutdown(self, msg) + self.driver.reconnect( - server=Server(hostname, parsed_policy['port'], True), + server=Server(hostname, parsed_policy['port'], attempt, True), wait=True) def _addCapabilities(self, capstring, msg): diff --git a/test/test_drivers.py b/test/test_drivers.py index cc7b4fafd..99487797e 100644 --- a/test/test_drivers.py +++ b/test/test_drivers.py @@ -49,17 +49,17 @@ class DriversTestCase(SupyTestCase): self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) driver.die() self.assertEqual( driver._getNextServer(), - drivers.Server('example.org', 6667, False)) + drivers.Server('example.org', 6667, None, False)) driver.die() self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) def testExpiredStsPolicy(self): irc = irclib.Irc('test') @@ -76,7 +76,7 @@ class DriversTestCase(SupyTestCase): self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6667, False)) + drivers.Server('example.com', 6667, None, False)) def testRescheduledStsPolicy(self): irc = irclib.Irc('test') @@ -93,16 +93,16 @@ class DriversTestCase(SupyTestCase): self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) driver.die() self.assertEqual( driver._getNextServer(), - drivers.Server('example.org', 6667, False)) + drivers.Server('example.org', 6667, None, False)) driver.die() timeFastForward(8) self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) diff --git a/test/test_irclib.py b/test/test_irclib.py index 2ddabe276..1d82e066a 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -696,7 +696,7 @@ class StsTestCase(SupyTestCase): def testStsInSecureConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = True self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) @@ -707,25 +707,25 @@ class StsTestCase(SupyTestCase): def testStsInInsecureTlsConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) def testStsInCleartextConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) def testStsInCleartextConnectionInvalidDuration(self): @@ -733,13 +733,13 @@ class StsTestCase(SupyTestCase): # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=foo,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) def testStsInCleartextConnectionNoDuration(self): @@ -747,13 +747,13 @@ class StsTestCase(SupyTestCase): # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) class IrcTestCase(SupyTestCase):