Add support for storing STS policies.

If on an insecure connection: reconnect.
If on a secure connect: store it and do nothing else.

For now, stored STS policies are not read when connecting to an
insecure server.
This commit is contained in:
Valentin Lorentz 2019-12-07 23:33:04 +01:00
parent ff5edd95a3
commit ecc2c32950
9 changed files with 449 additions and 35 deletions

View File

@ -148,14 +148,14 @@ class Owner(callbacks.Plugin):
def _connect(self, network, serverPort=None, password='', ssl=False):
try:
group = conf.supybot.networks.get(network)
(server, port) = group.servers()[0]
group.servers()[0]
except (registry.NonExistentRegistryEntry, IndexError):
if serverPort is None:
raise ValueError('connect requires a (server, port) ' \
'if the network is not registered.')
conf.registerNetwork(network, password, ssl)
serverS = '%s:%s' % serverPort
conf.supybot.networks.get(network).servers.append(serverS)
server = '%s:%s' % serverPort
conf.supybot.networks.get(network).servers.append(server)
assert conf.supybot.networks.get(network).servers(), \
'No servers are set for the %s network.' % network
self.log.debug('Creating new Irc for %s.', network)

View File

@ -273,15 +273,17 @@ class Servers(registry.SpaceSeparatedListOfStrings):
return s
def convert(self, s):
from .drivers import Server
s = self.normalize(s)
(server, port) = s.rsplit(':', 1)
(hostname, port) = s.rsplit(':', 1)
# support for `[ipv6]:port` format
if server.startswith("[") and server.endswith("]"):
server = server[1:-1]
if hostname.startswith("[") and hostname.endswith("]"):
hostname = hostname[1:-1]
port = int(port)
return (server, port)
return Server(hostname, port, force_tls_verification=False)
def __call__(self):
L = registry.SpaceSeparatedListOfStrings.__call__(self)
@ -1039,6 +1041,12 @@ registerGlobalValue(supybot.databases.channels, 'filename',
for the channels database. This file will go into the directory specified
by the supybot.directories.conf variable.""")))
registerGroup(supybot.databases, 'networks')
registerGlobalValue(supybot.databases.networks, 'filename',
registry.String('networks.conf', _("""Determines what filename will be used
for the networks database. This file will go into the directory specified
by the supybot.directories.conf variable.""")))
# TODO This will need to do more in the future (such as making sure link.allow
# will let the link occur), but for now let's just leave it as this.
class ChannelSpecific(registry.Boolean):

View File

@ -35,12 +35,12 @@ Contains simple socket drivers. Asyncore bugged (haha, pun!) me.
from __future__ import division
import os
import sys
import time
import errno
import threading
import select
import socket
import sys
try:
import ipaddress # Python >= 3.3 or backported ipaddress
@ -221,7 +221,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def connect(self, **kwargs):
self.reconnect(reset=False, **kwargs)
def reconnect(self, wait=False, reset=True):
def reconnect(self, wait=False, reset=True, server=None):
self._attempt += 1
self.nextReconnectTime = None
if self.connected:
@ -242,7 +242,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if wait:
self.scheduleReconnect()
return
self.server = self._getNextServer()
self.server = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network)
socks_proxy = network_config.socksproxy()
try:
@ -252,20 +252,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
log.error('Cannot use socks proxy (SocksiPy not installed), '
'using direct connection instead.')
socks_proxy = ''
if socks_proxy:
address = self.server[0]
else:
try:
address = utils.net.getAddressFromHostname(self.server[0],
hostname = utils.net.getAddressFromHostname(
self.server.hostname,
attempt=self._attempt)
except (socket.gaierror, socket.error) as e:
drivers.log.connectError(self.currentServer, e)
self.scheduleReconnect()
return
port = self.server[1]
drivers.log.connect(self.currentServer)
try:
self.conn = utils.net.getSocket(address, port=port,
self.conn = utils.net.getSocket(
self.server.hostname,
port=self.server.port,
socks_proxy=socks_proxy,
vhost=conf.supybot.protocols.irc.vhost(),
vhostv6=conf.supybot.protocols.irc.vhostv6(),
@ -280,12 +280,12 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465
self.conn.connect((address, port))
if network_config.ssl():
self.conn.connect((self.server.hostname, self.server.port))
if network_config.ssl() or self.server.force_tls_verification:
self.starttls()
# Suppress this warning for loopback IPs.
targetip = address
targetip = hostname
if sys.version_info[0] < 3:
# Backported Python 2 ipaddress demands unicode instead of str
targetip = targetip.decode('utf-8')
@ -361,6 +361,15 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def name(self):
return '%s(%s)' % (self.__class__.__name__, self.irc)
def anyCertValidationEnabled(self):
"""Returns whether any kind of certificate validation is enabled, other
than Server.force_tls_verification."""
return any([
conf.supybot.protocols.ssl.verifyCertificates(),
network_config.ssl.serverFingerprints(),
network_config.ssl.authorityCertificate(),
])
def starttls(self):
assert 'ssl' in globals()
network_config = getattr(conf.supybot.networks, self.irc.network)
@ -373,6 +382,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.warning('Could not find cert file %s.' %
certfile)
certfile = None
if self.server.force_tls_verification \
and not self.anyCertValidationEnabled():
verifyCertificates = True
else:
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
if not verifyCertificates:
drivers.log.warning('Not checking SSL certificates, connections '
@ -381,7 +394,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
'to enable validity checks.')
try:
self.conn = utils.net.ssl_wrap_socket(self.conn,
logger=drivers.log, hostname=self.server[0],
logger=drivers.log,
hostname=self.server.hostname,
certfile=certfile,
verify=verifyCertificates,
trusted_fingerprints=network_config.ssl.serverFingerprints(),

View File

@ -33,10 +33,19 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
"""
import socket
from collections import namedtuple
from .. import conf, ircmsgs, log as supylog, utils
from ..utils import minisix
Server = namedtuple('Server', 'hostname port force_tls_verification')
# force_tls_verification=True implies two things:
# 1. force TLS to be enabled for this server
# 2. ensure there is some kind of verification. If the user did not enable
# any, use standard PKI validation.
_drivers = {}
_deadDrivers = set()
_newDrivers = []
@ -80,7 +89,7 @@ class ServersMixin(object):
assert self.servers, 'Servers value for %s is empty.' % \
self.networkGroup._name
server = self.servers.pop(0)
self.currentServer = '%s:%s' % server
self.currentServer = '%s:%s' % (server.hostname, server.port)
return server

View File

@ -496,6 +496,44 @@ class IrcChannel(object):
fd.write(os.linesep)
class IrcNetwork(object):
"""This class holds dynamic information about a network that should be
preserved across restarts."""
__slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes')
def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None):
self.name = name
self.stsPolicies = stsPolicies or {}
self.lastDisconnectTimes = lastDisconnectTimes or {}
def __repr__(self):
return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \
(self.__class__.__name, self.name, self.stsPolicy,
self.lastDisconnectTimes)
def addStsPolicy(self, server, stsPolicy):
assert isinstance(stsPolicy, str)
self.stsPolicies[server] = stsPolicy
def addDisconnection(self, server):
self.lastDisconnectTimes[server] = int(time.time())
def preserve(self, fd, indent=''):
def write(s):
fd.write(indent)
fd.write(s)
fd.write(os.linesep)
for (server, stsPolicy) in sorted(self.stsPolicies.items()):
write('stsPolicy %s %s' % (server, stsPolicy))
for (server, disconnectTime) in \
sorted(self.lastDisconnectTimes.items()):
write('lastDisconnectTime %s %s' % (server, disconnectTime))
fd.write(os.linesep)
class Creator(object):
__slots__ = ()
def badCommand(self, command, rest, lineno):
@ -615,6 +653,31 @@ class IrcChannelCreator(Creator):
IrcChannelCreator.name = None
class IrcNetworkCreator(Creator):
__slots__ = ('net', 'networks')
def __init__(self, networks):
self.net = IrcNetwork()
self.networks = networks
def network(self, rest, lineno):
self.net.name = rest
def stspolicy(self, rest, lineno):
(server, stsPolicy) = rest.split()
self.net.addStsPolicy(server, stsPolicy)
def lastdisconnecttime(self, rest, lineno):
(server, when) = rest.split()
when = int(when)
self.net.lastDisconnectTimes[server] = when
def finish(self):
if self.net.name:
self.networks.setNetwork(self.net)
self.name = None
class DuplicateHostmask(ValueError):
pass
@ -666,10 +729,8 @@ class UsersDictionary(utils.IterableMap):
"""Flushes the database to its file."""
if not self.noFlush:
if self.filename is not None:
L = list(self.users.items())
L.sort()
fd = utils.file.AtomicFile(self.filename)
for (id, u) in L:
for (id, u) in sorted(self.users.items()):
fd.write('user %s' % id)
fd.write(os.linesep)
u.preserve(fd, indent=' ')
@ -861,7 +922,7 @@ class ChannelsDictionary(utils.IterableMap):
if not self.noFlush:
if self.filename is not None:
fd = utils.file.AtomicFile(self.filename)
for (channel, c) in self.channels.items():
for (channel, c) in sorted(self.channels.items()):
fd.write('channel %s' % channel)
fd.write(os.linesep)
c.preserve(fd, indent=' ')
@ -907,6 +968,83 @@ class ChannelsDictionary(utils.IterableMap):
def items(self):
return self.channels.items()
class NetworksDictionary(utils.IterableMap):
__slots__ = ('noFlush', 'filename', 'networks')
def __init__(self):
self.noFlush = False
self.filename = None
self.networks = ircutils.IrcDict()
def open(self, filename):
self.noFlush = True
try:
self.filename = filename
reader = unpreserve.Reader(IrcNetworkCreator, self)
try:
reader.readFile(filename)
self.noFlush = False
self.flush()
except EnvironmentError as e:
log.error('Invalid network database, resetting to empty.')
log.error('Exact error: %s', utils.exnToString(e))
except Exception as e:
log.error('Invalid network database, resetting to empty.')
log.exception('Exact error:')
finally:
self.noFlush = False
def flush(self):
"""Flushes the network database to its file."""
if not self.noFlush:
if self.filename is not None:
fd = utils.file.AtomicFile(self.filename)
for (network, net) in sorted(self.networks.items()):
fd.write('network %s' % network)
fd.write(os.linesep)
net.preserve(fd, indent=' ')
fd.close()
else:
log.warning('NetworksDictionary.flush without self.filename.')
else:
log.debug('Not flushing NetworksDictionary because of noFlush.')
def close(self):
self.flush()
if self.flush in world.flushers:
world.flushers.remove(self.flush)
self.networks.clear()
def reload(self):
"""Reloads the network database from its file."""
if self.filename is not None:
self.networks.clear()
try:
self.open(self.filename)
except EnvironmentError as e:
log.warning('NetworksDictionary.reload failed: %s', e)
else:
log.warning('NetworksDictionary.reload without self.filename.')
def getNetwork(self, network):
"""Returns an IrcNetwork object for the given network."""
network = network.lower()
if network in self.networks:
return self.networks[network]
else:
c = IrcNetwork()
self.networks[network] = c
return c
def setNetwork(self, network, ircNetwork):
"""Sets a given network to the IrcNetwork object given."""
network = network.lower()
self.networks[network] = ircNetwork
self.flush()
def items(self):
return self.networks.items()
class IgnoresDB(object):
__slots__ = ('filename', 'hostmasks')
@ -996,6 +1134,14 @@ try:
except EnvironmentError as e:
log.warning('Couldn\'t open channel database: %s', e)
try:
networkFile = os.path.join(confDir,
conf.supybot.databases.networks.filename())
networks = NetworksDictionary()
networks.open(networkFile)
except EnvironmentError as e:
log.warning('Couldn\'t open network database: %s', e)
try:
ignoreFile = os.path.join(confDir,
conf.supybot.databases.ignores.filename())
@ -1006,8 +1152,9 @@ except EnvironmentError as e:
world.flushers.append(users.flush)
world.flushers.append(ignores.flush)
world.flushers.append(channels.flush)
world.flushers.append(networks.flush)
world.flushers.append(ignores.flush)
###

View File

@ -55,6 +55,7 @@ except ImportError:
scram = None
from . import conf, ircdb, ircmsgs, ircutils, log, utils, world
from .drivers import Server
from .utils.str import rsplit
from .utils.iter import chain
from .utils.structures import smallqueue, RingBuffer
@ -1468,14 +1469,42 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.network, caps)
self.capUpkeep()
def _onCapSts(self, policy):
parsed_policy = ircutils._parseStsPolicy(log, policy)
if parsed_policy is None:
# There was an error (and it was logged). Abort the connection.
self.driver.reconnect()
return
if not self.driver.ssl or not self.driver.anyCertValidationEnabled():
hostname = self.driver.server.hostname
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True))
else:
# TLS is enabled and certificate is verified; write the STS policy
# in stone.
# For future-proofing (because we don't want to write an invalid
# value), we write the raw policy received from the server instead
# of the parsed one.
ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.server.hostname, policy)
def _addCapabilities(self, capstring):
for item in capstring.split():
while item.startswith(('=', '~')):
item = item[1:]
if '=' in item:
(cap, value) = item.split('=', 1)
if cap == 'sts':
self._onCapSts(value)
self.state.capabilities_ls[cap] = value
else:
if item == 'sts':
log.error('Got "sts" capability without value. Aborting '
'connection.')
self.driver.reconnect()
self.state.capabilities_ls[item] = None
def doCapLs(self, msg):

View File

@ -931,6 +931,31 @@ class AuthenticateDecoder(object):
return base64.b64decode(b''.join(self.chunks))
def _parseStsPolicy(logger, policy):
parsed_policy = {}
for kv in policy.split(','):
if '=' in kv:
(k, v) = kv.split('=', 1)
parsed_policy[k] = v
else:
parsed_policy[kv] = None
for key in ('port', 'duration'):
if parsed_policy.get(key) is None:
logger.error('Missing or empty "%s" key in STS policy.'
'Aborting connection.', key)
return None
try:
parsed_policy[key] = int(parsed_policy[key])
except ValueError:
logger.error('Expected integer as value for key "%s" in STS '
'policy, got %r instead. Aborting connection.',
key, parsed_policy[key])
return None
return parsed_policy
numerics = {
# <= 2.10
# Reply

View File

@ -31,11 +31,13 @@ from supybot.test import *
import os
import unittest
import unittest.mock
import supybot.conf as conf
import supybot.world as world
import supybot.ircdb as ircdb
import supybot.ircutils as ircutils
from supybot.utils.minisix import io
class IrcdbTestCase(SupyTestCase):
def setUp(self):
@ -347,6 +349,50 @@ class IrcChannelTestCase(IrcdbTestCase):
c.removeBan(banmask)
self.assertFalse(c.checkIgnored(prefix))
class IrcNetworkTestCase(IrcdbTestCase):
def testDefaults(self):
n = ircdb.IrcNetwork()
self.assertIsNone(n.name)
self.assertEqual(n.stsPolicies, {})
self.assertEqual(n.lastDisconnectTimes, {})
def testStsPolicy(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'bar')
n.addStsPolicy('baz', 'qux')
self.assertEqual(n.stsPolicies, {
'foo': 'bar',
'baz': 'qux',
})
def testAddDisconnection(self):
n = ircdb.IrcNetwork()
min_ = int(time.time())
n.addDisconnection('foo')
max_ = int(time.time())
self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_)
def testPreserve(self):
n = ircdb.IrcNetwork('foonet')
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
disconnect_time_baz = n.lastDisconnectTimes['baz']
fd = io.StringIO()
n.preserve(fd, indent=' ')
fd.seek(0)
self.assertCountEqual(fd.read().split('\n'), [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
'',
])
class UsersDictionaryTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'UsersDictionaryTestCase.conf')
@ -401,6 +447,89 @@ class UsersDictionaryTestCase(IrcdbTestCase):
self.assertRaises(ValueError, self.users.setUser, u2)
class NetworksDictionaryTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'NetworksDictionaryTestCase.conf')
def setUp(self):
try:
os.remove(self.filename)
except:
pass
self.networks = ircdb.NetworksDictionary()
IrcdbTestCase.setUp(self)
def testGetSetNetwork(self):
self.assertEqual(self.networks.getNetwork('foo').name, None)
n = ircdb.IrcNetwork()
n.name = 'foo'
self.networks.setNetwork('foo', n)
self.assertEqual(self.networks.getNetwork('foo').name, 'foo')
def testPreserveOne(self):
n = ircdb.IrcNetwork('foonet')
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
disconnect_time_baz = n.lastDisconnectTimes['baz']
self.networks.setNetwork('foonet', n)
fd = io.StringIO()
fd.close = lambda: None
self.networks.filename = 'blah'
original_Atomicfile = utils.file.AtomicFile
with unittest.mock.patch(
'supybot.utils.file.AtomicFile', return_value=fd):
self.networks.flush()
lines = fd.getvalue().split('\n')
self.assertEqual(lines.pop(0), 'network foonet')
self.assertCountEqual(lines, [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
'',
])
def testPreserveThree(self):
n = ircdb.IrcNetwork('foonet')
n.addStsPolicy('foo', 'sts1')
self.networks.setNetwork('foonet', n)
n = ircdb.IrcNetwork('barnet')
n.addStsPolicy('bar', 'sts2')
self.networks.setNetwork('barnet', n)
n = ircdb.IrcNetwork('baznet')
n.addStsPolicy('baz', 'sts3')
self.networks.setNetwork('baznet', n)
fd = io.StringIO()
fd.close = lambda: None
self.networks.filename = 'blah'
original_Atomicfile = utils.file.AtomicFile
with unittest.mock.patch(
'supybot.utils.file.AtomicFile', return_value=fd):
self.networks.flush()
fd.seek(0)
self.assertEqual(fd.getvalue(),
'network barnet\n'
' stsPolicy bar sts2\n'
'\n'
'network baznet\n'
' stsPolicy baz sts3\n'
'\n'
'network foonet\n'
' stsPolicy foo sts1\n'
'\n'
)
class CheckCapabilityTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'CheckCapabilityTestCase.conf')

View File

@ -27,14 +27,15 @@
# POSSIBILITY OF SUCH DAMAGE.
###
from supybot.test import *
import copy
import pickle
import warnings
import unittest.mock
from supybot.test import *
import supybot.conf as conf
import supybot.irclib as irclib
import supybot.drivers as drivers
import supybot.ircmsgs as ircmsgs
import supybot.ircutils as ircutils
@ -497,6 +498,58 @@ class IrcCapsTestCase(SupyTestCase):
self.assertEqual(m.args[0], 'REQ', m)
self.assertEqual(m.args[1], 'b'*400)
class StsTestCase(SupyTestCase):
def setUp(self):
self.irc = irclib.Irc('test')
m = self.irc.takeMsg()
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'NICK', 'Expected NICK, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m)
self.irc.driver = unittest.mock.Mock()
def tearDown(self):
ircdb.networks.networks = {}
def testStsInSecureConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
'irc.test': 'duration=42,port=6697'})
self.irc.driver.reconnect.assert_not_called()
def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
def testStsInCleartextConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
class IrcTestCase(SupyTestCase):
def setUp(self):
self.irc = irclib.Irc('test')