When getting STS policy over insecure connection, reuse the exact same IP address

Otherwise, if some IP addresses don't work (eg. all odd ones), the bot will
consecutively fail because it can't connect, then connect + get STS + reconnect,
then fail again, then connect + get STS, etc.
This commit is contained in:
Valentin Lorentz 2021-01-11 23:22:21 +01:00
parent ba77de0946
commit 772ec8d6a9
6 changed files with 27 additions and 20 deletions

View File

@ -284,7 +284,7 @@ class Servers(registry.SpaceSeparatedListOfStrings):
hostname = hostname[1:-1]
port = int(port)
return Server(hostname, port, force_tls_verification=False)
return Server(hostname, port, None, force_tls_verification=False)
def __call__(self):
L = registry.SpaceSeparatedListOfStrings.__call__(self)

View File

@ -257,6 +257,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
return
self.currentServer = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network)
if self.currentServer.attempt is None:
self.currentServer = self.currentServer._replace(attempt=self._attempt)
else:
self._attempt = self.currentServer.attempt
socks_proxy = network_config.socksproxy()
try:
if socks_proxy:

View File

@ -40,7 +40,7 @@ from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
from ..utils import minisix
Server = namedtuple('Server', 'hostname port force_tls_verification')
Server = namedtuple('Server', 'hostname port attempt force_tls_verification')
# force_tls_verification=True implies two things:
# 1. force TLS to be enabled for this server
# 2. ensure there is some kind of verification. If the user did not enable
@ -118,7 +118,7 @@ class ServersMixin(object):
# Change the port, and force TLS verification, as required by the STS
# specification.
return Server(server.hostname, policy['port'],
return Server(server.hostname, policy['port'], server.attempt,
force_tls_verification=True)
def die(self):

View File

@ -1823,14 +1823,17 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.driver.currentServer.hostname, policy)
else:
hostname = self.driver.currentServer.hostname
attempt = self.driver.currentServer.attempt
log.info('Got STS policy over insecure connection; '
'reconnecting to secure port. %r',
self.driver.currentServer)
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.state.fsm.on_shutdown(self, msg)
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True),
server=Server(hostname, parsed_policy['port'], attempt, True),
wait=True)
def _addCapabilities(self, capstring, msg):

View File

@ -49,17 +49,17 @@ class DriversTestCase(SupyTestCase):
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
drivers.Server('example.com', 6697, None, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
drivers.Server('example.org', 6667, None, False))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
drivers.Server('example.com', 6697, None, True))
def testExpiredStsPolicy(self):
irc = irclib.Irc('test')
@ -76,7 +76,7 @@ class DriversTestCase(SupyTestCase):
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6667, False))
drivers.Server('example.com', 6667, None, False))
def testRescheduledStsPolicy(self):
irc = irclib.Irc('test')
@ -93,16 +93,16 @@ class DriversTestCase(SupyTestCase):
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
drivers.Server('example.com', 6697, None, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
drivers.Server('example.org', 6667, None, False))
driver.die()
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
drivers.Server('example.com', 6697, None, True))

View File

@ -696,7 +696,7 @@ class StsTestCase(SupyTestCase):
def testStsInSecureConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
@ -707,25 +707,25 @@ class StsTestCase(SupyTestCase):
def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
server=drivers.Server('irc.test', 6697, None, True),
wait=True)
def testStsInCleartextConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
server=drivers.Server('irc.test', 6697, None, True),
wait=True)
def testStsInCleartextConnectionInvalidDuration(self):
@ -733,13 +733,13 @@ class StsTestCase(SupyTestCase):
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=foo,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
server=drivers.Server('irc.test', 6697, None, True),
wait=True)
def testStsInCleartextConnectionNoDuration(self):
@ -747,13 +747,13 @@ class StsTestCase(SupyTestCase):
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
server=drivers.Server('irc.test', 6697, None, True),
wait=True)
class IrcTestCase(SupyTestCase):