mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-12-25 04:02:46 +01:00
Fix various issues with STS handling.
This commit is contained in:
parent
51ff013fcc
commit
22120ee862
@ -241,6 +241,9 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
else:
|
||||
drivers.log.debug('Not resetting %s.', self.irc)
|
||||
if wait:
|
||||
if server is not None:
|
||||
# Make this server be the next one to be used.
|
||||
self.servers.insert(0, server)
|
||||
self.scheduleReconnect()
|
||||
return
|
||||
self.currentServer = server or self._getNextServer()
|
||||
@ -283,8 +286,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
# See http://stackoverflow.com/q/16136916/539465
|
||||
self.conn.connect(
|
||||
(self.currentServer.hostname, self.currentServer.port))
|
||||
if network_config.ssl() \
|
||||
or self.currentServer.force_tls_verification:
|
||||
if network_config.ssl() or \
|
||||
self.currentServer.force_tls_verification:
|
||||
self.starttls()
|
||||
|
||||
# Suppress this warning for loopback IPs.
|
||||
@ -294,6 +297,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
targetip = targetip.decode('utf-8')
|
||||
elif (not network_config.requireStarttls()) and \
|
||||
(not network_config.ssl()) and \
|
||||
(not self.currentServer.force_tls_verification) and \
|
||||
(ipaddress is None or not ipaddress.ip_address(targetip).is_loopback):
|
||||
drivers.log.warning(('Connection to network %s '
|
||||
'does not use SSL/TLS, which makes it vulnerable to '
|
||||
@ -369,6 +373,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
def anyCertValidationEnabled(self):
|
||||
"""Returns whether any kind of certificate validation is enabled, other
|
||||
than Server.force_tls_verification."""
|
||||
network_config = getattr(conf.supybot.networks, self.irc.network)
|
||||
return any([
|
||||
conf.supybot.protocols.ssl.verifyCertificates(),
|
||||
network_config.ssl.serverFingerprints(),
|
||||
@ -392,7 +397,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
verifyCertificates = True
|
||||
else:
|
||||
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
|
||||
if not verifyCertificates:
|
||||
if not self.currentServer.force_tls_verification \
|
||||
and not self.anyCertValidationEnabled():
|
||||
drivers.log.warning('Not checking SSL certificates, connections '
|
||||
'are vulnerable to man-in-the-middle attacks. Set '
|
||||
'supybot.protocols.ssl.verifyCertificates to "true" '
|
||||
|
@ -100,6 +100,8 @@ class ServersMixin(object):
|
||||
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
|
||||
|
||||
if policy is None or lastDisconnect is None:
|
||||
log.debug('No STS policy, or never disconnected from this server. %r %r',
|
||||
policy, lastDisconnect)
|
||||
return server
|
||||
|
||||
# The policy was stored, which means it was received on a secure
|
||||
@ -107,9 +109,13 @@ class ServersMixin(object):
|
||||
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
|
||||
|
||||
if lastDisconnect + policy['duration'] < time.time():
|
||||
log.info('STS policy expired, removing.')
|
||||
network.expireStsPolicy(server.hostname)
|
||||
return server
|
||||
|
||||
log.info('Using STS policy: changing port from %s to %s.',
|
||||
server.port, policy['port'])
|
||||
|
||||
# Change the port, and force TLS verification, as required by the STS
|
||||
# specification.
|
||||
return Server(server.hostname, policy['port'],
|
||||
|
16
src/ircdb.py
16
src/ircdb.py
@ -499,16 +499,15 @@ class IrcChannel(object):
|
||||
class IrcNetwork(object):
|
||||
"""This class holds dynamic information about a network that should be
|
||||
preserved across restarts."""
|
||||
__slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes')
|
||||
__slots__ = ('stsPolicies', 'lastDisconnectTimes')
|
||||
|
||||
def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None):
|
||||
self.name = name
|
||||
def __init__(self, stsPolicies=None, lastDisconnectTimes=None):
|
||||
self.stsPolicies = stsPolicies or {}
|
||||
self.lastDisconnectTimes = lastDisconnectTimes or {}
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \
|
||||
(self.__class__.__name, self.name, self.stsPolicy,
|
||||
return '%s(stsPolicies=%r, lastDisconnectTimes=%s)' % \
|
||||
(self.__class__.__name__, self.stsPolicies,
|
||||
self.lastDisconnectTimes)
|
||||
|
||||
def addStsPolicy(self, server, stsPolicy):
|
||||
@ -659,13 +658,14 @@ class IrcChannelCreator(Creator):
|
||||
|
||||
class IrcNetworkCreator(Creator):
|
||||
__slots__ = ('net', 'networks')
|
||||
name = None
|
||||
|
||||
def __init__(self, networks):
|
||||
self.net = IrcNetwork()
|
||||
self.networks = networks
|
||||
|
||||
def network(self, rest, lineno):
|
||||
self.net.name = rest
|
||||
IrcNetworkCreator.name = rest
|
||||
|
||||
def stspolicy(self, rest, lineno):
|
||||
(server, stsPolicy) = rest.split()
|
||||
@ -677,8 +677,8 @@ class IrcNetworkCreator(Creator):
|
||||
self.net.lastDisconnectTimes[server] = when
|
||||
|
||||
def finish(self):
|
||||
if self.net.name:
|
||||
self.networks.setNetwork(self.net.name, self.net)
|
||||
if self.name:
|
||||
self.networks.setNetwork(self.name, self.net)
|
||||
self.net = IrcNetwork()
|
||||
|
||||
|
||||
|
@ -419,10 +419,15 @@ class IrcStateFsm(object):
|
||||
CONNECTED_SASL = 80
|
||||
'''Doing SASL authentication in the middle of a connection.'''
|
||||
|
||||
SHUTTING_DOWN = 100
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
if getattr(self, 'state', None) is not None:
|
||||
log.debug('resetting from %s to %s',
|
||||
self.state, self.States.UNINITIALIZED)
|
||||
self.state = self.States.UNINITIALIZED
|
||||
|
||||
def _transition(self, to_state, expected_from=None):
|
||||
@ -483,6 +488,9 @@ class IrcStateFsm(object):
|
||||
self.States.INIT_MOTD
|
||||
])
|
||||
|
||||
def on_shutdown(self):
|
||||
self._transition(self.States.SHUTTING_DOWN)
|
||||
|
||||
class IrcState(IrcCommandDispatcher, log.Firewalled):
|
||||
"""Maintains state of the Irc connection. Should also become smarter.
|
||||
"""
|
||||
@ -1252,7 +1260,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.state.capabilities_req,
|
||||
self.state.capabilities_ack,
|
||||
self.state.capabilities_nak)
|
||||
self.driver.reconnect()
|
||||
self.driver.reconnect(wait=True)
|
||||
elif capabilities_responded == self.state.capabilities_req:
|
||||
log.debug('Got all capabilities ACKed/NAKed')
|
||||
# We got all the capabilities we asked for
|
||||
@ -1470,13 +1478,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.capUpkeep()
|
||||
|
||||
def _onCapSts(self, policy):
|
||||
secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled()
|
||||
secure_connection = self.driver.currentServer.force_tls_verification \
|
||||
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
|
||||
|
||||
parsed_policy = ircutils.parseStsPolicy(
|
||||
log, policy, parseDuration=secure_connection)
|
||||
if parsed_policy is None:
|
||||
# There was an error (and it was logged). Abort the connection.
|
||||
self.driver.reconnect()
|
||||
self.driver.reconnect(wait=True)
|
||||
return
|
||||
|
||||
if secure_connection:
|
||||
@ -1485,14 +1494,20 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
# For future-proofing (because we don't want to write an invalid
|
||||
# value), we write the raw policy received from the server instead
|
||||
# of the parsed one.
|
||||
log.debug('Storing STS policy: %s', policy)
|
||||
ircdb.networks.getNetwork(self.network).addStsPolicy(
|
||||
self.driver.server.hostname, policy)
|
||||
self.driver.currentServer.hostname, policy)
|
||||
else:
|
||||
hostname = self.driver.server.hostname
|
||||
hostname = self.driver.currentServer.hostname
|
||||
log.info('Got STS policy over insecure connection; '
|
||||
'reconnecting to secure port. %r',
|
||||
self.driver.currentServer)
|
||||
# Reconnect to the server, but with TLS *and* certificate
|
||||
# validation this time.
|
||||
self.state.fsm.on_shutdown()
|
||||
self.driver.reconnect(
|
||||
server=Server(hostname, parsed_policy['port'], True))
|
||||
server=Server(hostname, parsed_policy['port'], True),
|
||||
wait=True)
|
||||
|
||||
def _addCapabilities(self, capstring):
|
||||
for item in capstring.split():
|
||||
@ -1507,9 +1522,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
if item == 'sts':
|
||||
log.error('Got "sts" capability without value. Aborting '
|
||||
'connection.')
|
||||
self.driver.reconnect()
|
||||
self.driver.reconnect(wait=True)
|
||||
self.state.capabilities_ls[item] = None
|
||||
|
||||
|
||||
def doCapLs(self, msg):
|
||||
if len(msg.args) == 4:
|
||||
# Multi-line LS
|
||||
@ -1519,6 +1535,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self._addCapabilities(msg.args[3])
|
||||
elif len(msg.args) == 3: # End of LS
|
||||
self._addCapabilities(msg.args[2])
|
||||
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
||||
return
|
||||
self.state.fsm.expect_state([
|
||||
# Normal case:
|
||||
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
|
||||
@ -1573,6 +1591,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
caps = msg.args[2].split()
|
||||
assert caps, 'Empty list of capabilities'
|
||||
self._addCapabilities(msg.args[2])
|
||||
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
|
||||
return
|
||||
common_supported_unrequested_capabilities = (
|
||||
set(self.state.capabilities_ls) &
|
||||
self.REQUEST_CAPABILITIES -
|
||||
|
@ -352,7 +352,6 @@ class IrcChannelTestCase(IrcdbTestCase):
|
||||
class IrcNetworkTestCase(IrcdbTestCase):
|
||||
def testDefaults(self):
|
||||
n = ircdb.IrcNetwork()
|
||||
self.assertIsNone(n.name)
|
||||
self.assertEqual(n.stsPolicies, {})
|
||||
self.assertEqual(n.lastDisconnectTimes, {})
|
||||
|
||||
@ -373,7 +372,7 @@ class IrcNetworkTestCase(IrcdbTestCase):
|
||||
self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_)
|
||||
|
||||
def testPreserve(self):
|
||||
n = ircdb.IrcNetwork('foonet')
|
||||
n = ircdb.IrcNetwork()
|
||||
n.addStsPolicy('foo', 'sts1')
|
||||
n.addStsPolicy('bar', 'sts2')
|
||||
n.addDisconnection('foo')
|
||||
@ -459,15 +458,14 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
|
||||
IrcdbTestCase.setUp(self)
|
||||
|
||||
def testGetSetNetwork(self):
|
||||
self.assertEqual(self.networks.getNetwork('foo').name, None)
|
||||
self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {})
|
||||
|
||||
n = ircdb.IrcNetwork()
|
||||
n.name = 'foo'
|
||||
self.networks.setNetwork('foo', n)
|
||||
self.assertEqual(self.networks.getNetwork('foo').name, 'foo')
|
||||
self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {})
|
||||
|
||||
def testPreserveOne(self):
|
||||
n = ircdb.IrcNetwork('foonet')
|
||||
n = ircdb.IrcNetwork()
|
||||
n.addStsPolicy('foo', 'sts1')
|
||||
n.addStsPolicy('bar', 'sts2')
|
||||
n.addDisconnection('foo')
|
||||
@ -496,15 +494,15 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
|
||||
])
|
||||
|
||||
def testPreserveThree(self):
|
||||
n = ircdb.IrcNetwork('foonet')
|
||||
n = ircdb.IrcNetwork()
|
||||
n.addStsPolicy('foo', 'sts1')
|
||||
self.networks.setNetwork('foonet', n)
|
||||
|
||||
n = ircdb.IrcNetwork('barnet')
|
||||
n = ircdb.IrcNetwork()
|
||||
n.addStsPolicy('bar', 'sts2')
|
||||
self.networks.setNetwork('barnet', n)
|
||||
|
||||
n = ircdb.IrcNetwork('baznet')
|
||||
n = ircdb.IrcNetwork()
|
||||
n.addStsPolicy('baz', 'sts3')
|
||||
self.networks.setNetwork('baznet', n)
|
||||
|
||||
|
@ -520,7 +520,7 @@ class StsTestCase(SupyTestCase):
|
||||
def testStsInSecureConnection(self):
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = True
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
|
||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||
|
||||
@ -531,50 +531,54 @@ class StsTestCase(SupyTestCase):
|
||||
def testStsInInsecureTlsConnection(self):
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
|
||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||
self.irc.driver.reconnect.assert_called_once_with(
|
||||
server=drivers.Server('irc.test', 6697, True))
|
||||
server=drivers.Server('irc.test', 6697, True),
|
||||
wait=True)
|
||||
|
||||
def testStsInCleartextConnection(self):
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||
self.irc.driver.reconnect.assert_called_once_with(
|
||||
server=drivers.Server('irc.test', 6697, True))
|
||||
server=drivers.Server('irc.test', 6697, True),
|
||||
wait=True)
|
||||
|
||||
def testStsInCleartextConnectionInvalidDuration(self):
|
||||
# "Servers MAY send this key to all clients, but insecurely
|
||||
# connected clients MUST ignore it."
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=foo,port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||
self.irc.driver.reconnect.assert_called_once_with(
|
||||
server=drivers.Server('irc.test', 6697, True))
|
||||
server=drivers.Server('irc.test', 6697, True),
|
||||
wait=True)
|
||||
|
||||
def testStsInCleartextConnectionNoDuration(self):
|
||||
# "Servers MAY send this key to all clients, but insecurely
|
||||
# connected clients MUST ignore it."
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||
self.irc.driver.reconnect.assert_called_once_with(
|
||||
server=drivers.Server('irc.test', 6697, True))
|
||||
server=drivers.Server('irc.test', 6697, True),
|
||||
wait=True)
|
||||
|
||||
class IrcTestCase(SupyTestCase):
|
||||
def setUp(self):
|
||||
|
Loading…
Reference in New Issue
Block a user