Apply STS policies when connecting to a server.

This commit is contained in:
Valentin Lorentz 2019-12-08 15:54:48 +01:00
parent ecc2c32950
commit 51ff013fcc
8 changed files with 210 additions and 26 deletions

View File

@ -225,6 +225,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self._attempt += 1
self.nextReconnectTime = None
if self.connected:
self.onDisconnect()
drivers.log.reconnect(self.irc.network)
if self in self._instances:
self._instances.remove(self)
@ -242,7 +243,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if wait:
self.scheduleReconnect()
return
self.server = server or self._getNextServer()
self.currentServer = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network)
socks_proxy = network_config.socksproxy()
try:
@ -255,7 +256,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
else:
try:
hostname = utils.net.getAddressFromHostname(
self.server.hostname,
self.currentServer.hostname,
attempt=self._attempt)
except (socket.gaierror, socket.error) as e:
drivers.log.connectError(self.currentServer, e)
@ -264,8 +265,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.connect(self.currentServer)
try:
self.conn = utils.net.getSocket(
self.server.hostname,
port=self.server.port,
self.currentServer.hostname,
port=self.currentServer.port,
socks_proxy=socks_proxy,
vhost=conf.supybot.protocols.irc.vhost(),
vhostv6=conf.supybot.protocols.irc.vhostv6(),
@ -280,8 +281,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465
self.conn.connect((self.server.hostname, self.server.port))
if network_config.ssl() or self.server.force_tls_verification:
self.conn.connect(
(self.currentServer.hostname, self.currentServer.port))
if network_config.ssl() \
or self.currentServer.force_tls_verification:
self.starttls()
# Suppress this warning for loopback IPs.
@ -351,6 +354,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if self.writeCheckTime is not None:
self.writeCheckTime = None
drivers.log.die(self.irc)
drivers.IrcDriver.die(self)
drivers.ServersMixin.die(self)
def _reallyDie(self):
if self.conn is not None:
@ -382,7 +387,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.warning('Could not find cert file %s.' %
certfile)
certfile = None
if self.server.force_tls_verification \
if self.currentServer.force_tls_verification \
and not self.anyCertValidationEnabled():
verifyCertificates = True
else:
@ -395,7 +400,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try:
self.conn = utils.net.ssl_wrap_socket(self.conn,
logger=drivers.log,
hostname=self.server.hostname,
hostname=self.currentServer.hostname,
certfile=certfile,
verify=verifyCertificates,
trusted_fingerprints=network_config.ssl.serverFingerprints(),

View File

@ -32,10 +32,11 @@
Contains various drivers (network, file, and otherwise) for using IRC objects.
"""
import time
import socket
from collections import namedtuple
from .. import conf, ircmsgs, log as supylog, utils
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
from ..utils import minisix
@ -73,6 +74,7 @@ class IrcDriver(object):
class ServersMixin(object):
def __init__(self, irc, servers=()):
self.networkName = irc.network
self.networkGroup = conf.supybot.networks.get(irc.network)
self.servers = servers
super(ServersMixin, self).__init__()
@ -89,8 +91,36 @@ class ServersMixin(object):
assert self.servers, 'Servers value for %s is empty.' % \
self.networkGroup._name
server = self.servers.pop(0)
self.currentServer = '%s:%s' % (server.hostname, server.port)
return server
self.currentServer = self._applyStsPolicy(server)
return self.currentServer
def _applyStsPolicy(self, server):
network = ircdb.networks.getNetwork(self.networkName)
policy = network.stsPolicies.get(server.hostname)
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
if policy is None or lastDisconnect is None:
return server
# The policy was stored, which means it was received on a secure
# connection.
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
if lastDisconnect + policy['duration'] < time.time():
network.expireStsPolicy(server.hostname)
return server
# Change the port, and force TLS verification, as required by the STS
# specification.
return Server(server.hostname, policy['port'],
force_tls_verification=True)
def die(self):
self.onDisconnect()
def onDisconnect(self):
network = ircdb.networks.getNetwork(self.networkName)
network.addDisconnection(self.currentServer.hostname)
def empty():
@ -138,7 +168,8 @@ def run():
class Log(object):
"""This is used to have a nice, consistent interface for drivers to use."""
def connect(self, server):
self.info('Connecting to %s.', server)
self.info('Connecting to %s:%s.',
server.hostname, server.port)
def connectError(self, server, e):
if isinstance(e, Exception):
@ -146,7 +177,8 @@ class Log(object):
e = e.args[1]
else:
e = utils.exnToString(e)
self.warning('Error connecting to %s: %s', server, e)
self.warning('Error connecting to %s:%s: %s',
server.hostname, server.port, e)
def disconnect(self, server, e=None):
if e:
@ -156,7 +188,8 @@ class Log(object):
e = str(e)
if not e.endswith('.'):
e += '.'
self.warning('Disconnect from %s: %s', server, e)
self.warning('Disconnect from %s:%s: %s',
server.hostname, server.port, e)
else:
self.info('Disconnect from %s.', server)

View File

@ -515,6 +515,10 @@ class IrcNetwork(object):
assert isinstance(stsPolicy, str)
self.stsPolicies[server] = stsPolicy
def expireStsPolicy(self, server):
if server in self.stsPolicies:
del self.stsPolicies[server]
def addDisconnection(self, server):
self.lastDisconnectTimes[server] = int(time.time())
@ -674,8 +678,8 @@ class IrcNetworkCreator(Creator):
def finish(self):
if self.net.name:
self.networks.setNetwork(self.net)
self.name = None
self.networks.setNetwork(self.net.name, self.net)
self.net = IrcNetwork()
class DuplicateHostmask(ValueError):

View File

@ -1470,19 +1470,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.capUpkeep()
def _onCapSts(self, policy):
parsed_policy = ircutils._parseStsPolicy(log, policy)
secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled()
parsed_policy = ircutils.parseStsPolicy(
log, policy, parseDuration=secure_connection)
if parsed_policy is None:
# There was an error (and it was logged). Abort the connection.
self.driver.reconnect()
return
if not self.driver.ssl or not self.driver.anyCertValidationEnabled():
hostname = self.driver.server.hostname
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True))
else:
if secure_connection:
# TLS is enabled and certificate is verified; write the STS policy
# in stone.
# For future-proofing (because we don't want to write an invalid
@ -1490,6 +1487,12 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
# of the parsed one.
ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.server.hostname, policy)
else:
hostname = self.driver.server.hostname
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True))
def _addCapabilities(self, capstring):
for item in capstring.split():

View File

@ -931,7 +931,7 @@ class AuthenticateDecoder(object):
return base64.b64decode(b''.join(self.chunks))
def _parseStsPolicy(logger, policy):
def parseStsPolicy(logger, policy, parseDuration):
parsed_policy = {}
for kv in policy.split(','):
if '=' in kv:
@ -941,6 +941,10 @@ def _parseStsPolicy(logger, policy):
parsed_policy[kv] = None
for key in ('port', 'duration'):
if key == 'duration' and not parseDuration:
if key in parsed_policy:
del parsed_policy[key]
continue
if parsed_policy.get(key) is None:
logger.error('Missing or empty "%s" key in STS policy.'
'Aborting connection.', key)

View File

@ -42,7 +42,7 @@ import multiprocessing
import re
from . import conf, drivers, ircutils, log, registry
from . import conf, ircutils, log, registry
from .utils import minisix
startedAt = time.time() # Just in case it doesn't get set later.
@ -193,6 +193,7 @@ def upkeep():
def makeDriversDie():
"""Kills drivers."""
from . import drivers
log.info('Killing Driver objects.')
for driver in drivers._drivers.values():
driver.die()

108
test/test_drivers.py Normal file
View File

@ -0,0 +1,108 @@
##
# Copyright (c) 2019, Valentin Lorentz
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
from supybot.test import *
import supybot.ircdb as ircdb
import supybot.irclib as irclib
import supybot.drivers as drivers
class DriversTestCase(SupyTestCase):
def tearDown(self):
ircdb.networks.networks = {}
def testValidStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
with conf.supybot.networks.test.servers.context(
['example.com:6667', 'example.org:6667']):
driver = drivers.ServersMixin(irc)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
def testExpiredStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
timeFastForward(16)
with conf.supybot.networks.test.servers.context(
['example.com:6667']):
driver = drivers.ServersMixin(irc)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6667, False))
def testRescheduledStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
with conf.supybot.networks.test.servers.context(
['example.com:6667', 'example.org:6667']):
driver = drivers.ServersMixin(irc)
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
driver.die()
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))

View File

@ -550,6 +550,32 @@ class StsTestCase(SupyTestCase):
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
def testStsInCleartextConnectionInvalidDuration(self):
# "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=foo,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
def testStsInCleartextConnectionNoDuration(self):
# "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
class IrcTestCase(SupyTestCase):
def setUp(self):
self.irc = irclib.Irc('test')