mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-11-29 22:29:24 +01:00
Add support for storing STS policies.
If on an insecure connection: reconnect. If on a secure connect: store it and do nothing else. For now, stored STS policies are not read when connecting to an insecure server.
This commit is contained in:
parent
ff5edd95a3
commit
ecc2c32950
@ -148,14 +148,14 @@ class Owner(callbacks.Plugin):
|
||||
def _connect(self, network, serverPort=None, password='', ssl=False):
|
||||
try:
|
||||
group = conf.supybot.networks.get(network)
|
||||
(server, port) = group.servers()[0]
|
||||
group.servers()[0]
|
||||
except (registry.NonExistentRegistryEntry, IndexError):
|
||||
if serverPort is None:
|
||||
raise ValueError('connect requires a (server, port) ' \
|
||||
'if the network is not registered.')
|
||||
conf.registerNetwork(network, password, ssl)
|
||||
serverS = '%s:%s' % serverPort
|
||||
conf.supybot.networks.get(network).servers.append(serverS)
|
||||
server = '%s:%s' % serverPort
|
||||
conf.supybot.networks.get(network).servers.append(server)
|
||||
assert conf.supybot.networks.get(network).servers(), \
|
||||
'No servers are set for the %s network.' % network
|
||||
self.log.debug('Creating new Irc for %s.', network)
|
||||
|
16
src/conf.py
16
src/conf.py
@ -273,15 +273,17 @@ class Servers(registry.SpaceSeparatedListOfStrings):
|
||||
return s
|
||||
|
||||
def convert(self, s):
|
||||
from .drivers import Server
|
||||
|
||||
s = self.normalize(s)
|
||||
(server, port) = s.rsplit(':', 1)
|
||||
(hostname, port) = s.rsplit(':', 1)
|
||||
|
||||
# support for `[ipv6]:port` format
|
||||
if server.startswith("[") and server.endswith("]"):
|
||||
server = server[1:-1]
|
||||
if hostname.startswith("[") and hostname.endswith("]"):
|
||||
hostname = hostname[1:-1]
|
||||
|
||||
port = int(port)
|
||||
return (server, port)
|
||||
return Server(hostname, port, force_tls_verification=False)
|
||||
|
||||
def __call__(self):
|
||||
L = registry.SpaceSeparatedListOfStrings.__call__(self)
|
||||
@ -1039,6 +1041,12 @@ registerGlobalValue(supybot.databases.channels, 'filename',
|
||||
for the channels database. This file will go into the directory specified
|
||||
by the supybot.directories.conf variable.""")))
|
||||
|
||||
registerGroup(supybot.databases, 'networks')
|
||||
registerGlobalValue(supybot.databases.networks, 'filename',
|
||||
registry.String('networks.conf', _("""Determines what filename will be used
|
||||
for the networks database. This file will go into the directory specified
|
||||
by the supybot.directories.conf variable.""")))
|
||||
|
||||
# TODO This will need to do more in the future (such as making sure link.allow
|
||||
# will let the link occur), but for now let's just leave it as this.
|
||||
class ChannelSpecific(registry.Boolean):
|
||||
|
@ -35,12 +35,12 @@ Contains simple socket drivers. Asyncore bugged (haha, pun!) me.
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import errno
|
||||
import threading
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
|
||||
try:
|
||||
import ipaddress # Python >= 3.3 or backported ipaddress
|
||||
@ -221,7 +221,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
def connect(self, **kwargs):
|
||||
self.reconnect(reset=False, **kwargs)
|
||||
|
||||
def reconnect(self, wait=False, reset=True):
|
||||
def reconnect(self, wait=False, reset=True, server=None):
|
||||
self._attempt += 1
|
||||
self.nextReconnectTime = None
|
||||
if self.connected:
|
||||
@ -242,7 +242,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
if wait:
|
||||
self.scheduleReconnect()
|
||||
return
|
||||
self.server = self._getNextServer()
|
||||
self.server = server or self._getNextServer()
|
||||
network_config = getattr(conf.supybot.networks, self.irc.network)
|
||||
socks_proxy = network_config.socksproxy()
|
||||
try:
|
||||
@ -252,20 +252,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
log.error('Cannot use socks proxy (SocksiPy not installed), '
|
||||
'using direct connection instead.')
|
||||
socks_proxy = ''
|
||||
if socks_proxy:
|
||||
address = self.server[0]
|
||||
else:
|
||||
try:
|
||||
address = utils.net.getAddressFromHostname(self.server[0],
|
||||
hostname = utils.net.getAddressFromHostname(
|
||||
self.server.hostname,
|
||||
attempt=self._attempt)
|
||||
except (socket.gaierror, socket.error) as e:
|
||||
drivers.log.connectError(self.currentServer, e)
|
||||
self.scheduleReconnect()
|
||||
return
|
||||
port = self.server[1]
|
||||
drivers.log.connect(self.currentServer)
|
||||
try:
|
||||
self.conn = utils.net.getSocket(address, port=port,
|
||||
self.conn = utils.net.getSocket(
|
||||
self.server.hostname,
|
||||
port=self.server.port,
|
||||
socks_proxy=socks_proxy,
|
||||
vhost=conf.supybot.protocols.irc.vhost(),
|
||||
vhostv6=conf.supybot.protocols.irc.vhostv6(),
|
||||
@ -280,12 +280,12 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
try:
|
||||
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
|
||||
# See http://stackoverflow.com/q/16136916/539465
|
||||
self.conn.connect((address, port))
|
||||
if network_config.ssl():
|
||||
self.conn.connect((self.server.hostname, self.server.port))
|
||||
if network_config.ssl() or self.server.force_tls_verification:
|
||||
self.starttls()
|
||||
|
||||
# Suppress this warning for loopback IPs.
|
||||
targetip = address
|
||||
targetip = hostname
|
||||
if sys.version_info[0] < 3:
|
||||
# Backported Python 2 ipaddress demands unicode instead of str
|
||||
targetip = targetip.decode('utf-8')
|
||||
@ -361,6 +361,15 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
def name(self):
|
||||
return '%s(%s)' % (self.__class__.__name__, self.irc)
|
||||
|
||||
def anyCertValidationEnabled(self):
|
||||
"""Returns whether any kind of certificate validation is enabled, other
|
||||
than Server.force_tls_verification."""
|
||||
return any([
|
||||
conf.supybot.protocols.ssl.verifyCertificates(),
|
||||
network_config.ssl.serverFingerprints(),
|
||||
network_config.ssl.authorityCertificate(),
|
||||
])
|
||||
|
||||
def starttls(self):
|
||||
assert 'ssl' in globals()
|
||||
network_config = getattr(conf.supybot.networks, self.irc.network)
|
||||
@ -373,6 +382,10 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
drivers.log.warning('Could not find cert file %s.' %
|
||||
certfile)
|
||||
certfile = None
|
||||
if self.server.force_tls_verification \
|
||||
and not self.anyCertValidationEnabled():
|
||||
verifyCertificates = True
|
||||
else:
|
||||
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
|
||||
if not verifyCertificates:
|
||||
drivers.log.warning('Not checking SSL certificates, connections '
|
||||
@ -381,7 +394,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
|
||||
'to enable validity checks.')
|
||||
try:
|
||||
self.conn = utils.net.ssl_wrap_socket(self.conn,
|
||||
logger=drivers.log, hostname=self.server[0],
|
||||
logger=drivers.log,
|
||||
hostname=self.server.hostname,
|
||||
certfile=certfile,
|
||||
verify=verifyCertificates,
|
||||
trusted_fingerprints=network_config.ssl.serverFingerprints(),
|
||||
|
@ -33,10 +33,19 @@ Contains various drivers (network, file, and otherwise) for using IRC objects.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from collections import namedtuple
|
||||
|
||||
from .. import conf, ircmsgs, log as supylog, utils
|
||||
from ..utils import minisix
|
||||
|
||||
|
||||
Server = namedtuple('Server', 'hostname port force_tls_verification')
|
||||
# force_tls_verification=True implies two things:
|
||||
# 1. force TLS to be enabled for this server
|
||||
# 2. ensure there is some kind of verification. If the user did not enable
|
||||
# any, use standard PKI validation.
|
||||
|
||||
|
||||
_drivers = {}
|
||||
_deadDrivers = set()
|
||||
_newDrivers = []
|
||||
@ -80,7 +89,7 @@ class ServersMixin(object):
|
||||
assert self.servers, 'Servers value for %s is empty.' % \
|
||||
self.networkGroup._name
|
||||
server = self.servers.pop(0)
|
||||
self.currentServer = '%s:%s' % server
|
||||
self.currentServer = '%s:%s' % (server.hostname, server.port)
|
||||
return server
|
||||
|
||||
|
||||
|
157
src/ircdb.py
157
src/ircdb.py
@ -496,6 +496,44 @@ class IrcChannel(object):
|
||||
fd.write(os.linesep)
|
||||
|
||||
|
||||
class IrcNetwork(object):
|
||||
"""This class holds dynamic information about a network that should be
|
||||
preserved across restarts."""
|
||||
__slots__ = ('name', 'stsPolicies', 'lastDisconnectTimes')
|
||||
|
||||
def __init__(self, name=None, stsPolicies=None, lastDisconnectTimes=None):
|
||||
self.name = name
|
||||
self.stsPolicies = stsPolicies or {}
|
||||
self.lastDisconnectTimes = lastDisconnectTimes or {}
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(name=%r, stsPolicy=%r, lastDisconnectTimes=%s)' % \
|
||||
(self.__class__.__name, self.name, self.stsPolicy,
|
||||
self.lastDisconnectTimes)
|
||||
|
||||
def addStsPolicy(self, server, stsPolicy):
|
||||
assert isinstance(stsPolicy, str)
|
||||
self.stsPolicies[server] = stsPolicy
|
||||
|
||||
def addDisconnection(self, server):
|
||||
self.lastDisconnectTimes[server] = int(time.time())
|
||||
|
||||
def preserve(self, fd, indent=''):
|
||||
def write(s):
|
||||
fd.write(indent)
|
||||
fd.write(s)
|
||||
fd.write(os.linesep)
|
||||
|
||||
for (server, stsPolicy) in sorted(self.stsPolicies.items()):
|
||||
write('stsPolicy %s %s' % (server, stsPolicy))
|
||||
|
||||
for (server, disconnectTime) in \
|
||||
sorted(self.lastDisconnectTimes.items()):
|
||||
write('lastDisconnectTime %s %s' % (server, disconnectTime))
|
||||
|
||||
fd.write(os.linesep)
|
||||
|
||||
|
||||
class Creator(object):
|
||||
__slots__ = ()
|
||||
def badCommand(self, command, rest, lineno):
|
||||
@ -615,6 +653,31 @@ class IrcChannelCreator(Creator):
|
||||
IrcChannelCreator.name = None
|
||||
|
||||
|
||||
class IrcNetworkCreator(Creator):
|
||||
__slots__ = ('net', 'networks')
|
||||
|
||||
def __init__(self, networks):
|
||||
self.net = IrcNetwork()
|
||||
self.networks = networks
|
||||
|
||||
def network(self, rest, lineno):
|
||||
self.net.name = rest
|
||||
|
||||
def stspolicy(self, rest, lineno):
|
||||
(server, stsPolicy) = rest.split()
|
||||
self.net.addStsPolicy(server, stsPolicy)
|
||||
|
||||
def lastdisconnecttime(self, rest, lineno):
|
||||
(server, when) = rest.split()
|
||||
when = int(when)
|
||||
self.net.lastDisconnectTimes[server] = when
|
||||
|
||||
def finish(self):
|
||||
if self.net.name:
|
||||
self.networks.setNetwork(self.net)
|
||||
self.name = None
|
||||
|
||||
|
||||
class DuplicateHostmask(ValueError):
|
||||
pass
|
||||
|
||||
@ -666,10 +729,8 @@ class UsersDictionary(utils.IterableMap):
|
||||
"""Flushes the database to its file."""
|
||||
if not self.noFlush:
|
||||
if self.filename is not None:
|
||||
L = list(self.users.items())
|
||||
L.sort()
|
||||
fd = utils.file.AtomicFile(self.filename)
|
||||
for (id, u) in L:
|
||||
for (id, u) in sorted(self.users.items()):
|
||||
fd.write('user %s' % id)
|
||||
fd.write(os.linesep)
|
||||
u.preserve(fd, indent=' ')
|
||||
@ -861,7 +922,7 @@ class ChannelsDictionary(utils.IterableMap):
|
||||
if not self.noFlush:
|
||||
if self.filename is not None:
|
||||
fd = utils.file.AtomicFile(self.filename)
|
||||
for (channel, c) in self.channels.items():
|
||||
for (channel, c) in sorted(self.channels.items()):
|
||||
fd.write('channel %s' % channel)
|
||||
fd.write(os.linesep)
|
||||
c.preserve(fd, indent=' ')
|
||||
@ -907,6 +968,83 @@ class ChannelsDictionary(utils.IterableMap):
|
||||
def items(self):
|
||||
return self.channels.items()
|
||||
|
||||
class NetworksDictionary(utils.IterableMap):
|
||||
__slots__ = ('noFlush', 'filename', 'networks')
|
||||
|
||||
def __init__(self):
|
||||
self.noFlush = False
|
||||
self.filename = None
|
||||
self.networks = ircutils.IrcDict()
|
||||
|
||||
def open(self, filename):
|
||||
self.noFlush = True
|
||||
try:
|
||||
self.filename = filename
|
||||
reader = unpreserve.Reader(IrcNetworkCreator, self)
|
||||
try:
|
||||
reader.readFile(filename)
|
||||
self.noFlush = False
|
||||
self.flush()
|
||||
except EnvironmentError as e:
|
||||
log.error('Invalid network database, resetting to empty.')
|
||||
log.error('Exact error: %s', utils.exnToString(e))
|
||||
except Exception as e:
|
||||
log.error('Invalid network database, resetting to empty.')
|
||||
log.exception('Exact error:')
|
||||
finally:
|
||||
self.noFlush = False
|
||||
|
||||
def flush(self):
|
||||
"""Flushes the network database to its file."""
|
||||
if not self.noFlush:
|
||||
if self.filename is not None:
|
||||
fd = utils.file.AtomicFile(self.filename)
|
||||
for (network, net) in sorted(self.networks.items()):
|
||||
fd.write('network %s' % network)
|
||||
fd.write(os.linesep)
|
||||
net.preserve(fd, indent=' ')
|
||||
fd.close()
|
||||
else:
|
||||
log.warning('NetworksDictionary.flush without self.filename.')
|
||||
else:
|
||||
log.debug('Not flushing NetworksDictionary because of noFlush.')
|
||||
|
||||
def close(self):
|
||||
self.flush()
|
||||
if self.flush in world.flushers:
|
||||
world.flushers.remove(self.flush)
|
||||
self.networks.clear()
|
||||
|
||||
def reload(self):
|
||||
"""Reloads the network database from its file."""
|
||||
if self.filename is not None:
|
||||
self.networks.clear()
|
||||
try:
|
||||
self.open(self.filename)
|
||||
except EnvironmentError as e:
|
||||
log.warning('NetworksDictionary.reload failed: %s', e)
|
||||
else:
|
||||
log.warning('NetworksDictionary.reload without self.filename.')
|
||||
|
||||
def getNetwork(self, network):
|
||||
"""Returns an IrcNetwork object for the given network."""
|
||||
network = network.lower()
|
||||
if network in self.networks:
|
||||
return self.networks[network]
|
||||
else:
|
||||
c = IrcNetwork()
|
||||
self.networks[network] = c
|
||||
return c
|
||||
|
||||
def setNetwork(self, network, ircNetwork):
|
||||
"""Sets a given network to the IrcNetwork object given."""
|
||||
network = network.lower()
|
||||
self.networks[network] = ircNetwork
|
||||
self.flush()
|
||||
|
||||
def items(self):
|
||||
return self.networks.items()
|
||||
|
||||
|
||||
class IgnoresDB(object):
|
||||
__slots__ = ('filename', 'hostmasks')
|
||||
@ -996,6 +1134,14 @@ try:
|
||||
except EnvironmentError as e:
|
||||
log.warning('Couldn\'t open channel database: %s', e)
|
||||
|
||||
try:
|
||||
networkFile = os.path.join(confDir,
|
||||
conf.supybot.databases.networks.filename())
|
||||
networks = NetworksDictionary()
|
||||
networks.open(networkFile)
|
||||
except EnvironmentError as e:
|
||||
log.warning('Couldn\'t open network database: %s', e)
|
||||
|
||||
try:
|
||||
ignoreFile = os.path.join(confDir,
|
||||
conf.supybot.databases.ignores.filename())
|
||||
@ -1006,8 +1152,9 @@ except EnvironmentError as e:
|
||||
|
||||
|
||||
world.flushers.append(users.flush)
|
||||
world.flushers.append(ignores.flush)
|
||||
world.flushers.append(channels.flush)
|
||||
world.flushers.append(networks.flush)
|
||||
world.flushers.append(ignores.flush)
|
||||
|
||||
|
||||
###
|
||||
|
@ -55,6 +55,7 @@ except ImportError:
|
||||
scram = None
|
||||
|
||||
from . import conf, ircdb, ircmsgs, ircutils, log, utils, world
|
||||
from .drivers import Server
|
||||
from .utils.str import rsplit
|
||||
from .utils.iter import chain
|
||||
from .utils.structures import smallqueue, RingBuffer
|
||||
@ -1468,14 +1469,42 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
||||
self.network, caps)
|
||||
self.capUpkeep()
|
||||
|
||||
def _onCapSts(self, policy):
|
||||
parsed_policy = ircutils._parseStsPolicy(log, policy)
|
||||
if parsed_policy is None:
|
||||
# There was an error (and it was logged). Abort the connection.
|
||||
self.driver.reconnect()
|
||||
return
|
||||
|
||||
if not self.driver.ssl or not self.driver.anyCertValidationEnabled():
|
||||
hostname = self.driver.server.hostname
|
||||
# Reconnect to the server, but with TLS *and* certificate
|
||||
# validation this time.
|
||||
self.driver.reconnect(
|
||||
server=Server(hostname, parsed_policy['port'], True))
|
||||
else:
|
||||
# TLS is enabled and certificate is verified; write the STS policy
|
||||
# in stone.
|
||||
# For future-proofing (because we don't want to write an invalid
|
||||
# value), we write the raw policy received from the server instead
|
||||
# of the parsed one.
|
||||
ircdb.networks.getNetwork(self.network).addStsPolicy(
|
||||
self.driver.server.hostname, policy)
|
||||
|
||||
def _addCapabilities(self, capstring):
|
||||
for item in capstring.split():
|
||||
while item.startswith(('=', '~')):
|
||||
item = item[1:]
|
||||
if '=' in item:
|
||||
(cap, value) = item.split('=', 1)
|
||||
if cap == 'sts':
|
||||
self._onCapSts(value)
|
||||
self.state.capabilities_ls[cap] = value
|
||||
else:
|
||||
if item == 'sts':
|
||||
log.error('Got "sts" capability without value. Aborting '
|
||||
'connection.')
|
||||
self.driver.reconnect()
|
||||
self.state.capabilities_ls[item] = None
|
||||
|
||||
def doCapLs(self, msg):
|
||||
|
@ -931,6 +931,31 @@ class AuthenticateDecoder(object):
|
||||
return base64.b64decode(b''.join(self.chunks))
|
||||
|
||||
|
||||
def _parseStsPolicy(logger, policy):
|
||||
parsed_policy = {}
|
||||
for kv in policy.split(','):
|
||||
if '=' in kv:
|
||||
(k, v) = kv.split('=', 1)
|
||||
parsed_policy[k] = v
|
||||
else:
|
||||
parsed_policy[kv] = None
|
||||
|
||||
for key in ('port', 'duration'):
|
||||
if parsed_policy.get(key) is None:
|
||||
logger.error('Missing or empty "%s" key in STS policy.'
|
||||
'Aborting connection.', key)
|
||||
return None
|
||||
try:
|
||||
parsed_policy[key] = int(parsed_policy[key])
|
||||
except ValueError:
|
||||
logger.error('Expected integer as value for key "%s" in STS '
|
||||
'policy, got %r instead. Aborting connection.',
|
||||
key, parsed_policy[key])
|
||||
return None
|
||||
|
||||
return parsed_policy
|
||||
|
||||
|
||||
numerics = {
|
||||
# <= 2.10
|
||||
# Reply
|
||||
|
@ -31,11 +31,13 @@ from supybot.test import *
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
import supybot.conf as conf
|
||||
import supybot.world as world
|
||||
import supybot.ircdb as ircdb
|
||||
import supybot.ircutils as ircutils
|
||||
from supybot.utils.minisix import io
|
||||
|
||||
class IrcdbTestCase(SupyTestCase):
|
||||
def setUp(self):
|
||||
@ -347,6 +349,50 @@ class IrcChannelTestCase(IrcdbTestCase):
|
||||
c.removeBan(banmask)
|
||||
self.assertFalse(c.checkIgnored(prefix))
|
||||
|
||||
class IrcNetworkTestCase(IrcdbTestCase):
|
||||
def testDefaults(self):
|
||||
n = ircdb.IrcNetwork()
|
||||
self.assertIsNone(n.name)
|
||||
self.assertEqual(n.stsPolicies, {})
|
||||
self.assertEqual(n.lastDisconnectTimes, {})
|
||||
|
||||
def testStsPolicy(self):
|
||||
n = ircdb.IrcNetwork()
|
||||
n.addStsPolicy('foo', 'bar')
|
||||
n.addStsPolicy('baz', 'qux')
|
||||
self.assertEqual(n.stsPolicies, {
|
||||
'foo': 'bar',
|
||||
'baz': 'qux',
|
||||
})
|
||||
|
||||
def testAddDisconnection(self):
|
||||
n = ircdb.IrcNetwork()
|
||||
min_ = int(time.time())
|
||||
n.addDisconnection('foo')
|
||||
max_ = int(time.time())
|
||||
self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_)
|
||||
|
||||
def testPreserve(self):
|
||||
n = ircdb.IrcNetwork('foonet')
|
||||
n.addStsPolicy('foo', 'sts1')
|
||||
n.addStsPolicy('bar', 'sts2')
|
||||
n.addDisconnection('foo')
|
||||
n.addDisconnection('baz')
|
||||
disconnect_time_foo = n.lastDisconnectTimes['foo']
|
||||
disconnect_time_baz = n.lastDisconnectTimes['baz']
|
||||
fd = io.StringIO()
|
||||
n.preserve(fd, indent=' ')
|
||||
fd.seek(0)
|
||||
self.assertCountEqual(fd.read().split('\n'), [
|
||||
' stsPolicy foo sts1',
|
||||
' stsPolicy bar sts2',
|
||||
' lastDisconnectTime foo %d' % disconnect_time_foo,
|
||||
' lastDisconnectTime baz %d' % disconnect_time_baz,
|
||||
'',
|
||||
'',
|
||||
])
|
||||
|
||||
|
||||
class UsersDictionaryTestCase(IrcdbTestCase):
|
||||
filename = os.path.join(conf.supybot.directories.conf(),
|
||||
'UsersDictionaryTestCase.conf')
|
||||
@ -401,6 +447,89 @@ class UsersDictionaryTestCase(IrcdbTestCase):
|
||||
self.assertRaises(ValueError, self.users.setUser, u2)
|
||||
|
||||
|
||||
class NetworksDictionaryTestCase(IrcdbTestCase):
|
||||
filename = os.path.join(conf.supybot.directories.conf(),
|
||||
'NetworksDictionaryTestCase.conf')
|
||||
def setUp(self):
|
||||
try:
|
||||
os.remove(self.filename)
|
||||
except:
|
||||
pass
|
||||
self.networks = ircdb.NetworksDictionary()
|
||||
IrcdbTestCase.setUp(self)
|
||||
|
||||
def testGetSetNetwork(self):
|
||||
self.assertEqual(self.networks.getNetwork('foo').name, None)
|
||||
|
||||
n = ircdb.IrcNetwork()
|
||||
n.name = 'foo'
|
||||
self.networks.setNetwork('foo', n)
|
||||
self.assertEqual(self.networks.getNetwork('foo').name, 'foo')
|
||||
|
||||
def testPreserveOne(self):
|
||||
n = ircdb.IrcNetwork('foonet')
|
||||
n.addStsPolicy('foo', 'sts1')
|
||||
n.addStsPolicy('bar', 'sts2')
|
||||
n.addDisconnection('foo')
|
||||
n.addDisconnection('baz')
|
||||
disconnect_time_foo = n.lastDisconnectTimes['foo']
|
||||
disconnect_time_baz = n.lastDisconnectTimes['baz']
|
||||
self.networks.setNetwork('foonet', n)
|
||||
|
||||
fd = io.StringIO()
|
||||
fd.close = lambda: None
|
||||
self.networks.filename = 'blah'
|
||||
original_Atomicfile = utils.file.AtomicFile
|
||||
with unittest.mock.patch(
|
||||
'supybot.utils.file.AtomicFile', return_value=fd):
|
||||
self.networks.flush()
|
||||
|
||||
lines = fd.getvalue().split('\n')
|
||||
self.assertEqual(lines.pop(0), 'network foonet')
|
||||
self.assertCountEqual(lines, [
|
||||
' stsPolicy foo sts1',
|
||||
' stsPolicy bar sts2',
|
||||
' lastDisconnectTime foo %d' % disconnect_time_foo,
|
||||
' lastDisconnectTime baz %d' % disconnect_time_baz,
|
||||
'',
|
||||
'',
|
||||
])
|
||||
|
||||
def testPreserveThree(self):
|
||||
n = ircdb.IrcNetwork('foonet')
|
||||
n.addStsPolicy('foo', 'sts1')
|
||||
self.networks.setNetwork('foonet', n)
|
||||
|
||||
n = ircdb.IrcNetwork('barnet')
|
||||
n.addStsPolicy('bar', 'sts2')
|
||||
self.networks.setNetwork('barnet', n)
|
||||
|
||||
n = ircdb.IrcNetwork('baznet')
|
||||
n.addStsPolicy('baz', 'sts3')
|
||||
self.networks.setNetwork('baznet', n)
|
||||
|
||||
fd = io.StringIO()
|
||||
fd.close = lambda: None
|
||||
self.networks.filename = 'blah'
|
||||
original_Atomicfile = utils.file.AtomicFile
|
||||
with unittest.mock.patch(
|
||||
'supybot.utils.file.AtomicFile', return_value=fd):
|
||||
self.networks.flush()
|
||||
|
||||
fd.seek(0)
|
||||
self.assertEqual(fd.getvalue(),
|
||||
'network barnet\n'
|
||||
' stsPolicy bar sts2\n'
|
||||
'\n'
|
||||
'network baznet\n'
|
||||
' stsPolicy baz sts3\n'
|
||||
'\n'
|
||||
'network foonet\n'
|
||||
' stsPolicy foo sts1\n'
|
||||
'\n'
|
||||
)
|
||||
|
||||
|
||||
class CheckCapabilityTestCase(IrcdbTestCase):
|
||||
filename = os.path.join(conf.supybot.directories.conf(),
|
||||
'CheckCapabilityTestCase.conf')
|
||||
|
@ -27,14 +27,15 @@
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
###
|
||||
|
||||
from supybot.test import *
|
||||
|
||||
import copy
|
||||
import pickle
|
||||
import warnings
|
||||
import unittest.mock
|
||||
|
||||
from supybot.test import *
|
||||
|
||||
import supybot.conf as conf
|
||||
import supybot.irclib as irclib
|
||||
import supybot.drivers as drivers
|
||||
import supybot.ircmsgs as ircmsgs
|
||||
import supybot.ircutils as ircutils
|
||||
|
||||
@ -497,6 +498,58 @@ class IrcCapsTestCase(SupyTestCase):
|
||||
self.assertEqual(m.args[0], 'REQ', m)
|
||||
self.assertEqual(m.args[1], 'b'*400)
|
||||
|
||||
class StsTestCase(SupyTestCase):
|
||||
def setUp(self):
|
||||
self.irc = irclib.Irc('test')
|
||||
|
||||
m = self.irc.takeMsg()
|
||||
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
|
||||
self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m)
|
||||
|
||||
m = self.irc.takeMsg()
|
||||
self.failUnless(m.command == 'NICK', 'Expected NICK, got %r.' % m)
|
||||
|
||||
m = self.irc.takeMsg()
|
||||
self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m)
|
||||
|
||||
self.irc.driver = unittest.mock.Mock()
|
||||
|
||||
def tearDown(self):
|
||||
ircdb.networks.networks = {}
|
||||
|
||||
def testStsInSecureConnection(self):
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = True
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
|
||||
'irc.test': 'duration=42,port=6697'})
|
||||
self.irc.driver.reconnect.assert_not_called()
|
||||
|
||||
def testStsInInsecureTlsConnection(self):
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6697, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||
self.irc.driver.reconnect.assert_called_once_with(
|
||||
server=drivers.Server('irc.test', 6697, True))
|
||||
|
||||
def testStsInCleartextConnection(self):
|
||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||
self.irc.driver.ssl = True
|
||||
self.irc.driver.server = drivers.Server('irc.test', 6667, False)
|
||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||
|
||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
|
||||
self.irc.driver.reconnect.assert_called_once_with(
|
||||
server=drivers.Server('irc.test', 6697, True))
|
||||
|
||||
class IrcTestCase(SupyTestCase):
|
||||
def setUp(self):
|
||||
self.irc = irclib.Irc('test')
|
||||
|
Loading…
Reference in New Issue
Block a user