diff --git a/src/conf.py b/src/conf.py index 3d955917e..0e416b218 100644 --- a/src/conf.py +++ b/src/conf.py @@ -284,7 +284,7 @@ class Servers(registry.SpaceSeparatedListOfStrings): hostname = hostname[1:-1] port = int(port) - return Server(hostname, port, force_tls_verification=False) + return Server(hostname, port, None, force_tls_verification=False) def __call__(self): L = registry.SpaceSeparatedListOfStrings.__call__(self) diff --git a/src/drivers/Socket.py b/src/drivers/Socket.py index bd603642e..0fb332384 100644 --- a/src/drivers/Socket.py +++ b/src/drivers/Socket.py @@ -257,6 +257,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin): return self.currentServer = server or self._getNextServer() network_config = getattr(conf.supybot.networks, self.irc.network) + if self.currentServer.attempt is None: + self.currentServer = self.currentServer._replace(attempt=self._attempt) + else: + self._attempt = self.currentServer.attempt socks_proxy = network_config.socksproxy() try: if socks_proxy: diff --git a/src/drivers/__init__.py b/src/drivers/__init__.py index fcc14cc06..c57a0bcbe 100644 --- a/src/drivers/__init__.py +++ b/src/drivers/__init__.py @@ -40,7 +40,7 @@ from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils from ..utils import minisix -Server = namedtuple('Server', 'hostname port force_tls_verification') +Server = namedtuple('Server', 'hostname port attempt force_tls_verification') # force_tls_verification=True implies two things: # 1. force TLS to be enabled for this server # 2. ensure there is some kind of verification. If the user did not enable @@ -118,7 +118,7 @@ class ServersMixin(object): # Change the port, and force TLS verification, as required by the STS # specification. - return Server(server.hostname, policy['port'], + return Server(server.hostname, policy['port'], server.attempt, force_tls_verification=True) def die(self): diff --git a/src/irclib.py b/src/irclib.py index 623dc1168..5b78b8d22 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1823,14 +1823,17 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.driver.currentServer.hostname, policy) else: hostname = self.driver.currentServer.hostname + attempt = self.driver.currentServer.attempt + log.info('Got STS policy over insecure connection; ' 'reconnecting to secure port. %r', self.driver.currentServer) # Reconnect to the server, but with TLS *and* certificate # validation this time. self.state.fsm.on_shutdown(self, msg) + self.driver.reconnect( - server=Server(hostname, parsed_policy['port'], True), + server=Server(hostname, parsed_policy['port'], attempt, True), wait=True) def _addCapabilities(self, capstring, msg): diff --git a/test/test_drivers.py b/test/test_drivers.py index cc7b4fafd..99487797e 100644 --- a/test/test_drivers.py +++ b/test/test_drivers.py @@ -49,17 +49,17 @@ class DriversTestCase(SupyTestCase): self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) driver.die() self.assertEqual( driver._getNextServer(), - drivers.Server('example.org', 6667, False)) + drivers.Server('example.org', 6667, None, False)) driver.die() self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) def testExpiredStsPolicy(self): irc = irclib.Irc('test') @@ -76,7 +76,7 @@ class DriversTestCase(SupyTestCase): self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6667, False)) + drivers.Server('example.com', 6667, None, False)) def testRescheduledStsPolicy(self): irc = irclib.Irc('test') @@ -93,16 +93,16 @@ class DriversTestCase(SupyTestCase): self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) driver.die() self.assertEqual( driver._getNextServer(), - drivers.Server('example.org', 6667, False)) + drivers.Server('example.org', 6667, None, False)) driver.die() timeFastForward(8) self.assertEqual( driver._getNextServer(), - drivers.Server('example.com', 6697, True)) + drivers.Server('example.com', 6697, None, True)) diff --git a/test/test_irclib.py b/test/test_irclib.py index 2ddabe276..1d82e066a 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -696,7 +696,7 @@ class StsTestCase(SupyTestCase): def testStsInSecureConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = True self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) @@ -707,25 +707,25 @@ class StsTestCase(SupyTestCase): def testStsInInsecureTlsConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) def testStsInCleartextConnection(self): self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=42,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) def testStsInCleartextConnectionInvalidDuration(self): @@ -733,13 +733,13 @@ class StsTestCase(SupyTestCase): # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=duration=foo,port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) def testStsInCleartextConnectionNoDuration(self): @@ -747,13 +747,13 @@ class StsTestCase(SupyTestCase): # connected clients MUST ignore it." self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.ssl = True - self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) + self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False) self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', args=('*', 'LS', 'sts=port=6697'))) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.irc.driver.reconnect.assert_called_once_with( - server=drivers.Server('irc.test', 6697, True), + server=drivers.Server('irc.test', 6697, None, True), wait=True) class IrcTestCase(SupyTestCase):