Add support for storing STS policies.

If on an insecure connection: reconnect.
If on a secure connect: store it and do nothing else.

For now, stored STS policies are not read when connecting to an
insecure server.
This commit is contained in:
Valentin Lorentz 2019-12-07 23:33:04 +01:00
parent ff5edd95a3
commit ecc2c32950
9 changed files with 449 additions and 35 deletions

View File

@ -148,14 +148,14 @@ class Owner(callbacks.Plugin):
def _connect(self, network, serverPort=None, password='', ssl=False): def _connect(self, network, serverPort=None, password='', ssl=False):
try: try:
group = conf.supybot.networks.get(network) group = conf.supybot.networks.get(network)
(server, port) = group.servers()[0] group.servers()[0]
except (registry.NonExistentRegistryEntry, IndexError): except (registry.NonExistentRegistryEntry, IndexError):
if serverPort is None: if serverPort is None:
raise ValueError('connect requires a (server, port) ' \ raise ValueError('connect requires a (server, port) ' \
'if the network is not registered.') 'if the network is not registered.')
conf.registerNetwork(network, password, ssl) conf.registerNetwork(network, password, ssl)
serverS = '%s:%s' % serverPort server = '%s:%s' % serverPort
conf.supybot.networks.get(network).servers.append(serverS) conf.supybot.networks.get(network).servers.append(server)
assert conf.supybot.networks.get(network).servers(), \ assert conf.supybot.networks.get(network).servers(), \
'No servers are set for the %s network.' % network 'No servers are set for the %s network.' % network
self.log.debug('Creating new Irc for %s.', network) self.log.debug('Creating new Irc for %s.', network)

View File

@ -273,15 +273,17 @@ class Servers(registry.SpaceSeparatedListOfStrings):
return s return s
def convert(self, s): def convert(self, s):
from .drivers import Server
s = self.normalize(s) s = self.normalize(s)
(server, port) = s.rsplit(':', 1) (hostname, port) = s.rsplit(':', 1)
# support for `[ipv6]:port` format # support for `[ipv6]:port` format
if server.startswith("[") and server.endswith("]"): if hostname.startswith("[") and hostname.endswith("]"):
server = server[1:-1] hostname = hostname[1:-1]
port = int(port) port = int(port)
return (server, port) return Server(hostname, port, force_tls_verification=False)
def __call__(self): def __call__(self):
L = registry.SpaceSeparatedListOfStrings.__call__(self) L = registry.SpaceSeparatedListOfStrings.__call__(self)
@ -1039,6 +1041,12 @@ registerGlobalValue(supybot.databases.channels, 'filename',
for the channels database. This file will go into the directory specified for the channels database. This file will go into the directory specified
by the supybot.directories.conf variable."""))) by the supybot.directories.conf variable.""")))
registerGroup(supybot.databases, 'networks')
registerGlobalValue(supybot.databases.networks, 'filename',
registry.String('networks.conf', _("""Determines what filename will be used
for the networks database. This file will go into the directory specified
by the supybot.directories.conf variable.""")))
# TODO This will need to do more in the future (such as making sure link.allow # TODO This will need to do more in the future (such as making sure link.allow
# will let the link occur), but for now let's just leave it as this. # will let the link occur), but for now let's just leave it as this.
class ChannelSpecific(registry.Boolean): class ChannelSpecific(registry.Boolean):

View File

@ -35,12 +35,12 @@ Contains simple socket drivers. Asyncore bugged (haha, pun!) me.
from __future__ import division from __future__ import division
import os import os
import sys
import time import time
import errno import errno
import threading import threading
import select import select
import socket import socket
import sys
try: try:
import ipaddress # Python >= 3.3 or backported ipaddress import ipaddress # Python >= 3.3 or backported ipaddress
@ -221,7 +221,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def connect(self, **kwargs): def connect(self, **kwargs):
self.reconnect(reset=False, **kwargs) self.reconnect(reset=False, **kwargs)
def reconnect(self, wait=False, reset=True): def reconnect(self, wait=False, reset=True, server=None):
self._attempt += 1 self._attempt += 1
self.nextReconnectTime = None self.nextReconnectTime = None
if self.connected: if self.connected:
@ -242,7 +242,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if wait: if wait:
self.scheduleReconnect() self.scheduleReconnect()
return return
self.server = self._getNextServer() self.server = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network) network_config = getattr(conf.supybot.networks, self.irc.network)
socks_proxy = network_config.socksproxy() socks_proxy = network_config.socksproxy()
try: try:
@ -252,20 +252,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
log.error('Cannot use socks proxy (SocksiPy not installed), ' log.error('Cannot use socks proxy (SocksiPy not installed), '
'using direct connection instead.') 'using direct connection instead.')
socks_proxy = '' socks_proxy = ''
if socks_proxy:
address = self.server[0]
else: else:
try: try:
address = utils.net.getAddressFromHostname(self.server[0], hostname = utils.net.getAddressFromHostname(
attempt=self._attempt) self.server.hostname,
attempt=self._attempt)
except (socket.gaierror, socket.error) as e: except (socket.gaierror, socket.error) as e:
drivers.log.connectError(self.currentServer, e) drivers.log.connectError(self.currentServer, e)
self.scheduleReconnect() self.scheduleReconnect()
return return
port = self.server[1]
drivers.log.connect(self.currentServer) drivers.log.connect(self.currentServer)
try: try:
self.conn = utils.net.getSocket(address, port=port, self.conn = utils.net.getSocket(
self.server.hostname,
port=self.server.port,
socks_proxy=socks_proxy, socks_proxy=socks_proxy,
vhost=conf.supybot.protocols.irc.vhost(), vhost=conf.supybot.protocols.irc.vhost(),
vhostv6=conf.supybot.protocols.irc.vhostv6(), vhostv6=conf.supybot.protocols.irc.vhostv6(),
@ -280,12 +280,12 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try: try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS. # Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465 # See http://stackoverflow.com/q/16136916/539465
self.conn.connect((address, port)) self.conn.connect((self.server.hostname, self.server.port))
if network_config.ssl(): if network_config.ssl() or self.server.force_tls_verification:
self.starttls() self.starttls()
# Suppress this warning for loopback IPs. # Suppress this warning for loopback IPs.
targetip = address targetip = hostname
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
# Backported Python 2 ipaddress demands unicode instead of str # Backported Python 2 ipaddress demands unicode instead of str
targetip = targetip.decode('utf-8') targetip = targetip.decode('utf-8')
@ -361,6 +361,15 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def name(self): def name(self):
return '%s(%s)' % (self.__class__.__name__, self.irc) return '%s(%s)' % (self.__class__.__name__, self.irc)
def anyCertValidationEnabled(self):
"""Returns whether any kind of certificate validation is enabled, other
than Server.force_tls_verification."""
return any([
conf.supybot.protocols.ssl.verifyCertificates(),
network_config.ssl.serverFingerprints(),
network_config.ssl.authorityCertificate(),
])
def starttls(self): def starttls(self):
assert 'ssl' in globals() assert 'ssl' in globals()
network_config = getattr(conf.supybot.networks, self.irc.network) network_config = getattr(conf.supybot.networks, self.irc.network)
@ -373,15 +382,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.warning('Could not find cert file %s.' % drivers.log.warning('Could not find cert file %s.' %
certfile) certfile)
certfile = None certfile = None
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates() if self.server.force_tls_verification \
if not verifyCertificates: and not self.anyCertValidationEnabled():
drivers.log.warning('Not checking SSL certificates, connections ' verifyCertificates = True
'are vulnerable to man-in-the-middle attacks. Set ' else:
'supybot.protocols.ssl.verifyCertificates to "true" ' verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
'to enable validity checks.') if not verifyCertificates:
drivers.log.warning('Not checking SSL certificates, connections '
'are vulnerable to man-in-the-middle attacks. Set '
'supybot.protocols.ssl.verifyCertificates to "true" '
'to enable validity checks.')
try: try:
self.conn = utils.net.ssl_wrap_socket(self.conn, self.conn = utils.net.ssl_wrap_socket(self.conn,
logger=drivers.log, hostname=self.server[0], logger=drivers.log,
hostname=self.server.hostname,
certfile=certfile, certfile=certfile,
verify=verifyCertificates, verify=verifyCertificates,
trusted_fingerprints=network_config.ssl.serverFingerprints(), trusted_fingerprints=network_config.ssl.serverFingerprints(),

View File

@ -33,10 +33,19 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
""" """
import socket import socket
from collections import namedtuple
from .. import conf, ircmsgs, log as supylog, utils from .. import conf, ircmsgs, log as supylog, utils
from ..utils import minisix from ..utils import minisix
Server = namedtuple('Server', 'hostname port force_tls_verification')
# force_tls_verification=True implies two things:
# 1. force TLS to be enabled for this server
# 2. ensure there is some kind of verification. If the user did not enable
# any, use standard PKI validation.
_drivers = {} _drivers = {}
_deadDrivers = set() _deadDrivers = set()
_newDrivers = [] _newDrivers = []
@ -80,7 +89,7 @@ class ServersMixin(object):
assert self.servers, 'Servers value for %s is empty.' % \ assert self.servers, 'Servers value for %s is empty.' % \
self.networkGroup._name self.networkGroup._name
server = self.servers.pop(0) server = self.servers.pop(0)
self.currentServer = '%s:%s' % server self.currentServer = '%s:%s' % (server.hostname, server.port)
return server return server

View File

@ -496,6 +496,44 @@ class IrcChannel(object):
fd.write(os.linesep) fd.write(os.linesep)
class IrcNetwork(object):
"""This class holds dynamic information about a network that should be
preserved across restarts."""
__slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes')
def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None):
self.name = name
self.stsPolicies = stsPolicies or {}
self.lastDisconnectTimes = lastDisconnectTimes or {}
def __repr__(self):
return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \
(self.__class__.__name, self.name, self.stsPolicy,
self.lastDisconnectTimes)
def addStsPolicy(self, server, stsPolicy):
assert isinstance(stsPolicy, str)
self.stsPolicies[server] = stsPolicy
def addDisconnection(self, server):
self.lastDisconnectTimes[server] = int(time.time())
def preserve(self, fd, indent=''):
def write(s):
fd.write(indent)
fd.write(s)
fd.write(os.linesep)
for (server, stsPolicy) in sorted(self.stsPolicies.items()):
write('stsPolicy %s %s' % (server, stsPolicy))
for (server, disconnectTime) in \
sorted(self.lastDisconnectTimes.items()):
write('lastDisconnectTime %s %s' % (server, disconnectTime))
fd.write(os.linesep)
class Creator(object): class Creator(object):
__slots__ = () __slots__ = ()
def badCommand(self, command, rest, lineno): def badCommand(self, command, rest, lineno):
@ -615,6 +653,31 @@ class IrcChannelCreator(Creator):
IrcChannelCreator.name = None IrcChannelCreator.name = None
class IrcNetworkCreator(Creator):
__slots__ = ('net', 'networks')
def __init__(self, networks):
self.net = IrcNetwork()
self.networks = networks
def network(self, rest, lineno):
self.net.name = rest
def stspolicy(self, rest, lineno):
(server, stsPolicy) = rest.split()
self.net.addStsPolicy(server, stsPolicy)
def lastdisconnecttime(self, rest, lineno):
(server, when) = rest.split()
when = int(when)
self.net.lastDisconnectTimes[server] = when
def finish(self):
if self.net.name:
self.networks.setNetwork(self.net)
self.name = None
class DuplicateHostmask(ValueError): class DuplicateHostmask(ValueError):
pass pass
@ -666,10 +729,8 @@ class UsersDictionary(utils.IterableMap):
"""Flushes the database to its file.""" """Flushes the database to its file."""
if not self.noFlush: if not self.noFlush:
if self.filename is not None: if self.filename is not None:
L = list(self.users.items())
L.sort()
fd = utils.file.AtomicFile(self.filename) fd = utils.file.AtomicFile(self.filename)
for (id, u) in L: for (id, u) in sorted(self.users.items()):
fd.write('user %s' % id) fd.write('user %s' % id)
fd.write(os.linesep) fd.write(os.linesep)
u.preserve(fd, indent=' ') u.preserve(fd, indent=' ')
@ -861,7 +922,7 @@ class ChannelsDictionary(utils.IterableMap):
if not self.noFlush: if not self.noFlush:
if self.filename is not None: if self.filename is not None:
fd = utils.file.AtomicFile(self.filename) fd = utils.file.AtomicFile(self.filename)
for (channel, c) in self.channels.items(): for (channel, c) in sorted(self.channels.items()):
fd.write('channel %s' % channel) fd.write('channel %s' % channel)
fd.write(os.linesep) fd.write(os.linesep)
c.preserve(fd, indent=' ') c.preserve(fd, indent=' ')
@ -907,6 +968,83 @@ class ChannelsDictionary(utils.IterableMap):
def items(self): def items(self):
return self.channels.items() return self.channels.items()
class NetworksDictionary(utils.IterableMap):
__slots__ = ('noFlush', 'filename', 'networks')
def __init__(self):
self.noFlush = False
self.filename = None
self.networks = ircutils.IrcDict()
def open(self, filename):
self.noFlush = True
try:
self.filename = filename
reader = unpreserve.Reader(IrcNetworkCreator, self)
try:
reader.readFile(filename)
self.noFlush = False
self.flush()
except EnvironmentError as e:
log.error('Invalid network database, resetting to empty.')
log.error('Exact error: %s', utils.exnToString(e))
except Exception as e:
log.error('Invalid network database, resetting to empty.')
log.exception('Exact error:')
finally:
self.noFlush = False
def flush(self):
"""Flushes the network database to its file."""
if not self.noFlush:
if self.filename is not None:
fd = utils.file.AtomicFile(self.filename)
for (network, net) in sorted(self.networks.items()):
fd.write('network %s' % network)
fd.write(os.linesep)
net.preserve(fd, indent=' ')
fd.close()
else:
log.warning('NetworksDictionary.flush without self.filename.')
else:
log.debug('Not flushing NetworksDictionary because of noFlush.')
def close(self):
self.flush()
if self.flush in world.flushers:
world.flushers.remove(self.flush)
self.networks.clear()
def reload(self):
"""Reloads the network database from its file."""
if self.filename is not None:
self.networks.clear()
try:
self.open(self.filename)
except EnvironmentError as e:
log.warning('NetworksDictionary.reload failed: %s', e)
else:
log.warning('NetworksDictionary.reload without self.filename.')
def getNetwork(self, network):
"""Returns an IrcNetwork object for the given network."""
network = network.lower()
if network in self.networks:
return self.networks[network]
else:
c = IrcNetwork()
self.networks[network] = c
return c
def setNetwork(self, network, ircNetwork):
"""Sets a given network to the IrcNetwork object given."""
network = network.lower()
self.networks[network] = ircNetwork
self.flush()
def items(self):
return self.networks.items()
class IgnoresDB(object): class IgnoresDB(object):
__slots__ = ('filename', 'hostmasks') __slots__ = ('filename', 'hostmasks')
@ -996,6 +1134,14 @@ try:
except EnvironmentError as e: except EnvironmentError as e:
log.warning('Couldn\'t open channel database: %s', e) log.warning('Couldn\'t open channel database: %s', e)
try:
networkFile = os.path.join(confDir,
conf.supybot.databases.networks.filename())
networks = NetworksDictionary()
networks.open(networkFile)
except EnvironmentError as e:
log.warning('Couldn\'t open network database: %s', e)
try: try:
ignoreFile = os.path.join(confDir, ignoreFile = os.path.join(confDir,
conf.supybot.databases.ignores.filename()) conf.supybot.databases.ignores.filename())
@ -1006,8 +1152,9 @@ except EnvironmentError as e:
world.flushers.append(users.flush) world.flushers.append(users.flush)
world.flushers.append(ignores.flush)
world.flushers.append(channels.flush) world.flushers.append(channels.flush)
world.flushers.append(networks.flush)
world.flushers.append(ignores.flush)
### ###

View File

@ -55,6 +55,7 @@ except ImportError:
scram = None scram = None
from . import conf, ircdb, ircmsgs, ircutils, log, utils, world from . import conf, ircdb, ircmsgs, ircutils, log, utils, world
from .drivers import Server
from .utils.str import rsplit from .utils.str import rsplit
from .utils.iter import chain from .utils.iter import chain
from .utils.structures import smallqueue, RingBuffer from .utils.structures import smallqueue, RingBuffer
@ -1468,14 +1469,42 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.network, caps) self.network, caps)
self.capUpkeep() self.capUpkeep()
def _onCapSts(self, policy):
parsed_policy = ircutils._parseStsPolicy(log, policy)
if parsed_policy is None:
# There was an error (and it was logged). Abort the connection.
self.driver.reconnect()
return
if not self.driver.ssl or not self.driver.anyCertValidationEnabled():
hostname = self.driver.server.hostname
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True))
else:
# TLS is enabled and certificate is verified; write the STS policy
# in stone.
# For future-proofing (because we don't want to write an invalid
# value), we write the raw policy received from the server instead
# of the parsed one.
ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.server.hostname, policy)
def _addCapabilities(self, capstring): def _addCapabilities(self, capstring):
for item in capstring.split(): for item in capstring.split():
while item.startswith(('=', '~')): while item.startswith(('=', '~')):
item = item[1:] item = item[1:]
if '=' in item: if '=' in item:
(cap, value) = item.split('=', 1) (cap, value) = item.split('=', 1)
if cap == 'sts':
self._onCapSts(value)
self.state.capabilities_ls[cap] = value self.state.capabilities_ls[cap] = value
else: else:
if item == 'sts':
log.error('Got "sts" capability without value. Aborting '
'connection.')
self.driver.reconnect()
self.state.capabilities_ls[item] = None self.state.capabilities_ls[item] = None
def doCapLs(self, msg): def doCapLs(self, msg):

View File

@ -931,6 +931,31 @@ class AuthenticateDecoder(object):
return base64.b64decode(b''.join(self.chunks)) return base64.b64decode(b''.join(self.chunks))
def _parseStsPolicy(logger, policy):
parsed_policy = {}
for kv in policy.split(','):
if '=' in kv:
(k, v) = kv.split('=', 1)
parsed_policy[k] = v
else:
parsed_policy[kv] = None
for key in ('port', 'duration'):
if parsed_policy.get(key) is None:
logger.error('Missing or empty "%s" key in STS policy.'
'Aborting connection.', key)
return None
try:
parsed_policy[key] = int(parsed_policy[key])
except ValueError:
logger.error('Expected integer as value for key "%s" in STS '
'policy, got %r instead. Aborting connection.',
key, parsed_policy[key])
return None
return parsed_policy
numerics = { numerics = {
# <= 2.10 # <= 2.10
# Reply # Reply

View File

@ -31,11 +31,13 @@ from supybot.test import *
import os import os
import unittest import unittest
import unittest.mock
import supybot.conf as conf import supybot.conf as conf
import supybot.world as world import supybot.world as world
import supybot.ircdb as ircdb import supybot.ircdb as ircdb
import supybot.ircutils as ircutils import supybot.ircutils as ircutils
from supybot.utils.minisix import io
class IrcdbTestCase(SupyTestCase): class IrcdbTestCase(SupyTestCase):
def setUp(self): def setUp(self):
@ -347,6 +349,50 @@ class IrcChannelTestCase(IrcdbTestCase):
c.removeBan(banmask) c.removeBan(banmask)
self.assertFalse(c.checkIgnored(prefix)) self.assertFalse(c.checkIgnored(prefix))
class IrcNetworkTestCase(IrcdbTestCase):
def testDefaults(self):
n = ircdb.IrcNetwork()
self.assertIsNone(n.name)
self.assertEqual(n.stsPolicies, {})
self.assertEqual(n.lastDisconnectTimes, {})
def testStsPolicy(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'bar')
n.addStsPolicy('baz', 'qux')
self.assertEqual(n.stsPolicies, {
'foo': 'bar',
'baz': 'qux',
})
def testAddDisconnection(self):
n = ircdb.IrcNetwork()
min_ = int(time.time())
n.addDisconnection('foo')
max_ = int(time.time())
self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_)
def testPreserve(self):
n = ircdb.IrcNetwork('foonet')
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
disconnect_time_baz = n.lastDisconnectTimes['baz']
fd = io.StringIO()
n.preserve(fd, indent=' ')
fd.seek(0)
self.assertCountEqual(fd.read().split('\n'), [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
'',
])
class UsersDictionaryTestCase(IrcdbTestCase): class UsersDictionaryTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(), filename = os.path.join(conf.supybot.directories.conf(),
'UsersDictionaryTestCase.conf') 'UsersDictionaryTestCase.conf')
@ -401,6 +447,89 @@ class UsersDictionaryTestCase(IrcdbTestCase):
self.assertRaises(ValueError, self.users.setUser, u2) self.assertRaises(ValueError, self.users.setUser, u2)
class NetworksDictionaryTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'NetworksDictionaryTestCase.conf')
def setUp(self):
try:
os.remove(self.filename)
except:
pass
self.networks = ircdb.NetworksDictionary()
IrcdbTestCase.setUp(self)
def testGetSetNetwork(self):
self.assertEqual(self.networks.getNetwork('foo').name, None)
n = ircdb.IrcNetwork()
n.name = 'foo'
self.networks.setNetwork('foo', n)
self.assertEqual(self.networks.getNetwork('foo').name, 'foo')
def testPreserveOne(self):
n = ircdb.IrcNetwork('foonet')
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
disconnect_time_baz = n.lastDisconnectTimes['baz']
self.networks.setNetwork('foonet', n)
fd = io.StringIO()
fd.close = lambda: None
self.networks.filename = 'blah'
original_Atomicfile = utils.file.AtomicFile
with unittest.mock.patch(
'supybot.utils.file.AtomicFile', return_value=fd):
self.networks.flush()
lines = fd.getvalue().split('\n')
self.assertEqual(lines.pop(0), 'network foonet')
self.assertCountEqual(lines, [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
'',
])
def testPreserveThree(self):
n = ircdb.IrcNetwork('foonet')
n.addStsPolicy('foo', 'sts1')
self.networks.setNetwork('foonet', n)
n = ircdb.IrcNetwork('barnet')
n.addStsPolicy('bar', 'sts2')
self.networks.setNetwork('barnet', n)
n = ircdb.IrcNetwork('baznet')
n.addStsPolicy('baz', 'sts3')
self.networks.setNetwork('baznet', n)
fd = io.StringIO()
fd.close = lambda: None
self.networks.filename = 'blah'
original_Atomicfile = utils.file.AtomicFile
with unittest.mock.patch(
'supybot.utils.file.AtomicFile', return_value=fd):
self.networks.flush()
fd.seek(0)
self.assertEqual(fd.getvalue(),
'network barnet\n'
' stsPolicy bar sts2\n'
'\n'
'network baznet\n'
' stsPolicy baz sts3\n'
'\n'
'network foonet\n'
' stsPolicy foo sts1\n'
'\n'
)
class CheckCapabilityTestCase(IrcdbTestCase): class CheckCapabilityTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(), filename = os.path.join(conf.supybot.directories.conf(),
'CheckCapabilityTestCase.conf') 'CheckCapabilityTestCase.conf')

View File

@ -27,14 +27,15 @@
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
### ###
from supybot.test import *
import copy import copy
import pickle import pickle
import warnings import unittest.mock
from supybot.test import *
import supybot.conf as conf import supybot.conf as conf
import supybot.irclib as irclib import supybot.irclib as irclib
import supybot.drivers as drivers
import supybot.ircmsgs as ircmsgs import supybot.ircmsgs as ircmsgs
import supybot.ircutils as ircutils import supybot.ircutils as ircutils
@ -497,6 +498,58 @@ class IrcCapsTestCase(SupyTestCase):
self.assertEqual(m.args[0], 'REQ', m) self.assertEqual(m.args[0], 'REQ', m)
self.assertEqual(m.args[1], 'b'*400) self.assertEqual(m.args[1], 'b'*400)
class StsTestCase(SupyTestCase):
def setUp(self):
self.irc = irclib.Irc('test')
m = self.irc.takeMsg()
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'NICK', 'Expected NICK, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m)
self.irc.driver = unittest.mock.Mock()
def tearDown(self):
ircdb.networks.networks = {}
def testStsInSecureConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
'irc.test': 'duration=42,port=6697'})
self.irc.driver.reconnect.assert_not_called()
def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
def testStsInCleartextConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True))
class IrcTestCase(SupyTestCase): class IrcTestCase(SupyTestCase):
def setUp(self): def setUp(self):
self.irc = irclib.Irc('test') self.irc = irclib.Irc('test')