mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-12-25 12:12:54 +01:00
When getting STS policy over insecure connection, reuse the exact same IP address
Otherwise, if some IP addresses don't work (eg. all odd ones), the bot will consecutively fail because it can't connect, then connect + get STS + reconnect, then fail again, then connect + get STS, etc.
This commit is contained in:
parent
ba77de0946
commit
772ec8d6a9
@ -284,7 +284,7 @@ class Servers(registry.SpaceSeparatedListOfStrings):
|
|||||||
hostname = hostname[1:-1]
|
hostname = hostname[1:-1]
|
||||||
|
|
||||||
port = int(port)
|
port = int(port)
|
||||||
return Server(hostname, port, force_tls_verification=False)
|
return Server(hostname, port, None, force_tls_verification=False)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
L = registry.SpaceSeparatedListOfStrings.__call__(self)
|
L = registry.SpaceSeparatedListOfStrings.__call__(self)
|
||||||
|
@ -257,6 +257,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
|||||||
return
|
return
|
||||||
self.currentServer = server or self._getNextServer()
|
self.currentServer = server or self._getNextServer()
|
||||||
network_config = getattr(conf.supybot.networks, self.irc.network)
|
network_config = getattr(conf.supybot.networks, self.irc.network)
|
||||||
|
if self.currentServer.attempt is None:
|
||||||
|
self.currentServer = self.currentServer._replace(attempt=self._attempt)
|
||||||
|
else:
|
||||||
|
self._attempt = self.currentServer.attempt
|
||||||
socks_proxy = network_config.socksproxy()
|
socks_proxy = network_config.socksproxy()
|
||||||
try:
|
try:
|
||||||
if socks_proxy:
|
if socks_proxy:
|
||||||
|
@ -40,7 +40,7 @@ from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
|
|||||||
from ..utils import minisix
|
from ..utils import minisix
|
||||||
|
|
||||||
|
|
||||||
Server = namedtuple('Server', 'hostname port force_tls_verification')
|
Server = namedtuple('Server', 'hostname port attempt force_tls_verification')
|
||||||
# force_tls_verification=True implies two things:
|
# force_tls_verification=True implies two things:
|
||||||
# 1. force TLS to be enabled for this server
|
# 1. force TLS to be enabled for this server
|
||||||
# 2. ensure there is some kind of verification. If the user did not enable
|
# 2. ensure there is some kind of verification. If the user did not enable
|
||||||
@ -118,7 +118,7 @@ class ServersMixin(object):
|
|||||||
|
|
||||||
# Change the port, and force TLS verification, as required by the STS
|
# Change the port, and force TLS verification, as required by the STS
|
||||||
# specification.
|
# specification.
|
||||||
return Server(server.hostname, policy['port'],
|
return Server(server.hostname, policy['port'], server.attempt,
|
||||||
force_tls_verification=True)
|
force_tls_verification=True)
|
||||||
|
|
||||||
def die(self):
|
def die(self):
|
||||||
|
@ -1823,14 +1823,17 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
self.driver.currentServer.hostname, policy)
|
self.driver.currentServer.hostname, policy)
|
||||||
else:
|
else:
|
||||||
hostname = self.driver.currentServer.hostname
|
hostname = self.driver.currentServer.hostname
|
||||||
|
attempt = self.driver.currentServer.attempt
|
||||||
|
|
||||||
log.info('Got STS policy over insecure connection; '
|
log.info('Got STS policy over insecure connection; '
|
||||||
'reconnecting to secure port. %r',
|
'reconnecting to secure port. %r',
|
||||||
self.driver.currentServer)
|
self.driver.currentServer)
|
||||||
# Reconnect to the server, but with TLS *and* certificate
|
# Reconnect to the server, but with TLS *and* certificate
|
||||||
# validation this time.
|
# validation this time.
|
||||||
self.state.fsm.on_shutdown(self, msg)
|
self.state.fsm.on_shutdown(self, msg)
|
||||||
|
|
||||||
self.driver.reconnect(
|
self.driver.reconnect(
|
||||||
server=Server(hostname, parsed_policy['port'], True),
|
server=Server(hostname, parsed_policy['port'], attempt, True),
|
||||||
wait=True)
|
wait=True)
|
||||||
|
|
||||||
def _addCapabilities(self, capstring, msg):
|
def _addCapabilities(self, capstring, msg):
|
||||||
|
@ -49,17 +49,17 @@ class DriversTestCase(SupyTestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.com', 6697, True))
|
drivers.Server('example.com', 6697, None, True))
|
||||||
driver.die()
|
driver.die()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.org', 6667, False))
|
drivers.Server('example.org', 6667, None, False))
|
||||||
driver.die()
|
driver.die()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.com', 6697, True))
|
drivers.Server('example.com', 6697, None, True))
|
||||||
|
|
||||||
def testExpiredStsPolicy(self):
|
def testExpiredStsPolicy(self):
|
||||||
irc = irclib.Irc('test')
|
irc = irclib.Irc('test')
|
||||||
@ -76,7 +76,7 @@ class DriversTestCase(SupyTestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.com', 6667, False))
|
drivers.Server('example.com', 6667, None, False))
|
||||||
|
|
||||||
def testRescheduledStsPolicy(self):
|
def testRescheduledStsPolicy(self):
|
||||||
irc = irclib.Irc('test')
|
irc = irclib.Irc('test')
|
||||||
@ -93,16 +93,16 @@ class DriversTestCase(SupyTestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.com', 6697, True))
|
drivers.Server('example.com', 6697, None, True))
|
||||||
driver.die()
|
driver.die()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.org', 6667, False))
|
drivers.Server('example.org', 6667, None, False))
|
||||||
driver.die()
|
driver.die()
|
||||||
|
|
||||||
timeFastForward(8)
|
timeFastForward(8)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
driver._getNextServer(),
|
driver._getNextServer(),
|
||||||
drivers.Server('example.com', 6697, True))
|
drivers.Server('example.com', 6697, None, True))
|
||||||
|
@ -696,7 +696,7 @@ class StsTestCase(SupyTestCase):
|
|||||||
def testStsInSecureConnection(self):
|
def testStsInSecureConnection(self):
|
||||||
self.irc.driver.anyCertValidationEnabled.return_value = True
|
self.irc.driver.anyCertValidationEnabled.return_value = True
|
||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||||
|
|
||||||
@ -707,25 +707,25 @@ class StsTestCase(SupyTestCase):
|
|||||||
def testStsInInsecureTlsConnection(self):
|
def testStsInInsecureTlsConnection(self):
|
||||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||||
|
|
||||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||||
self.irc.driver.reconnect.assert_called_once_with(
|
self.irc.driver.reconnect.assert_called_once_with(
|
||||||
server=drivers.Server('irc.test', 6697, True),
|
server=drivers.Server('irc.test', 6697, None, True),
|
||||||
wait=True)
|
wait=True)
|
||||||
|
|
||||||
def testStsInCleartextConnection(self):
|
def testStsInCleartextConnection(self):
|
||||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||||
|
|
||||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||||
self.irc.driver.reconnect.assert_called_once_with(
|
self.irc.driver.reconnect.assert_called_once_with(
|
||||||
server=drivers.Server('irc.test', 6697, True),
|
server=drivers.Server('irc.test', 6697, None, True),
|
||||||
wait=True)
|
wait=True)
|
||||||
|
|
||||||
def testStsInCleartextConnectionInvalidDuration(self):
|
def testStsInCleartextConnectionInvalidDuration(self):
|
||||||
@ -733,13 +733,13 @@ class StsTestCase(SupyTestCase):
|
|||||||
# connected clients MUST ignore it."
|
# connected clients MUST ignore it."
|
||||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=duration=foo,port=6697')))
|
args=('*', 'LS', 'sts=duration=foo,port=6697')))
|
||||||
|
|
||||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||||
self.irc.driver.reconnect.assert_called_once_with(
|
self.irc.driver.reconnect.assert_called_once_with(
|
||||||
server=drivers.Server('irc.test', 6697, True),
|
server=drivers.Server('irc.test', 6697, None, True),
|
||||||
wait=True)
|
wait=True)
|
||||||
|
|
||||||
def testStsInCleartextConnectionNoDuration(self):
|
def testStsInCleartextConnectionNoDuration(self):
|
||||||
@ -747,13 +747,13 @@ class StsTestCase(SupyTestCase):
|
|||||||
# connected clients MUST ignore it."
|
# connected clients MUST ignore it."
|
||||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=port=6697')))
|
args=('*', 'LS', 'sts=port=6697')))
|
||||||
|
|
||||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||||
self.irc.driver.reconnect.assert_called_once_with(
|
self.irc.driver.reconnect.assert_called_once_with(
|
||||||
server=drivers.Server('irc.test', 6697, True),
|
server=drivers.Server('irc.test', 6697, None, True),
|
||||||
wait=True)
|
wait=True)
|
||||||
|
|
||||||
class IrcTestCase(SupyTestCase):
|
class IrcTestCase(SupyTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user