Fix various issues with STS handling.

This commit is contained in:
Valentin Lorentz 2019-12-08 21:25:59 +01:00
parent 51ff013fcc
commit 22120ee862
6 changed files with 70 additions and 36 deletions

View File

@ -241,6 +241,9 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
else: else:
drivers.log.debug('Not resetting %s.', self.irc) drivers.log.debug('Not resetting %s.', self.irc)
if wait: if wait:
if server is not None:
# Make this server be the next one to be used.
self.servers.insert(0, server)
self.scheduleReconnect() self.scheduleReconnect()
return return
self.currentServer = server or self._getNextServer() self.currentServer = server or self._getNextServer()
@ -283,8 +286,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
# See http://stackoverflow.com/q/16136916/539465 # See http://stackoverflow.com/q/16136916/539465
self.conn.connect( self.conn.connect(
(self.currentServer.hostname, self.currentServer.port)) (self.currentServer.hostname, self.currentServer.port))
if network_config.ssl() \ if network_config.ssl() or \
or self.currentServer.force_tls_verification: self.currentServer.force_tls_verification:
self.starttls() self.starttls()
# Suppress this warning for loopback IPs. # Suppress this warning for loopback IPs.
@ -294,6 +297,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
targetip = targetip.decode('utf-8') targetip = targetip.decode('utf-8')
elif (not network_config.requireStarttls()) and \ elif (not network_config.requireStarttls()) and \
(not network_config.ssl()) and \ (not network_config.ssl()) and \
(not self.currentServer.force_tls_verification) and \
(ipaddress is None or not ipaddress.ip_address(targetip).is_loopback): (ipaddress is None or not ipaddress.ip_address(targetip).is_loopback):
drivers.log.warning(('Connection to network %s ' drivers.log.warning(('Connection to network %s '
'does not use SSL/TLS, which makes it vulnerable to ' 'does not use SSL/TLS, which makes it vulnerable to '
@ -369,6 +373,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def anyCertValidationEnabled(self): def anyCertValidationEnabled(self):
"""Returns whether any kind of certificate validation is enabled, other """Returns whether any kind of certificate validation is enabled, other
than Server.force_tls_verification.""" than Server.force_tls_verification."""
network_config = getattr(conf.supybot.networks, self.irc.network)
return any([ return any([
conf.supybot.protocols.ssl.verifyCertificates(), conf.supybot.protocols.ssl.verifyCertificates(),
network_config.ssl.serverFingerprints(), network_config.ssl.serverFingerprints(),
@ -392,7 +397,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
verifyCertificates = True verifyCertificates = True
else: else:
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates() verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
if not verifyCertificates: if not self.currentServer.force_tls_verification \
and not self.anyCertValidationEnabled():
drivers.log.warning('Not checking SSL certificates, connections ' drivers.log.warning('Not checking SSL certificates, connections '
'are vulnerable to man-in-the-middle attacks. Set ' 'are vulnerable to man-in-the-middle attacks. Set '
'supybot.protocols.ssl.verifyCertificates to "true" ' 'supybot.protocols.ssl.verifyCertificates to "true" '

View File

@ -100,6 +100,8 @@ class ServersMixin(object):
lastDisconnect = network.lastDisconnectTimes.get(server.hostname) lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
if policy is None or lastDisconnect is None: if policy is None or lastDisconnect is None:
log.debug('No STS policy, or never disconnected from this server. %r %r',
policy, lastDisconnect)
return server return server
# The policy was stored, which means it was received on a secure # The policy was stored, which means it was received on a secure
@ -107,9 +109,13 @@ class ServersMixin(object):
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True) policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
if lastDisconnect + policy['duration'] < time.time(): if lastDisconnect + policy['duration'] < time.time():
log.info('STS policy expired, removing.')
network.expireStsPolicy(server.hostname) network.expireStsPolicy(server.hostname)
return server return server
log.info('Using STS policy: changing port from %s to %s.',
server.port, policy['port'])
# Change the port, and force TLS verification, as required by the STS # Change the port, and force TLS verification, as required by the STS
# specification. # specification.
return Server(server.hostname, policy['port'], return Server(server.hostname, policy['port'],

View File

@ -499,16 +499,15 @@ class IrcChannel(object):
class IrcNetwork(object): class IrcNetwork(object):
"""This class holds dynamic information about a network that should be """This class holds dynamic information about a network that should be
preserved across restarts.""" preserved across restarts."""
__slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes') __slots__ = ('stsPolicies', 'lastDisconnectTimes')
def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None): def __init__(self, stsPolicies=None, lastDisconnectTimes=None):
self.name = name
self.stsPolicies = stsPolicies or {} self.stsPolicies = stsPolicies or {}
self.lastDisconnectTimes = lastDisconnectTimes or {} self.lastDisconnectTimes = lastDisconnectTimes or {}
def __repr__(self): def __repr__(self):
return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \ return '%s(stsPolicies=%r, lastDisconnectTimes=%s)' % \
(self.__class__.__name, self.name, self.stsPolicy, (self.__class__.__name__, self.stsPolicies,
self.lastDisconnectTimes) self.lastDisconnectTimes)
def addStsPolicy(self, server, stsPolicy): def addStsPolicy(self, server, stsPolicy):
@ -659,13 +658,14 @@ class IrcChannelCreator(Creator):
class IrcNetworkCreator(Creator): class IrcNetworkCreator(Creator):
__slots__ = ('net', 'networks') __slots__ = ('net', 'networks')
name = None
def __init__(self, networks): def __init__(self, networks):
self.net = IrcNetwork() self.net = IrcNetwork()
self.networks = networks self.networks = networks
def network(self, rest, lineno): def network(self, rest, lineno):
self.net.name = rest IrcNetworkCreator.name = rest
def stspolicy(self, rest, lineno): def stspolicy(self, rest, lineno):
(server, stsPolicy) = rest.split() (server, stsPolicy) = rest.split()
@ -677,8 +677,8 @@ class IrcNetworkCreator(Creator):
self.net.lastDisconnectTimes[server] = when self.net.lastDisconnectTimes[server] = when
def finish(self): def finish(self):
if self.net.name: if self.name:
self.networks.setNetwork(self.net.name, self.net) self.networks.setNetwork(self.name, self.net)
self.net = IrcNetwork() self.net = IrcNetwork()

View File

@ -419,10 +419,15 @@ class IrcStateFsm(object):
CONNECTED_SASL = 80 CONNECTED_SASL = 80
'''Doing SASL authentication in the middle of a connection.''' '''Doing SASL authentication in the middle of a connection.'''
SHUTTING_DOWN = 100
def __init__(self): def __init__(self):
self.reset() self.reset()
def reset(self): def reset(self):
if getattr(self, 'state', None) is not None:
log.debug('resetting from %s to %s',
self.state, self.States.UNINITIALIZED)
self.state = self.States.UNINITIALIZED self.state = self.States.UNINITIALIZED
def _transition(self, to_state, expected_from=None): def _transition(self, to_state, expected_from=None):
@ -483,6 +488,9 @@ class IrcStateFsm(object):
self.States.INIT_MOTD self.States.INIT_MOTD
]) ])
def on_shutdown(self):
self._transition(self.States.SHUTTING_DOWN)
class IrcState(IrcCommandDispatcher, log.Firewalled): class IrcState(IrcCommandDispatcher, log.Firewalled):
"""Maintains state of the Irc connection. Should also become smarter. """Maintains state of the Irc connection. Should also become smarter.
""" """
@ -1252,7 +1260,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_req, self.state.capabilities_req,
self.state.capabilities_ack, self.state.capabilities_ack,
self.state.capabilities_nak) self.state.capabilities_nak)
self.driver.reconnect() self.driver.reconnect(wait=True)
elif capabilities_responded == self.state.capabilities_req: elif capabilities_responded == self.state.capabilities_req:
log.debug('Got all capabilities ACKed/NAKed') log.debug('Got all capabilities ACKed/NAKed')
# We got all the capabilities we asked for # We got all the capabilities we asked for
@ -1470,13 +1478,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.capUpkeep() self.capUpkeep()
def _onCapSts(self, policy): def _onCapSts(self, policy):
secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled() secure_connection = self.driver.currentServer.force_tls_verification \
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
parsed_policy = ircutils.parseStsPolicy( parsed_policy = ircutils.parseStsPolicy(
log, policy, parseDuration=secure_connection) log, policy, parseDuration=secure_connection)
if parsed_policy is None: if parsed_policy is None:
# There was an error (and it was logged). Abort the connection. # There was an error (and it was logged). Abort the connection.
self.driver.reconnect() self.driver.reconnect(wait=True)
return return
if secure_connection: if secure_connection:
@ -1485,14 +1494,20 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
# For future-proofing (because we don't want to write an invalid # For future-proofing (because we don't want to write an invalid
# value), we write the raw policy received from the server instead # value), we write the raw policy received from the server instead
# of the parsed one. # of the parsed one.
log.debug('Storing STS policy: %s', policy)
ircdb.networks.getNetwork(self.network).addStsPolicy( ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.server.hostname, policy) self.driver.currentServer.hostname, policy)
else: else:
hostname = self.driver.server.hostname hostname = self.driver.currentServer.hostname
log.info('Got STS policy over insecure connection; '
'reconnecting to secure port. %r',
self.driver.currentServer)
# Reconnect to the server, but with TLS *and* certificate # Reconnect to the server, but with TLS *and* certificate
# validation this time. # validation this time.
self.state.fsm.on_shutdown()
self.driver.reconnect( self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True)) server=Server(hostname, parsed_policy['port'], True),
wait=True)
def _addCapabilities(self, capstring): def _addCapabilities(self, capstring):
for item in capstring.split(): for item in capstring.split():
@ -1507,9 +1522,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if item == 'sts': if item == 'sts':
log.error('Got "sts" capability without value. Aborting ' log.error('Got "sts" capability without value. Aborting '
'connection.') 'connection.')
self.driver.reconnect() self.driver.reconnect(wait=True)
self.state.capabilities_ls[item] = None self.state.capabilities_ls[item] = None
def doCapLs(self, msg): def doCapLs(self, msg):
if len(msg.args) == 4: if len(msg.args) == 4:
# Multi-line LS # Multi-line LS
@ -1519,6 +1535,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self._addCapabilities(msg.args[3]) self._addCapabilities(msg.args[3])
elif len(msg.args) == 3: # End of LS elif len(msg.args) == 3: # End of LS
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2])
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return
self.state.fsm.expect_state([ self.state.fsm.expect_state([
# Normal case: # Normal case:
IrcStateFsm.States.INIT_CAP_NEGOTIATION, IrcStateFsm.States.INIT_CAP_NEGOTIATION,
@ -1573,6 +1591,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
caps = msg.args[2].split() caps = msg.args[2].split()
assert caps, 'Empty list of capabilities' assert caps, 'Empty list of capabilities'
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2])
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return
common_supported_unrequested_capabilities = ( common_supported_unrequested_capabilities = (
set(self.state.capabilities_ls) & set(self.state.capabilities_ls) &
self.REQUEST_CAPABILITIES - self.REQUEST_CAPABILITIES -

View File

@ -352,7 +352,6 @@ class IrcChannelTestCase(IrcdbTestCase):
class IrcNetworkTestCase(IrcdbTestCase): class IrcNetworkTestCase(IrcdbTestCase):
def testDefaults(self): def testDefaults(self):
n = ircdb.IrcNetwork() n = ircdb.IrcNetwork()
self.assertIsNone(n.name)
self.assertEqual(n.stsPolicies, {}) self.assertEqual(n.stsPolicies, {})
self.assertEqual(n.lastDisconnectTimes, {}) self.assertEqual(n.lastDisconnectTimes, {})
@ -373,7 +372,7 @@ class IrcNetworkTestCase(IrcdbTestCase):
self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_) self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_)
def testPreserve(self): def testPreserve(self):
n = ircdb.IrcNetwork('foonet') n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1') n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2') n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo') n.addDisconnection('foo')
@ -459,15 +458,14 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
IrcdbTestCase.setUp(self) IrcdbTestCase.setUp(self)
def testGetSetNetwork(self): def testGetSetNetwork(self):
self.assertEqual(self.networks.getNetwork('foo').name, None) self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {})
n = ircdb.IrcNetwork() n = ircdb.IrcNetwork()
n.name = 'foo'
self.networks.setNetwork('foo', n) self.networks.setNetwork('foo', n)
self.assertEqual(self.networks.getNetwork('foo').name, 'foo') self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {})
def testPreserveOne(self): def testPreserveOne(self):
n = ircdb.IrcNetwork('foonet') n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1') n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2') n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo') n.addDisconnection('foo')
@ -496,15 +494,15 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
]) ])
def testPreserveThree(self): def testPreserveThree(self):
n = ircdb.IrcNetwork('foonet') n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1') n.addStsPolicy('foo', 'sts1')
self.networks.setNetwork('foonet', n) self.networks.setNetwork('foonet', n)
n = ircdb.IrcNetwork('barnet') n = ircdb.IrcNetwork()
n.addStsPolicy('bar', 'sts2') n.addStsPolicy('bar', 'sts2')
self.networks.setNetwork('barnet', n) self.networks.setNetwork('barnet', n)
n = ircdb.IrcNetwork('baznet') n = ircdb.IrcNetwork()
n.addStsPolicy('baz', 'sts3') n.addStsPolicy('baz', 'sts3')
self.networks.setNetwork('baznet', n) self.networks.setNetwork('baznet', n)

View File

@ -520,7 +520,7 @@ class StsTestCase(SupyTestCase):
def testStsInSecureConnection(self): def testStsInSecureConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = True self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6697, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697'))) args=('*', 'LS', 'sts=duration=42,port=6697')))
@ -531,50 +531,54 @@ class StsTestCase(SupyTestCase):
def testStsInInsecureTlsConnection(self): def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6697, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697'))) args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True)) server=drivers.Server('irc.test', 6697, True),
wait=True)
def testStsInCleartextConnection(self): def testStsInCleartextConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697'))) args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True)) server=drivers.Server('irc.test', 6697, True),
wait=True)
def testStsInCleartextConnectionInvalidDuration(self): def testStsInCleartextConnectionInvalidDuration(self):
# "Servers MAY send this key to all clients, but insecurely # "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it." # connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=foo,port=6697'))) args=('*', 'LS', 'sts=duration=foo,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True)) server=drivers.Server('irc.test', 6697, True),
wait=True)
def testStsInCleartextConnectionNoDuration(self): def testStsInCleartextConnectionNoDuration(self):
# "Servers MAY send this key to all clients, but insecurely # "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it." # connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False) self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=port=6697'))) args=('*', 'LS', 'sts=port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {}) self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True)) server=drivers.Server('irc.test', 6697, True),
wait=True)
class IrcTestCase(SupyTestCase): class IrcTestCase(SupyTestCase):
def setUp(self): def setUp(self):