Apply STS policies when connecting to a server.

This commit is contained in:
Valentin Lorentz 2019-12-08 15:54:48 +01:00
parent ecc2c32950
commit 51ff013fcc
8 changed files with 210 additions and 26 deletions

View File

@ -225,6 +225,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self._attempt += 1 self._attempt += 1
self.nextReconnectTime = None self.nextReconnectTime = None
if self.connected: if self.connected:
self.onDisconnect()
drivers.log.reconnect(self.irc.network) drivers.log.reconnect(self.irc.network)
if self in self._instances: if self in self._instances:
self._instances.remove(self) self._instances.remove(self)
@ -242,7 +243,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if wait: if wait:
self.scheduleReconnect() self.scheduleReconnect()
return return
self.server = server or self._getNextServer() self.currentServer = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network) network_config = getattr(conf.supybot.networks, self.irc.network)
socks_proxy = network_config.socksproxy() socks_proxy = network_config.socksproxy()
try: try:
@ -255,7 +256,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
else: else:
try: try:
hostname = utils.net.getAddressFromHostname( hostname = utils.net.getAddressFromHostname(
self.server.hostname, self.currentServer.hostname,
attempt=self._attempt) attempt=self._attempt)
except (socket.gaierror, socket.error) as e: except (socket.gaierror, socket.error) as e:
drivers.log.connectError(self.currentServer, e) drivers.log.connectError(self.currentServer, e)
@ -264,8 +265,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.connect(self.currentServer) drivers.log.connect(self.currentServer)
try: try:
self.conn = utils.net.getSocket( self.conn = utils.net.getSocket(
self.server.hostname, self.currentServer.hostname,
port=self.server.port, port=self.currentServer.port,
socks_proxy=socks_proxy, socks_proxy=socks_proxy,
vhost=conf.supybot.protocols.irc.vhost(), vhost=conf.supybot.protocols.irc.vhost(),
vhostv6=conf.supybot.protocols.irc.vhostv6(), vhostv6=conf.supybot.protocols.irc.vhostv6(),
@ -280,8 +281,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try: try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS. # Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465 # See http://stackoverflow.com/q/16136916/539465
self.conn.connect((self.server.hostname, self.server.port)) self.conn.connect(
if network_config.ssl() or self.server.force_tls_verification: (self.currentServer.hostname, self.currentServer.port))
if network_config.ssl() \
or self.currentServer.force_tls_verification:
self.starttls() self.starttls()
# Suppress this warning for loopback IPs. # Suppress this warning for loopback IPs.
@ -351,6 +354,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if self.writeCheckTime is not None: if self.writeCheckTime is not None:
self.writeCheckTime = None self.writeCheckTime = None
drivers.log.die(self.irc) drivers.log.die(self.irc)
drivers.IrcDriver.die(self)
drivers.ServersMixin.die(self)
def _reallyDie(self): def _reallyDie(self):
if self.conn is not None: if self.conn is not None:
@ -382,7 +387,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.warning('Could not find cert file %s.' % drivers.log.warning('Could not find cert file %s.' %
certfile) certfile)
certfile = None certfile = None
if self.server.force_tls_verification \ if self.currentServer.force_tls_verification \
and not self.anyCertValidationEnabled(): and not self.anyCertValidationEnabled():
verifyCertificates = True verifyCertificates = True
else: else:
@ -395,7 +400,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try: try:
self.conn = utils.net.ssl_wrap_socket(self.conn, self.conn = utils.net.ssl_wrap_socket(self.conn,
logger=drivers.log, logger=drivers.log,
hostname=self.server.hostname, hostname=self.currentServer.hostname,
certfile=certfile, certfile=certfile,
verify=verifyCertificates, verify=verifyCertificates,
trusted_fingerprints=network_config.ssl.serverFingerprints(), trusted_fingerprints=network_config.ssl.serverFingerprints(),

View File

@ -32,10 +32,11 @@
Contains various drivers (network, file, and otherwise) for using IRC objects. Contains various drivers (network, file, and otherwise) for using IRC objects.
""" """
import time
import socket import socket
from collections import namedtuple from collections import namedtuple
from .. import conf, ircmsgs, log as supylog, utils from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
from ..utils import minisix from ..utils import minisix
@ -73,6 +74,7 @@ class IrcDriver(object):
class ServersMixin(object): class ServersMixin(object):
def __init__(self, irc, servers=()): def __init__(self, irc, servers=()):
self.networkName = irc.network
self.networkGroup = conf.supybot.networks.get(irc.network) self.networkGroup = conf.supybot.networks.get(irc.network)
self.servers = servers self.servers = servers
super(ServersMixin, self).__init__() super(ServersMixin, self).__init__()
@ -89,8 +91,36 @@ class ServersMixin(object):
assert self.servers, 'Servers value for %s is empty.' % \ assert self.servers, 'Servers value for %s is empty.' % \
self.networkGroup._name self.networkGroup._name
server = self.servers.pop(0) server = self.servers.pop(0)
self.currentServer = '%s:%s' % (server.hostname, server.port) self.currentServer = self._applyStsPolicy(server)
return server return self.currentServer
def _applyStsPolicy(self, server):
network = ircdb.networks.getNetwork(self.networkName)
policy = network.stsPolicies.get(server.hostname)
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
if policy is None or lastDisconnect is None:
return server
# The policy was stored, which means it was received on a secure
# connection.
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
if lastDisconnect + policy['duration'] < time.time():
network.expireStsPolicy(server.hostname)
return server
# Change the port, and force TLS verification, as required by the STS
# specification.
return Server(server.hostname, policy['port'],
force_tls_verification=True)
def die(self):
self.onDisconnect()
def onDisconnect(self):
network = ircdb.networks.getNetwork(self.networkName)
network.addDisconnection(self.currentServer.hostname)
def empty(): def empty():
@ -138,7 +168,8 @@ def run():
class Log(object): class Log(object):
"""This is used to have a nice, consistent interface for drivers to use.""" """This is used to have a nice, consistent interface for drivers to use."""
def connect(self, server): def connect(self, server):
self.info('Connecting to %s.', server) self.info('Connecting to %s:%s.',
server.hostname, server.port)
def connectError(self, server, e): def connectError(self, server, e):
if isinstance(e, Exception): if isinstance(e, Exception):
@ -146,7 +177,8 @@ class Log(object):
e = e.args[1] e = e.args[1]
else: else:
e = utils.exnToString(e) e = utils.exnToString(e)
self.warning('Error connecting to %s: %s', server, e) self.warning('Error connecting to %s:%s: %s',
server.hostname, server.port, e)
def disconnect(self, server, e=None): def disconnect(self, server, e=None):
if e: if e:
@ -156,7 +188,8 @@ class Log(object):
e = str(e) e = str(e)
if not e.endswith('.'): if not e.endswith('.'):
e += '.' e += '.'
self.warning('Disconnect from %s: %s', server, e) self.warning('Disconnect from %s:%s: %s',
server.hostname, server.port, e)
else: else:
self.info('Disconnect from %s.', server) self.info('Disconnect from %s.', server)

View File

@ -515,6 +515,10 @@ class IrcNetwork(object):
assert isinstance(stsPolicy, str) assert isinstance(stsPolicy, str)
self.stsPolicies[server] = stsPolicy self.stsPolicies[server] = stsPolicy
def expireStsPolicy(self, server):
if server in self.stsPolicies:
del self.stsPolicies[server]
def addDisconnection(self, server): def addDisconnection(self, server):
self.lastDisconnectTimes[server] = int(time.time()) self.lastDisconnectTimes[server] = int(time.time())
@ -674,8 +678,8 @@ class IrcNetworkCreator(Creator):
def finish(self): def finish(self):
if self.net.name: if self.net.name:
self.networks.setNetwork(self.net) self.networks.setNetwork(self.net.name, self.net)
self.name = None self.net = IrcNetwork()
class DuplicateHostmask(ValueError): class DuplicateHostmask(ValueError):

View File

@ -1470,19 +1470,16 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.capUpkeep() self.capUpkeep()
def _onCapSts(self, policy): def _onCapSts(self, policy):
parsed_policy = ircutils._parseStsPolicy(log, policy) secure_connection = self.driver.ssl and self.driver.anyCertValidationEnabled()
parsed_policy = ircutils.parseStsPolicy(
log, policy, parseDuration=secure_connection)
if parsed_policy is None: if parsed_policy is None:
# There was an error (and it was logged). Abort the connection. # There was an error (and it was logged). Abort the connection.
self.driver.reconnect() self.driver.reconnect()
return return
if not self.driver.ssl or not self.driver.anyCertValidationEnabled(): if secure_connection:
hostname = self.driver.server.hostname
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True))
else:
# TLS is enabled and certificate is verified; write the STS policy # TLS is enabled and certificate is verified; write the STS policy
# in stone. # in stone.
# For future-proofing (because we don't want to write an invalid # For future-proofing (because we don't want to write an invalid
@ -1490,6 +1487,12 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
# of the parsed one. # of the parsed one.
ircdb.networks.getNetwork(self.network).addStsPolicy( ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.server.hostname, policy) self.driver.server.hostname, policy)
else:
hostname = self.driver.server.hostname
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True))
def _addCapabilities(self, capstring): def _addCapabilities(self, capstring):
for item in capstring.split(): for item in capstring.split():

View File

@ -931,7 +931,7 @@ class AuthenticateDecoder(object):
return base64.b64decode(b''.join(self.chunks)) return base64.b64decode(b''.join(self.chunks))
def _parseStsPolicy(logger, policy): def parseStsPolicy(logger, policy, parseDuration):
parsed_policy = {} parsed_policy = {}
for kv in policy.split(','): for kv in policy.split(','):
if '=' in kv: if '=' in kv:
@ -941,6 +941,10 @@ def _parseStsPolicy(logger, policy):
parsed_policy[kv] = None parsed_policy[kv] = None
for key in ('port', 'duration'): for key in ('port', 'duration'):
if key == 'duration' and not parseDuration:
if key in parsed_policy:
del parsed_policy[key]
continue
if parsed_policy.get(key) is None: if parsed_policy.get(key) is None:
logger.error('Missing or empty "%s" key in STS policy.' logger.error('Missing or empty "%s" key in STS policy.'
'Aborting connection.', key) 'Aborting connection.', key)

View File

@ -42,7 +42,7 @@ import multiprocessing
import re import re
from . import conf, drivers, ircutils, log, registry from . import conf, ircutils, log, registry
from .utils import minisix from .utils import minisix
startedAt = time.time() # Just in case it doesn't get set later. startedAt = time.time() # Just in case it doesn't get set later.
@ -193,6 +193,7 @@ def upkeep():
def makeDriversDie(): def makeDriversDie():
"""Kills drivers.""" """Kills drivers."""
from . import drivers
log.info('Killing Driver objects.') log.info('Killing Driver objects.')
for driver in drivers._drivers.values(): for driver in drivers._drivers.values():
driver.die() driver.die()

108
test/test_drivers.py Normal file
View File

@ -0,0 +1,108 @@
##
# Copyright (c) 2019, Valentin Lorentz
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
from supybot.test import *
import supybot.ircdb as ircdb
import supybot.irclib as irclib
import supybot.drivers as drivers
class DriversTestCase(SupyTestCase):
def tearDown(self):
ircdb.networks.networks = {}
def testValidStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
with conf.supybot.networks.test.servers.context(
['example.com:6667', 'example.org:6667']):
driver = drivers.ServersMixin(irc)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
def testExpiredStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
timeFastForward(16)
with conf.supybot.networks.test.servers.context(
['example.com:6667']):
driver = drivers.ServersMixin(irc)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6667, False))
def testRescheduledStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
with conf.supybot.networks.test.servers.context(
['example.com:6667', 'example.org:6667']):
driver = drivers.ServersMixin(irc)
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
driver.die()
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))

View File

@ -550,6 +550,32 @@ class StsTestCase(SupyTestCase):
self.irc.driver.reconnect.assert_called_once_with( self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True)) server=drivers.Server('irc.test', 6697, True))
def testStsInCleartextConnectionInvalidDuration(self):
# "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=foo,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
def testStsInCleartextConnectionNoDuration(self):
# "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
class IrcTestCase(SupyTestCase): class IrcTestCase(SupyTestCase):
def setUp(self): def setUp(self):
self.irc = irclib.Irc('test') self.irc = irclib.Irc('test')