Merge branch 'sts' into testing

This commit is contained in:
Valentin Lorentz 2020-05-02 17:10:13 +02:00
commit fc4cc956ba
12 changed files with 904 additions and 245 deletions

View File

@ -148,14 +148,14 @@ class Owner(callbacks.Plugin):
def _connect(self, network, serverPort=None, password='', ssl=False):
try:
group = conf.supybot.networks.get(network)
(server, port) = group.servers()[0]
group.servers()[0]
except (registry.NonExistentRegistryEntry, IndexError):
if serverPort is None:
raise ValueError('connect requires a (server, port) ' \
'if the network is not registered.')
conf.registerNetwork(network, password, ssl)
serverS = '%s:%s' % serverPort
conf.supybot.networks.get(network).servers.append(serverS)
server = '%s:%s' % serverPort
conf.supybot.networks.get(network).servers.append(server)
assert conf.supybot.networks.get(network).servers(), \
'No servers are set for the %s network.' % network
self.log.debug('Creating new Irc for %s.', network)

View File

@ -273,15 +273,17 @@ class Servers(registry.SpaceSeparatedListOfStrings):
return s
def convert(self, s):
from .drivers import Server
s = self.normalize(s)
(server, port) = s.rsplit(':', 1)
(hostname, port) = s.rsplit(':', 1)
# support for `[ipv6]:port` format
if server.startswith("[") and server.endswith("]"):
server = server[1:-1]
if hostname.startswith("[") and hostname.endswith("]"):
hostname = hostname[1:-1]
port = int(port)
return (server, port)
return Server(hostname, port, force_tls_verification=False)
def __call__(self):
L = registry.SpaceSeparatedListOfStrings.__call__(self)
@ -880,13 +882,11 @@ registerGlobalValue(supybot.drivers, 'poll',
class ValidDriverModule(registry.OnlySomeStrings):
__slots__ = ()
validStrings = ('default', 'Socket', 'Twisted')
validStrings = ('default', 'Socket')
registerGlobalValue(supybot.drivers, 'module',
ValidDriverModule('default', _("""Determines what driver module the
bot will use. The default is Socket which is simple and stable
and supports SSL. Twisted doesn't work if the IRC server which
you are connecting to has IPv6 (most of them do).""")))
bot will use. Current, the only (and default) driver is Socket.""")))
registerGlobalValue(supybot.drivers, 'maxReconnectWait',
registry.PositiveFloat(300.0, _("""Determines the maximum time the bot will
@ -1041,6 +1041,12 @@ registerGlobalValue(supybot.databases.channels, 'filename',
for the channels database. This file will go into the directory specified
by the supybot.directories.conf variable.""")))
registerGroup(supybot.databases, 'networks')
registerGlobalValue(supybot.databases.networks, 'filename',
registry.String('networks.conf', _("""Determines what filename will be used
for the networks database. This file will go into the directory specified
by the supybot.directories.conf variable.""")))
# TODO This will need to do more in the future (such as making sure link.allow
# will let the link occur), but for now let's just leave it as this.
class ChannelSpecific(registry.Boolean):

View File

@ -35,12 +35,12 @@ Contains simple socket drivers. Asyncore bugged (haha, pun!) me.
from __future__ import division
import os
import sys
import time
import errno
import threading
import select
import socket
import sys
try:
import ipaddress # Python >= 3.3 or backported ipaddress
@ -82,9 +82,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self.resetDelay()
if self.networkGroup.get('ssl').value and 'ssl' not in globals():
drivers.log.error('The Socket driver can not connect to SSL '
'servers for your Python version. Try the '
'Twisted driver instead, or install a Python'
'version that supports SSL (2.6 and greater).')
'servers for your Python version.')
self.ssl = False
else:
self.ssl = self.networkGroup.get('ssl').value
@ -223,10 +221,11 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def connect(self, **kwargs):
self.reconnect(reset=False, **kwargs)
def reconnect(self, wait=False, reset=True):
def reconnect(self, wait=False, reset=True, server=None):
self._attempt += 1
self.nextReconnectTime = None
if self.connected:
self.onDisconnect()
drivers.log.reconnect(self.irc.network)
if self in self._instances:
self._instances.remove(self)
@ -242,9 +241,12 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
else:
drivers.log.debug('Not resetting %s.', self.irc)
if wait:
if server is not None:
# Make this server be the next one to be used.
self.servers.insert(0, server)
self.scheduleReconnect()
return
self.server = self._getNextServer()
self.currentServer = server or self._getNextServer()
network_config = getattr(conf.supybot.networks, self.irc.network)
socks_proxy = network_config.socksproxy()
try:
@ -254,20 +256,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
log.error('Cannot use socks proxy (SocksiPy not installed), '
'using direct connection instead.')
socks_proxy = ''
if socks_proxy:
address = self.server[0]
else:
try:
address = utils.net.getAddressFromHostname(self.server[0],
attempt=self._attempt)
hostname = utils.net.getAddressFromHostname(
self.currentServer.hostname,
attempt=self._attempt)
except (socket.gaierror, socket.error) as e:
drivers.log.connectError(self.currentServer, e)
self.scheduleReconnect()
return
port = self.server[1]
drivers.log.connect(self.currentServer)
try:
self.conn = utils.net.getSocket(address, port=port,
self.conn = utils.net.getSocket(
self.currentServer.hostname,
port=self.currentServer.port,
socks_proxy=socks_proxy,
vhost=conf.supybot.protocols.irc.vhost(),
vhostv6=conf.supybot.protocols.irc.vhostv6(),
@ -282,17 +284,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
try:
# Connect before SSL, otherwise SSL is disabled if we use SOCKS.
# See http://stackoverflow.com/q/16136916/539465
self.conn.connect((address, port))
if network_config.ssl():
self.conn.connect(
(self.currentServer.hostname, self.currentServer.port))
if network_config.ssl() or \
self.currentServer.force_tls_verification:
self.starttls()
# Suppress this warning for loopback IPs.
targetip = address
targetip = hostname
if sys.version_info[0] < 3:
# Backported Python 2 ipaddress demands unicode instead of str
targetip = targetip.decode('utf-8')
elif (not network_config.requireStarttls()) and \
(not network_config.ssl()) and \
(not self.currentServer.force_tls_verification) and \
(ipaddress is None or not ipaddress.ip_address(targetip).is_loopback):
drivers.log.warning(('Connection to network %s '
'does not use SSL/TLS, which makes it vulnerable to '
@ -353,6 +358,8 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
if self.writeCheckTime is not None:
self.writeCheckTime = None
drivers.log.die(self.irc)
drivers.IrcDriver.die(self)
drivers.ServersMixin.die(self)
def _reallyDie(self):
if self.conn is not None:
@ -363,6 +370,16 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def name(self):
return '%s(%s)' % (self.__class__.__name__, self.irc)
def anyCertValidationEnabled(self):
"""Returns whether any kind of certificate validation is enabled, other
than Server.force_tls_verification."""
network_config = getattr(conf.supybot.networks, self.irc.network)
return any([
conf.supybot.protocols.ssl.verifyCertificates(),
network_config.ssl.serverFingerprints(),
network_config.ssl.authorityCertificate(),
])
def starttls(self):
assert 'ssl' in globals()
network_config = getattr(conf.supybot.networks, self.irc.network)
@ -375,15 +392,21 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
drivers.log.warning('Could not find cert file %s.' %
certfile)
certfile = None
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
if not verifyCertificates:
drivers.log.warning('Not checking SSL certificates, connections '
'are vulnerable to man-in-the-middle attacks. Set '
'supybot.protocols.ssl.verifyCertificates to "true" '
'to enable validity checks.')
if self.currentServer.force_tls_verification \
and not self.anyCertValidationEnabled():
verifyCertificates = True
else:
verifyCertificates = conf.supybot.protocols.ssl.verifyCertificates()
if not self.currentServer.force_tls_verification \
and not self.anyCertValidationEnabled():
drivers.log.warning('Not checking SSL certificates, connections '
'are vulnerable to man-in-the-middle attacks. Set '
'supybot.protocols.ssl.verifyCertificates to "true" '
'to enable validity checks.')
try:
self.conn = utils.net.ssl_wrap_socket(self.conn,
logger=drivers.log, hostname=self.server[0],
logger=drivers.log,
hostname=self.currentServer.hostname,
certfile=certfile,
verify=verifyCertificates,
trusted_fingerprints=network_config.ssl.serverFingerprints(),

View File

@ -1,160 +0,0 @@
###
# Copyright (c) 2002-2004, Jeremiah Fincher
# Copyright (c) 2009, James McCoy
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
from .. import conf, drivers
from twisted.names import client
from twisted.internet import reactor, error
from twisted.protocols.basic import LineReceiver
from twisted.internet.protocol import ReconnectingClientFactory
# This hack prevents the standard Twisted resolver from starting any
# threads, which allows for a clean shut-down in Twisted>=2.0
reactor.installResolver(client.createResolver())
try:
from OpenSSL import SSL
from twisted.internet import ssl
except ImportError:
drivers.log.debug('PyOpenSSL is not available, '
'cannot connect to SSL servers.')
SSL = None
class TwistedRunnerDriver(drivers.IrcDriver):
def name(self):
return self.__class__.__name__
def run(self):
try:
reactor.iterate(conf.supybot.drivers.poll())
except:
drivers.log.exception('Uncaught exception outside reactor:')
class SupyIrcProtocol(LineReceiver):
delimiter = '\n'
MAX_LENGTH = 1024
def __init__(self):
self.mostRecentCall = reactor.callLater(0.1, self.checkIrcForMsgs)
def lineReceived(self, line):
msg = drivers.parseMsg(line)
if msg is not None:
self.irc.feedMsg(msg)
def checkIrcForMsgs(self):
if self.connected:
msg = self.irc.takeMsg()
while msg:
self.transport.write(str(msg))
msg = self.irc.takeMsg()
self.mostRecentCall = reactor.callLater(0.1, self.checkIrcForMsgs)
def connectionLost(self, r):
self.mostRecentCall.cancel()
if r.check(error.ConnectionDone):
drivers.log.disconnect(self.factory.currentServer)
else:
drivers.log.disconnect(self.factory.currentServer, errorMsg(r))
if self.irc.zombie:
self.factory.stopTrying()
while self.irc.takeMsg():
continue
else:
self.irc.reset()
def connectionMade(self):
self.factory.resetDelay()
self.irc.driver = self
def die(self):
drivers.log.die(self.irc)
self.factory.stopTrying()
self.transport.loseConnection()
def reconnect(self, wait=None):
# We ignore wait here, because we handled our own waiting.
drivers.log.reconnect(self.irc.network)
self.transport.loseConnection()
def errorMsg(reason):
return reason.getErrorMessage()
class SupyReconnectingFactory(ReconnectingClientFactory, drivers.ServersMixin):
maxDelay = property(lambda self: conf.supybot.drivers.maxReconnectWait())
protocol = SupyIrcProtocol
def __init__(self, irc):
drivers.log.warning('Twisted driver is deprecated. You should '
'consider switching to Socket (set '
'supybot.drivers.module to Socket).')
self.irc = irc
drivers.ServersMixin.__init__(self, irc)
(server, port) = self._getNextServer()
vhost = conf.supybot.protocols.irc.vhost()
if self.networkGroup.get('ssl').value:
self.connectSSL(server, port, vhost)
else:
self.connectTCP(server, port, vhost)
def connectTCP(self, server, port, vhost):
"""Connect to the server with a standard TCP connection."""
reactor.connectTCP(server, port, self, bindAddress=(vhost, 0))
def connectSSL(self, server, port, vhost):
"""Connect to the server using an SSL socket."""
drivers.log.info('Attempting an SSL connection.')
if SSL:
reactor.connectSSL(server, port, self,
ssl.ClientContextFactory(), bindAddress=(vhost, 0))
else:
drivers.log.error('PyOpenSSL is not available. Not connecting.')
def clientConnectionFailed(self, connector, r):
drivers.log.connectError(self.currentServer, errorMsg(r))
(connector.host, connector.port) = self._getNextServer()
ReconnectingClientFactory.clientConnectionFailed(self, connector,r)
def clientConnectionLost(self, connector, r):
(connector.host, connector.port) = self._getNextServer()
ReconnectingClientFactory.clientConnectionLost(self, connector, r)
def startedConnecting(self, connector):
drivers.log.connect(self.currentServer)
def buildProtocol(self, addr):
protocol = ReconnectingClientFactory.buildProtocol(self, addr)
protocol.irc = self.irc
return protocol
Driver = SupyReconnectingFactory
poller = TwistedRunnerDriver()
# vim:set shiftwidth=4 softtabstop=4 expandtab textwidth=79:

View File

@ -32,11 +32,21 @@
Contains various drivers (network, file, and otherwise) for using IRC objects.
"""
import time
import socket
from collections import namedtuple
from .. import conf, ircmsgs, log as supylog, utils
from .. import conf, ircdb, ircmsgs, ircutils, log as supylog, utils
from ..utils import minisix
Server = namedtuple('Server', 'hostname port force_tls_verification')
# force_tls_verification=True implies two things:
# 1. force TLS to be enabled for this server
# 2. ensure there is some kind of verification. If the user did not enable
# any, use standard PKI validation.
_drivers = {}
_deadDrivers = set()
_newDrivers = []
@ -64,6 +74,7 @@ class IrcDriver(object):
class ServersMixin(object):
def __init__(self, irc, servers=()):
self.networkName = irc.network
self.networkGroup = conf.supybot.networks.get(irc.network)
self.servers = servers
super(ServersMixin, self).__init__()
@ -80,8 +91,42 @@ class ServersMixin(object):
assert self.servers, 'Servers value for %s is empty.' % \
self.networkGroup._name
server = self.servers.pop(0)
self.currentServer = '%s:%s' % server
return server
self.currentServer = self._applyStsPolicy(server)
return self.currentServer
def _applyStsPolicy(self, server):
network = ircdb.networks.getNetwork(self.networkName)
policy = network.stsPolicies.get(server.hostname)
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
if policy is None or lastDisconnect is None:
log.debug('No STS policy, or never disconnected from this server. %r %r',
policy, lastDisconnect)
return server
# The policy was stored, which means it was received on a secure
# connection.
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
if lastDisconnect + policy['duration'] < time.time():
log.info('STS policy expired, removing.')
network.expireStsPolicy(server.hostname)
return server
log.info('Using STS policy: changing port from %s to %s.',
server.port, policy['port'])
# Change the port, and force TLS verification, as required by the STS
# specification.
return Server(server.hostname, policy['port'],
force_tls_verification=True)
def die(self):
self.onDisconnect()
def onDisconnect(self):
network = ircdb.networks.getNetwork(self.networkName)
network.addDisconnection(self.currentServer.hostname)
def empty():
@ -129,7 +174,8 @@ def run():
class Log(object):
"""This is used to have a nice, consistent interface for drivers to use."""
def connect(self, server):
self.info('Connecting to %s.', server)
self.info('Connecting to %s:%s.',
server.hostname, server.port)
def connectError(self, server, e):
if isinstance(e, Exception):
@ -137,7 +183,8 @@ class Log(object):
e = e.args[1]
else:
e = utils.exnToString(e)
self.warning('Error connecting to %s: %s', server, e)
self.warning('Error connecting to %s:%s: %s',
server.hostname, server.port, e)
def disconnect(self, server, e=None):
if e:
@ -147,7 +194,8 @@ class Log(object):
e = str(e)
if not e.endswith('.'):
e += '.'
self.warning('Disconnect from %s: %s', server, e)
self.warning('Disconnect from %s:%s: %s',
server.hostname, server.port, e)
else:
self.info('Disconnect from %s.', server)

View File

@ -496,6 +496,47 @@ class IrcChannel(object):
fd.write(os.linesep)
class IrcNetwork(object):
"""This class holds dynamic information about a network that should be
preserved across restarts."""
__slots__ = ('stsPolicies', 'lastDisconnectTimes')
def __init__(self, stsPolicies=None, lastDisconnectTimes=None):
self.stsPolicies = stsPolicies or {}
self.lastDisconnectTimes = lastDisconnectTimes or {}
def __repr__(self):
return '%s(stsPolicies=%r, lastDisconnectTimes=%s)' % \
(self.__class__.__name__, self.stsPolicies,
self.lastDisconnectTimes)
def addStsPolicy(self, server, stsPolicy):
assert isinstance(stsPolicy, str)
self.stsPolicies[server] = stsPolicy
def expireStsPolicy(self, server):
if server in self.stsPolicies:
del self.stsPolicies[server]
def addDisconnection(self, server):
self.lastDisconnectTimes[server] = int(time.time())
def preserve(self, fd, indent=''):
def write(s):
fd.write(indent)
fd.write(s)
fd.write(os.linesep)
for (server, stsPolicy) in sorted(self.stsPolicies.items()):
write('stsPolicy %s %s' % (server, stsPolicy))
for (server, disconnectTime) in \
sorted(self.lastDisconnectTimes.items()):
write('lastDisconnectTime %s %s' % (server, disconnectTime))
fd.write(os.linesep)
class Creator(object):
__slots__ = ()
def badCommand(self, command, rest, lineno):
@ -615,6 +656,32 @@ class IrcChannelCreator(Creator):
IrcChannelCreator.name = None
class IrcNetworkCreator(Creator):
__slots__ = ('net', 'networks')
name = None
def __init__(self, networks):
self.net = IrcNetwork()
self.networks = networks
def network(self, rest, lineno):
IrcNetworkCreator.name = rest
def stspolicy(self, rest, lineno):
(server, stsPolicy) = rest.split()
self.net.addStsPolicy(server, stsPolicy)
def lastdisconnecttime(self, rest, lineno):
(server, when) = rest.split()
when = int(when)
self.net.lastDisconnectTimes[server] = when
def finish(self):
if self.name:
self.networks.setNetwork(self.name, self.net)
self.net = IrcNetwork()
class DuplicateHostmask(ValueError):
pass
@ -666,10 +733,8 @@ class UsersDictionary(utils.IterableMap):
"""Flushes the database to its file."""
if not self.noFlush:
if self.filename is not None:
L = list(self.users.items())
L.sort()
fd = utils.file.AtomicFile(self.filename)
for (id, u) in L:
for (id, u) in sorted(self.users.items()):
fd.write('user %s' % id)
fd.write(os.linesep)
u.preserve(fd, indent=' ')
@ -861,7 +926,7 @@ class ChannelsDictionary(utils.IterableMap):
if not self.noFlush:
if self.filename is not None:
fd = utils.file.AtomicFile(self.filename)
for (channel, c) in self.channels.items():
for (channel, c) in sorted(self.channels.items()):
fd.write('channel %s' % channel)
fd.write(os.linesep)
c.preserve(fd, indent=' ')
@ -907,6 +972,83 @@ class ChannelsDictionary(utils.IterableMap):
def items(self):
return self.channels.items()
class NetworksDictionary(utils.IterableMap):
__slots__ = ('noFlush', 'filename', 'networks')
def __init__(self):
self.noFlush = False
self.filename = None
self.networks = ircutils.IrcDict()
def open(self, filename):
self.noFlush = True
try:
self.filename = filename
reader = unpreserve.Reader(IrcNetworkCreator, self)
try:
reader.readFile(filename)
self.noFlush = False
self.flush()
except EnvironmentError as e:
log.error('Invalid network database, resetting to empty.')
log.error('Exact error: %s', utils.exnToString(e))
except Exception as e:
log.error('Invalid network database, resetting to empty.')
log.exception('Exact error:')
finally:
self.noFlush = False
def flush(self):
"""Flushes the network database to its file."""
if not self.noFlush:
if self.filename is not None:
fd = utils.file.AtomicFile(self.filename)
for (network, net) in sorted(self.networks.items()):
fd.write('network %s' % network)
fd.write(os.linesep)
net.preserve(fd, indent=' ')
fd.close()
else:
log.warning('NetworksDictionary.flush without self.filename.')
else:
log.debug('Not flushing NetworksDictionary because of noFlush.')
def close(self):
self.flush()
if self.flush in world.flushers:
world.flushers.remove(self.flush)
self.networks.clear()
def reload(self):
"""Reloads the network database from its file."""
if self.filename is not None:
self.networks.clear()
try:
self.open(self.filename)
except EnvironmentError as e:
log.warning('NetworksDictionary.reload failed: %s', e)
else:
log.warning('NetworksDictionary.reload without self.filename.')
def getNetwork(self, network):
"""Returns an IrcNetwork object for the given network."""
network = network.lower()
if network in self.networks:
return self.networks[network]
else:
c = IrcNetwork()
self.networks[network] = c
return c
def setNetwork(self, network, ircNetwork):
"""Sets a given network to the IrcNetwork object given."""
network = network.lower()
self.networks[network] = ircNetwork
self.flush()
def items(self):
return self.networks.items()
class IgnoresDB(object):
__slots__ = ('filename', 'hostmasks')
@ -996,6 +1138,14 @@ try:
except EnvironmentError as e:
log.warning('Couldn\'t open channel database: %s', e)
try:
networkFile = os.path.join(confDir,
conf.supybot.databases.networks.filename())
networks = NetworksDictionary()
networks.open(networkFile)
except EnvironmentError as e:
log.warning('Couldn\'t open network database: %s', e)
try:
ignoreFile = os.path.join(confDir,
conf.supybot.databases.ignores.filename())
@ -1006,8 +1156,9 @@ except EnvironmentError as e:
world.flushers.append(users.flush)
world.flushers.append(ignores.flush)
world.flushers.append(channels.flush)
world.flushers.append(networks.flush)
world.flushers.append(ignores.flush)
###

View File

@ -30,6 +30,7 @@
import re
import copy
import time
import enum
import random
import base64
import textwrap
@ -54,6 +55,7 @@ except ImportError:
scram = None
from . import conf, ircdb, ircmsgs, ircutils, log, utils, world
from .drivers import Server
from .utils.str import rsplit
from .utils.iter import chain
from .utils.structures import smallqueue, RingBuffer
@ -120,6 +122,7 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
'__call__': None,
'inFilter': lambda self, irc, msg: msg,
'outFilter': lambda self, irc, msg: msg,
'postTransition': None,
'name': lambda self: self.__class__.__name__,
'callPrecedence': lambda self, irc: ([], []),
}
@ -170,6 +173,12 @@ class IrcCallback(IrcCommandDispatcher, log.Firewalled):
"""
return msg
def postTransition(self, irc, msg, from_state, to_state):
"""Called when the state of the IRC connection changes.
`msg` is the message that triggered the transition, if any."""
pass
def __call__(self, irc, msg):
"""Used for handling each message."""
method = self.dispatchCommand(msg.command, msg.args)
@ -389,14 +398,130 @@ class ChannelState(utils.python.Object):
Batch = collections.namedtuple('Batch', 'type arguments messages')
class IrcStateFsm(object):
'''Finite State Machine keeping track of what part of the connection
initialization we are in.'''
__slots__ = ('state',)
@enum.unique
class States(enum.Enum):
UNINITIALIZED = 10
'''Nothing received yet (except server notices)'''
INIT_CAP_NEGOTIATION = 20
'''Sent CAP LS, did not send CAP END yet'''
INIT_SASL = 30
'''In an AUTHENTICATE session'''
INIT_WAITING_MOTD = 50
'''Waiting for start of MOTD'''
INIT_MOTD = 60
'''Waiting for end of MOTD'''
CONNECTED = 70
'''Normal state of the connections'''
CONNECTED_SASL = 80
'''Doing SASL authentication in the middle of a connection.'''
SHUTTING_DOWN = 100
def __init__(self):
self.reset()
def reset(self):
if getattr(self, 'state', None) is not None:
log.debug('resetting from %s to %s',
self.state, self.States.UNINITIALIZED)
self.state = self.States.UNINITIALIZED
def _transition(self, irc, msg, to_state, expected_from=None):
"""Transitions to state `to_state`.
If `expected_from` is not `None`, first checks the current state is
in the set.
After the transition, calls the
`postTransition(irc, msg, from_state, to_state)` method of all objects
in `irc.callbacks`.
`msg` may be None if the transition isn't triggered by a message, but
`irc` may not."""
from_state = self.state
if expected_from is None or from_state in expected_from:
log.debug('transition from %s to %s', self.state, to_state)
self.state = to_state
for callback in reversed(irc.callbacks):
msg = callback.postTransition(irc, msg, from_state, to_state)
else:
raise ValueError('unexpected transition to %s while in state %s' %
(to_state, self.state))
def expect_state(self, expected_states):
if self.state not in expected_states:
raise ValueError(('Connection in state %s, but expected to be '
'in state %s') % (self.state, expected_states))
def on_init_messages_sent(self, irc):
'''As soon as USER/NICK/CAP LS are sent'''
self._transition(irc, None, self.States.INIT_CAP_NEGOTIATION, [
self.States.UNINITIALIZED,
])
def on_sasl_cap(self, irc, msg):
'''Whenever we see the 'sasl' capability in a CAP LS response'''
if self.state == self.States.INIT_CAP_NEGOTIATION:
self._transition(irc, msg, self.States.INIT_SASL)
elif self.state == self.States.CONNECTED:
self._transition(irc, msg, self.States.CONNECTED_SASL)
else:
raise ValueError('Got sasl cap while in state %s' % self.state)
def on_sasl_auth_finished(self, irc, msg):
'''When sasl auth either succeeded or failed.'''
if self.state == self.States.INIT_SASL:
self._transition(irc, msg, self.States.INIT_CAP_NEGOTIATION)
elif self.state == self.States.CONNECTED_SASL:
self._transition(irc, msg, self.States.CONNECTED)
else:
raise ValueError('Finished SASL auth while in state %s' % self.state)
def on_cap_end(self, irc, msg):
'''When we send CAP END'''
self._transition(irc, msg, self.States.INIT_WAITING_MOTD, [
self.States.INIT_CAP_NEGOTIATION,
])
def on_start_motd(self, irc, msg):
'''On 375 (RPL_MOTDSTART)'''
self._transition(irc, msg, self.States.INIT_MOTD, [
self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD,
])
def on_end_motd(self, irc, msg):
'''On 376 (RPL_ENDOFMOTD) or 422 (ERR_NOMOTD)'''
self._transition(irc, msg, self.States.CONNECTED, [
self.States.INIT_CAP_NEGOTIATION,
self.States.INIT_WAITING_MOTD,
self.States.INIT_MOTD
])
def on_shutdown(self, irc, msg):
self._transition(irc, msg, self.States.SHUTTING_DOWN)
class IrcState(IrcCommandDispatcher, log.Firewalled):
"""Maintains state of the Irc connection. Should also become smarter.
"""
__firewalled__ = {'addMsg': None}
def __init__(self, history=None, supported=None,
nicksToHostmasks=None, channels=None,
capabilities_req=None,
capabilities_ack=None, capabilities_nak=None,
capabilities_ls=None):
self.fsm = IrcStateFsm()
if history is None:
history = RingBuffer(conf.supybot.protocols.irc.maxHistoryLength())
if supported is None:
@ -405,6 +530,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
nicksToHostmasks = ircutils.IrcDict()
if channels is None:
channels = ircutils.IrcDict()
self.capabilities_req = capabilities_req or set()
self.capabilities_ack = capabilities_ack or set()
self.capabilities_nak = capabilities_nak or set()
self.capabilities_ls = capabilities_ls or {}
@ -417,6 +543,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
def reset(self):
"""Resets the state to normal, unconnected state."""
self.fsm.reset()
self.history.reset()
self.history.resize(conf.supybot.protocols.irc.maxHistoryLength())
self.ircd = None
@ -424,6 +551,7 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
self.supported.clear()
self.nicksToHostmasks.clear()
self.batches = {}
self.capabilities_req = set()
self.capabilities_ack = set()
self.capabilities_nak = set()
self.capabilities_ls = {}
@ -1115,6 +1243,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendAuthenticationMessages()
self.state.fsm.on_init_messages_sent(self)
def sendAuthenticationMessages(self):
# Notes:
# * using sendMsg instead of queueMsg because these messages cannot
@ -1135,17 +1265,60 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.user(self.ident, self.user))
def endCapabilityNegociation(self):
if not self.capNegociationEnded:
self.capNegociationEnded = True
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
def capUpkeep(self, msg):
"""
Called after getting a CAP ACK/NAK to check it's consistent with what
was requested, and to end the cap negotiation when we received all the
ACK/NAKs we were waiting for.
`msg` is the message that triggered this call."""
self.state.fsm.expect_state([
# Normal CAP ACK / CAP NAK during cap negotiation
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
# CAP ACK / CAP NAK after a CAP NEW (probably)
IrcStateFsm.States.CONNECTED,
])
capabilities_responded = (self.state.capabilities_ack |
self.state.capabilities_nak)
if not capabilities_responded <= self.state.capabilities_req:
log.error('Server responded with unrequested ACK/NAK '
'capabilities: req=%r, ack=%r, nak=%r',
self.state.capabilities_req,
self.state.capabilities_ack,
self.state.capabilities_nak)
self.driver.reconnect(wait=True)
elif capabilities_responded == self.state.capabilities_req:
log.debug('Got all capabilities ACKed/NAKed')
# We got all the capabilities we asked for
if 'sasl' in self.state.capabilities_ack:
if self.state.fsm.state in [
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
IrcStateFsm.States.CONNECTED]:
self._maybeStartSasl(msg)
else:
pass # Already in the middle of a SASL auth
else:
self.endCapabilityNegociation(msg)
else:
log.debug('Waiting for ACK/NAK of capabilities: %r',
self.state.capabilities_req - capabilities_responded)
pass # Do nothing, we'll get more
def endCapabilityNegociation(self, msg):
self.state.fsm.on_cap_end(self, msg)
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',)))
def sendSaslString(self, string):
for chunk in ircutils.authenticate_generator(string):
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
args=(chunk,)))
def tryNextSaslMechanism(self):
def tryNextSaslMechanism(self, msg):
self.state.fsm.expect_state([
IrcStateFsm.States.INIT_SASL,
IrcStateFsm.States.CONNECTED_SASL,
])
if self.sasl_next_mechanisms:
self.sasl_current_mechanism = self.sasl_next_mechanisms.pop(0)
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
@ -1155,15 +1328,30 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
'aborting connection.')
else:
self.sasl_current_mechanism = None
self.endCapabilityNegociation()
self.state.fsm.on_sasl_auth_finished(self, msg)
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation(msg)
def filterSaslMechanisms(self, available):
available = set(map(str.lower, available))
self.sasl_next_mechanisms = [
x for x in self.sasl_next_mechanisms
if x.lower() in available]
def _maybeStartSasl(self, msg):
if not self.sasl_authenticated and \
'sasl' in self.state.capabilities_ack:
self.state.fsm.on_sasl_cap(self, msg)
assert 'sasl' in self.state.capabilities_ls, (
'Got "CAP ACK sasl" without receiving "CAP LS sasl" or '
'"CAP NEW sasl" first.')
s = self.state.capabilities_ls['sasl']
if s is not None:
available = set(map(str.lower, s.split(',')))
self.sasl_next_mechanisms = [
x for x in self.sasl_next_mechanisms
if x.lower() in available]
self.tryNextSaslMechanism(msg)
def doAuthenticate(self, msg):
self.state.fsm.expect_state([
IrcStateFsm.States.INIT_SASL,
IrcStateFsm.States.CONNECTED_SASL,
])
if not self.authenticate_decoder:
self.authenticate_decoder = ircutils.AuthenticateDecoder()
self.authenticate_decoder.feed(msg)
@ -1265,26 +1453,28 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
def do903(self, msg):
log.info('%s: SASL authentication successful', self.network)
self.sasl_authenticated = True
self.endCapabilityNegociation()
self.state.fsm.on_sasl_auth_finished(self, msg)
if self.state.fsm.state == IrcStateFsm.States.INIT_CAP_NEGOTIATION:
self.endCapabilityNegociation(msg)
def do904(self, msg):
log.warning('%s: SASL authentication failed (mechanism: %s)',
self.network, self.sasl_current_mechanism)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do905(self, msg):
log.warning('%s: SASL authentication failed because the username or '
'password is too long.', self.network)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do906(self, msg):
log.warning('%s: SASL authentication aborted', self.network)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do907(self, msg):
log.warning('%s: Attempted SASL authentication when we were already '
'authenticated.', self.network)
self.tryNextSaslMechanism()
self.tryNextSaslMechanism(msg)
def do908(self, msg):
log.info('%s: Supported SASL mechanisms: %s',
@ -1301,10 +1491,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.network, caps)
self.state.capabilities_ack.update(caps)
if 'sasl' in caps:
self.tryNextSaslMechanism()
else:
self.endCapabilityNegociation()
self.capUpkeep(msg)
def doCapNak(self, msg):
if len(msg.args) != 3:
log.warning('Bad CAP NAK from server: %r', msg)
@ -1314,31 +1502,76 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_nak.update(caps)
log.warning('%s: Server refused capabilities: %L',
self.network, caps)
self.endCapabilityNegociation()
def _addCapabilities(self, capstring):
self.capUpkeep(msg)
def _onCapSts(self, policy, msg):
secure_connection = self.driver.currentServer.force_tls_verification \
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
parsed_policy = ircutils.parseStsPolicy(
log, policy, parseDuration=secure_connection)
if parsed_policy is None:
# There was an error (and it was logged). Abort the connection.
self.driver.reconnect(wait=True)
return
if secure_connection:
# TLS is enabled and certificate is verified; write the STS policy
# in stone.
# For future-proofing (because we don't want to write an invalid
# value), we write the raw policy received from the server instead
# of the parsed one.
log.debug('Storing STS policy: %s', policy)
ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.currentServer.hostname, policy)
else:
hostname = self.driver.currentServer.hostname
log.info('Got STS policy over insecure connection; '
'reconnecting to secure port. %r',
self.driver.currentServer)
# Reconnect to the server, but with TLS *and* certificate
# validation this time.
self.state.fsm.on_shutdown(self, msg)
self.driver.reconnect(
server=Server(hostname, parsed_policy['port'], True),
wait=True)
def _addCapabilities(self, capstring, msg):
for item in capstring.split():
while item.startswith(('=', '~')):
item = item[1:]
if '=' in item:
(cap, value) = item.split('=', 1)
if cap == 'sts':
self._onCapSts(value, msg)
self.state.capabilities_ls[cap] = value
else:
if item == 'sts':
log.error('Got "sts" capability without value. Aborting '
'connection.')
self.driver.reconnect(wait=True)
self.state.capabilities_ls[item] = None
def doCapLs(self, msg):
if len(msg.args) == 4:
# Multi-line LS
if msg.args[2] != '*':
log.warning('Bad CAP LS from server: %r', msg)
return
self._addCapabilities(msg.args[3])
self._addCapabilities(msg.args[3], msg)
elif len(msg.args) == 3: # End of LS
self._addCapabilities(msg.args[2])
if 'sasl' in self.state.capabilities_ls:
s = self.state.capabilities_ls['sasl']
if s is not None:
self.filterSaslMechanisms(set(s.split(',')))
self._addCapabilities(msg.args[2], msg)
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return
self.state.fsm.expect_state([
# Normal case:
IrcStateFsm.States.INIT_CAP_NEGOTIATION,
# Should only happen if a plugin sends a CAP LS (which they
# shouldn't do):
IrcStateFsm.States.CONNECTED,
IrcStateFsm.States.CONNECTED_SASL,
])
# Normally at this point, self.state.capabilities_ack should be
# empty; but let's just make sure we're not requesting the same
# caps twice for no reason.
@ -1352,10 +1585,11 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
if new_caps:
self._requestCaps(new_caps)
else:
self.endCapabilityNegociation()
self.endCapabilityNegociation(msg)
else:
log.warning('Bad CAP LS from server: %r', msg)
return
def doCapDel(self, msg):
if len(msg.args) != 3:
log.warning('Bad CAP DEL from server: %r', msg)
@ -1374,18 +1608,18 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.state.capabilities_ack.remove(cap)
except KeyError:
pass
def doCapNew(self, msg):
# Note that in theory, this method may be called at any time, even
# before CAP END (or even before the initial CAP LS).
if len(msg.args) != 3:
log.warning('Bad CAP NEW from server: %r', msg)
return
caps = msg.args[2].split()
assert caps, 'Empty list of capabilities'
self._addCapabilities(msg.args[2])
if not self.sasl_authenticated and 'sasl' in self.state.capabilities_ls:
self.resetSasl()
s = self.state.capabilities_ls['sasl']
if s is not None:
self.filterSaslMechanisms(set(s.split(',')))
self._addCapabilities(msg.args[2], msg)
if self.state.fsm.state == IrcStateFsm.States.SHUTTING_DOWN:
return
common_supported_unrequested_capabilities = (
set(self.state.capabilities_ls) &
self.REQUEST_CAPABILITIES -
@ -1394,6 +1628,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self._requestCaps(common_supported_unrequested_capabilities)
def _requestCaps(self, caps):
self.state.capabilities_req |= caps
caps = ' '.join(sorted(caps))
# textwrap works here because in ASCII, all chars are 1 bytes:
cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :'))
@ -1473,7 +1709,12 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
"""Handles PONG messages."""
self.outstandingPing = False
def do375(self, msg):
self.state.fsm.on_start_motd(self, msg)
log.info('Got start of MOTD from %s', self.server)
def do376(self, msg):
self.state.fsm.on_end_motd(self, msg)
log.info('Got end of MOTD from %s', self.server)
self.afterConnect = True
# Let's reset nicks in case we had to use a weird one.

View File

@ -931,6 +931,35 @@ class AuthenticateDecoder(object):
return base64.b64decode(b''.join(self.chunks))
def parseStsPolicy(logger, policy, parseDuration):
parsed_policy = {}
for kv in policy.split(','):
if '=' in kv:
(k, v) = kv.split('=', 1)
parsed_policy[k] = v
else:
parsed_policy[kv] = None
for key in ('port', 'duration'):
if key == 'duration' and not parseDuration:
if key in parsed_policy:
del parsed_policy[key]
continue
if parsed_policy.get(key) is None:
logger.error('Missing or empty "%s" key in STS policy.'
'Aborting connection.', key)
return None
try:
parsed_policy[key] = int(parsed_policy[key])
except ValueError:
logger.error('Expected integer as value for key "%s" in STS '
'policy, got %r instead. Aborting connection.',
key, parsed_policy[key])
return None
return parsed_policy
numerics = {
# <= 2.10
# Reply

View File

@ -42,7 +42,7 @@ import multiprocessing
import re
from . import conf, drivers, ircutils, log, registry
from . import conf, ircutils, log, registry
from .utils import minisix
startedAt = time.time() # Just in case it doesn't get set later.
@ -193,6 +193,7 @@ def upkeep():
def makeDriversDie():
"""Kills drivers."""
from . import drivers
log.info('Killing Driver objects.')
for driver in drivers._drivers.values():
driver.die()

108
test/test_drivers.py Normal file
View File

@ -0,0 +1,108 @@
##
# Copyright (c) 2019, Valentin Lorentz
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
from supybot.test import *
import supybot.ircdb as ircdb
import supybot.irclib as irclib
import supybot.drivers as drivers
class DriversTestCase(SupyTestCase):
def tearDown(self):
ircdb.networks.networks = {}
def testValidStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
with conf.supybot.networks.test.servers.context(
['example.com:6667', 'example.org:6667']):
driver = drivers.ServersMixin(irc)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
def testExpiredStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
timeFastForward(16)
with conf.supybot.networks.test.servers.context(
['example.com:6667']):
driver = drivers.ServersMixin(irc)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6667, False))
def testRescheduledStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addDisconnection('example.com')
with conf.supybot.networks.test.servers.context(
['example.com:6667', 'example.org:6667']):
driver = drivers.ServersMixin(irc)
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))
driver.die()
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.org', 6667, False))
driver.die()
timeFastForward(8)
self.assertEqual(
driver._getNextServer(),
drivers.Server('example.com', 6697, True))

View File

@ -31,11 +31,13 @@ from supybot.test import *
import os
import unittest
import unittest.mock
import supybot.conf as conf
import supybot.world as world
import supybot.ircdb as ircdb
import supybot.ircutils as ircutils
from supybot.utils.minisix import io
class IrcdbTestCase(SupyTestCase):
def setUp(self):
@ -347,6 +349,49 @@ class IrcChannelTestCase(IrcdbTestCase):
c.removeBan(banmask)
self.assertFalse(c.checkIgnored(prefix))
class IrcNetworkTestCase(IrcdbTestCase):
def testDefaults(self):
n = ircdb.IrcNetwork()
self.assertEqual(n.stsPolicies, {})
self.assertEqual(n.lastDisconnectTimes, {})
def testStsPolicy(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'bar')
n.addStsPolicy('baz', 'qux')
self.assertEqual(n.stsPolicies, {
'foo': 'bar',
'baz': 'qux',
})
def testAddDisconnection(self):
n = ircdb.IrcNetwork()
min_ = int(time.time())
n.addDisconnection('foo')
max_ = int(time.time())
self.assertTrue(min_ <= n.lastDisconnectTimes['foo'] <= max_)
def testPreserve(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
disconnect_time_baz = n.lastDisconnectTimes['baz']
fd = io.StringIO()
n.preserve(fd, indent=' ')
fd.seek(0)
self.assertCountEqual(fd.read().split('\n'), [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
'',
])
class UsersDictionaryTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'UsersDictionaryTestCase.conf')
@ -401,6 +446,88 @@ class UsersDictionaryTestCase(IrcdbTestCase):
self.assertRaises(ValueError, self.users.setUser, u2)
class NetworksDictionaryTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'NetworksDictionaryTestCase.conf')
def setUp(self):
try:
os.remove(self.filename)
except:
pass
self.networks = ircdb.NetworksDictionary()
IrcdbTestCase.setUp(self)
def testGetSetNetwork(self):
self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {})
n = ircdb.IrcNetwork()
self.networks.setNetwork('foo', n)
self.assertEqual(self.networks.getNetwork('foo').stsPolicies, {})
def testPreserveOne(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
disconnect_time_baz = n.lastDisconnectTimes['baz']
self.networks.setNetwork('foonet', n)
fd = io.StringIO()
fd.close = lambda: None
self.networks.filename = 'blah'
original_Atomicfile = utils.file.AtomicFile
with unittest.mock.patch(
'supybot.utils.file.AtomicFile', return_value=fd):
self.networks.flush()
lines = fd.getvalue().split('\n')
self.assertEqual(lines.pop(0), 'network foonet')
self.assertCountEqual(lines, [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
'',
])
def testPreserveThree(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1')
self.networks.setNetwork('foonet', n)
n = ircdb.IrcNetwork()
n.addStsPolicy('bar', 'sts2')
self.networks.setNetwork('barnet', n)
n = ircdb.IrcNetwork()
n.addStsPolicy('baz', 'sts3')
self.networks.setNetwork('baznet', n)
fd = io.StringIO()
fd.close = lambda: None
self.networks.filename = 'blah'
original_Atomicfile = utils.file.AtomicFile
with unittest.mock.patch(
'supybot.utils.file.AtomicFile', return_value=fd):
self.networks.flush()
fd.seek(0)
self.assertEqual(fd.getvalue(),
'network barnet\n'
' stsPolicy bar sts2\n'
'\n'
'network baznet\n'
' stsPolicy baz sts3\n'
'\n'
'network foonet\n'
' stsPolicy foo sts1\n'
'\n'
)
class CheckCapabilityTestCase(IrcdbTestCase):
filename = os.path.join(conf.supybot.directories.conf(),
'CheckCapabilityTestCase.conf')

View File

@ -27,14 +27,15 @@
# POSSIBILITY OF SUCH DAMAGE.
###
from supybot.test import *
import copy
import pickle
import warnings
import unittest.mock
from supybot.test import *
import supybot.conf as conf
import supybot.irclib as irclib
import supybot.drivers as drivers
import supybot.ircmsgs as ircmsgs
import supybot.ircutils as ircutils
@ -497,6 +498,88 @@ class IrcCapsTestCase(SupyTestCase):
self.assertEqual(m.args[0], 'REQ', m)
self.assertEqual(m.args[1], 'b'*400)
class StsTestCase(SupyTestCase):
def setUp(self):
self.irc = irclib.Irc('test')
m = self.irc.takeMsg()
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'NICK', 'Expected NICK, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m)
self.irc.driver = unittest.mock.Mock()
def tearDown(self):
ircdb.networks.networks = {}
def testStsInSecureConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
'irc.test': 'duration=42,port=6697'})
self.irc.driver.reconnect.assert_not_called()
def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
wait=True)
def testStsInCleartextConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
wait=True)
def testStsInCleartextConnectionInvalidDuration(self):
# "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=foo,port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
wait=True)
def testStsInCleartextConnectionNoDuration(self):
# "Servers MAY send this key to all clients, but insecurely
# connected clients MUST ignore it."
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=port=6697')))
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {})
self.irc.driver.reconnect.assert_called_once_with(
server=drivers.Server('irc.test', 6697, True),
wait=True)
class IrcTestCase(SupyTestCase):
def setUp(self):
self.irc = irclib.Irc('test')
@ -832,6 +915,8 @@ class SaslTestCase(SupyTestCase):
while self.irc.takeMsg():
pass
self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'NEW', 'sasl=EXTERNAL')))