When getting STS policy over insecure connection, reuse the exact same IP address

Otherwise, if some IP addresses don't work (eg. all odd ones), the bot will
consecutively fail because it can't connect, then connect + get STS + reconnect,
then fail again, then connect + get STS, etc.
This commit is contained in:
Valentin Lorentz 2021-01-11 23:22:21 +01:00
parent ba77de0946
commit 772ec8d6a9
6 changed files with 27 additions and 20 deletions

View File

@ -284,7 +284,7 @@ class Servers(registry.SpaceSeparatedListOfStrings):
hostname = hostname[1:-1] hostname = hostname[1:-1]
port = int(port) port = int(port)
return Server(hostname, port, force_tls_verification=False) return Server(hostname, port, None, force_tls_verification=False)
def __call__(self): def __call__(self):
L = registry.SpaceSeparatedListOfStrings.__call__(self) L = registry.SpaceSeparatedListOfStrings.__call__(self)

View File

@ -257,6 +257,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
return return
self.currentServer = server or self._getNextServer() self.currentServer = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network) network_config = getattr(conf.supybot.networks, self.irc.network)
if self.currentServer.attempt is None:
self.currentServer = self.currentServer._replace(attempt=self._attempt)
else:
self._attempt = self.currentServer.attempt
socks_proxy = network_config.socksproxy() socks_proxy = network_config.socksproxy()
try: try:
if socks_proxy: if socks_proxy:

View File

@ -40,7 +40,7 @@ from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
from ..utils import minisix from ..utils import minisix
Server = namedtuple('Server', 'hostname port force_tls_verification') Server = namedtuple('Server', 'hostname port attempt force_tls_verification')
# force_tls_verification=True implies two things: # force_tls_verification=True implies two things:
# 1. force TLS to be enabled for this server # 1. force TLS to be enabled for this server
# 2. ensure there is some kind of verification. If the user did not enable # 2. ensure there is some kind of verification. If the user did not enable
@ -118,7 +118,7 @@ class ServersMixin(object):
# Change the port, and force TLS verification, as required by the STS # Change the port, and force TLS verification, as required by the STS
# specification. # specification.
return Server(server.hostname, policy['port'], return Server(server.hostname, policy['port'], server.attempt,
force_tls_verification=True) force_tls_verification=True)
def die(self): def die(self):

View File

@ -1823,14 +1823,17 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.driver.currentServer.hostname, policy) self.driver.currentServer.hostname, policy)
else: else:
hostname = self.driver.currentServer.hostname hostname = self.driver.currentServer.hostname
attempt = self.driver.currentServer.attempt
log.info('Got STS policy over insecure connection; ' log.info('Got STS policy over insecure connection; '
'reconnecting to secure port. %r', 'reconnecting to secure port. %r',
self.driver.currentServer) self.driver.currentServer)
# Reconnect to the server, but with TLS *and* certificate # Reconnect to the server, but with TLS *and* certificate
# validation this time. # validation this time.
self.state.fsm.on_shutdown(self, msg) self.state.fsm.on_shutdown(self, msg)
self.driver.reconnect( self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True), server=Server(hostname, parsed_policy['port'], attempt, True),
wait=True) wait=True)
def _addCapabilities(self, capstring, msg): def _addCapabilities(self, capstring, msg):

View File

@ -49,17 +49,17 @@ class DriversTestCase(SupyTestCase):
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.com', 6697, True)) drivers.Server('example.com', 6697, None, True))
driver.die() driver.die()
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.org', 6667, False)) drivers.Server('example.org', 6667, None, False))
driver.die() driver.die()
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.com', 6697, True)) drivers.Server('example.com', 6697, None, True))
def testExpiredStsPolicy(self): def testExpiredStsPolicy(self):
irc = irclib.Irc('test') irc = irclib.Irc('test')
@ -76,7 +76,7 @@ class DriversTestCase(SupyTestCase):
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.com', 6667, False)) drivers.Server('example.com', 6667, None, False))
def testRescheduledStsPolicy(self): def testRescheduledStsPolicy(self):
irc = irclib.Irc('test') irc = irclib.Irc('test')
@ -93,16 +93,16 @@ class DriversTestCase(SupyTestCase):
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.com', 6697, True)) drivers.Server('example.com', 6697, None, True))
driver.die() driver.die()
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.org', 6667, False)) drivers.Server('example.org', 6667, None, False))
driver.die() driver.die()
timeFastForward(8) timeFastForward(8)
self.assertEqual( self.assertEqual(
driver._getNextServer(), driver._getNextServer(),
drivers.Server('example.com', 6697, True)) drivers.Server('example.com', 6697, None, True))

View File

@ -696,7 +696,7 @@ class StsTestCase(SupyTestCase):
def testStsInSecureConnection(self): def testStsInSecureConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = True self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697'))) args=('*', 'LS', 'sts=duration=42,port=6697')))
@ -707,25 +707,25 @@ class StsTestCase(SupyTestCase):
def testStsInInsecureTlsConnection(self): def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697'))) args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True), server=drivers.Server('irc.test', 6697, None, True),
wait=True) wait=True)
def testStsInCleartextConnection(self): def testStsInCleartextConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697'))) args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True), server=drivers.Server('irc.test', 6697, None, True),
wait=True) wait=True)
def testStsInCleartextConnectionInvalidDuration(self): def testStsInCleartextConnectionInvalidDuration(self):
@ -733,13 +733,13 @@ class StsTestCase(SupyTestCase):
# connected clients MUST ignore it." # connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=foo,port=6697'))) args=('*', 'LS', 'sts=duration=foo,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True), server=drivers.Server('irc.test', 6697, None, True),
wait=True) wait=True)
def testStsInCleartextConnectionNoDuration(self): def testStsInCleartextConnectionNoDuration(self):
@ -747,13 +747,13 @@ class StsTestCase(SupyTestCase):
# connected clients MUST ignore it." # connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=port=6697'))) args=('*', 'LS', 'sts=port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True), server=drivers.Server('irc.test', 6697, None, True),
wait=True) wait=True)
class IrcTestCase(SupyTestCase): class IrcTestCase(SupyTestCase):