Add support for STARTTLS (not tested).

This commit is contained in:
Valentin Lorentz 2015-12-12 16:40:48 +01:00
parent 30cb10e422
commit 4b1c766b42
3 changed files with 56 additions and 19 deletions

View File

@ -323,6 +323,10 @@ def registerNetwork(name, password='', ssl=False, sasl_username='',
registerGlobalValue(network, 'ssl', registry.Boolean(ssl, registerGlobalValue(network, 'ssl', registry.Boolean(ssl,
_("""Determines whether the bot will attempt to connect with SSL _("""Determines whether the bot will attempt to connect with SSL
sockets to %s.""") % name)) sockets to %s.""") % name))
registerGlobalValue(network, 'requireStarttls', registry.Boolean(False,
_("""Determines whether the bot will connect in plain text to %s
but require STARTTLS before authentication. This is ignored if the
connection already uses SSL.""") % name))
registerGlobalValue(network, 'certfile', registry.String('', registerGlobalValue(network, 'certfile', registry.String('',
_("""Determines what certificate file (if any) the bot will use to _("""Determines what certificate file (if any) the bot will use to
connect with SSL sockets to %s.""") % name)) connect with SSL sockets to %s.""") % name))

View File

@ -79,7 +79,9 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
'servers for your Python version. Try the ' 'servers for your Python version. Try the '
'Twisted driver instead, or install a Python' 'Twisted driver instead, or install a Python'
'version that supports SSL (2.6 and greater).') 'version that supports SSL (2.6 and greater).')
self.ssl = False
else: else:
self.ssl = self.networkGroup.get('ssl').value
self.connect() self.connect()
def getDelay(self): def getDelay(self):
@ -273,18 +275,7 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10)) self.conn.settimeout(max(10, conf.supybot.drivers.poll()*10))
try: try:
if getattr(conf.supybot.networks, self.irc.network).ssl(): if getattr(conf.supybot.networks, self.irc.network).ssl():
assert 'ssl' in globals() self.starttls()
certfile = getattr(conf.supybot.networks, self.irc.network) \
.certfile()
if not certfile:
certfile = conf.supybot.protocols.irc.certfile()
if not certfile:
certfile = None
elif not os.path.isfile(certfile):
drivers.log.warning('Could not find cert file %s.' %
certfile)
certfile = None
self.conn = ssl.wrap_socket(self.conn, certfile=certfile)
self.conn.connect((address, server[1])) self.conn.connect((address, server[1]))
def setTimeout(): def setTimeout():
self.conn.settimeout(conf.supybot.drivers.poll()) self.conn.settimeout(conf.supybot.drivers.poll())
@ -348,6 +339,20 @@ class SocketDriver(drivers.IrcDriver, drivers.ServersMixin):
def name(self): def name(self):
return '%s(%s)' % (self.__class__.__name__, self.irc) return '%s(%s)' % (self.__class__.__name__, self.irc)
def starttls(self):
assert 'ssl' in globals()
certfile = getattr(conf.supybot.networks, self.irc.network) \
.certfile()
if not certfile:
certfile = conf.supybot.protocols.irc.certfile()
if not certfile:
certfile = None
elif not os.path.isfile(certfile):
drivers.log.warning('Could not find cert file %s.' %
certfile)
certfile = None
self.conn = ssl.wrap_socket(self.conn, certfile=certfile)
Driver = SocketDriver Driver = SocketDriver

View File

@ -958,14 +958,15 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
def _setNonResettingVariables(self): def _setNonResettingVariables(self):
# Configuration stuff. # Configuration stuff.
network_config = conf.supybot.networks.get(self.network)
def get_value(name): def get_value(name):
return getattr(conf.supybot.networks.get(self.network), name)() or \ return getattr(network_config, name)() or \
getattr(conf.supybot, name)() getattr(conf.supybot, name)()
self.nick = get_value('nick') self.nick = get_value('nick')
self.user = get_value('user') self.user = get_value('user')
self.ident = get_value('ident') self.ident = get_value('ident')
self.alternateNicks = conf.supybot.nick.alternates()[:] self.alternateNicks = conf.supybot.nick.alternates()[:]
self.password = conf.supybot.networks.get(self.network).password() self.password = network_config.password()
self.prefix = '%s!%s@%s' % (self.nick, self.ident, 'unset.domain') self.prefix = '%s!%s@%s' % (self.nick, self.ident, 'unset.domain')
# The rest. # The rest.
self.lastTake = 0 self.lastTake = 0
@ -975,6 +976,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.lastping = time.time() self.lastping = time.time()
self.outstandingPing = False self.outstandingPing = False
self.capNegociationEnded = False self.capNegociationEnded = False
self.requireStarttls = not network_config.ssl() and \
network_config.requireStarttls()
self.resetSasl() self.resetSasl()
def resetSasl(self): def resetSasl(self):
@ -999,6 +1002,9 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sasl_username and self.sasl_password: self.sasl_username and self.sasl_password:
self.sasl_next_mechanisms.append(mechanism) self.sasl_next_mechanisms.append(mechanism)
if self.sasl_next_mechanisms:
self.REQUEST_CAPABILITIES.add('sasl')
REQUEST_CAPABILITIES = set(['account-notify', 'extended-join', REQUEST_CAPABILITIES = set(['account-notify', 'extended-join',
'multi-prefix', 'metadata-notify', 'account-tag', 'multi-prefix', 'metadata-notify', 'account-tag',
@ -1012,15 +1018,33 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
return return
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('LS', '302')))
if self.requireStarttls:
self.sendMsg(ircmsgs.IrcMsg(command='STARTTLS'))
else:
self.sendAuthenticationMessages()
def do670(self, irc, msg):
"""STARTTLS accepted."""
log.info('%s: Starting TLS session.', self.network)
self.requireStarttls = False
self.driver.starttls()
self.sendAuthenticationMessages()
def do691(self, irc, msg):
"""STARTTLS refused."""
log.error('%s: Server refused STARTTLS: %s', self.network, msg.args[0])
self.feedMsg(ircmsgs.error('STARTTLS upgrade refused by the server'))
self.driver.reconnect()
def sendAuthenticationMessages(self):
# Notes: # Notes:
# * using sendMsg instead of queueMsg because these messages cannot # * using sendMsg instead of queueMsg because these messages cannot
# be throttled. # be throttled.
self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('LS', '302')))
if self.password: if self.password:
log.info('%s: Queuing PASS command, not logging the password.', log.info('%s: Queuing PASS command, not logging the password.',
self.network) self.network)
self.sendMsg(ircmsgs.password(self.password)) self.sendMsg(ircmsgs.password(self.password))
log.debug('%s: Sending NICK command, nick is %s.', log.debug('%s: Sending NICK command, nick is %s.',
@ -1033,9 +1057,6 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.sendMsg(ircmsgs.user(self.ident, self.user)) self.sendMsg(ircmsgs.user(self.ident, self.user))
if self.sasl_next_mechanisms:
self.REQUEST_CAPABILITIES.add('sasl')
def endCapabilityNegociation(self): def endCapabilityNegociation(self):
if not self.capNegociationEnded: if not self.capNegociationEnded:
self.capNegociationEnded = True self.capNegociationEnded = True
@ -1182,6 +1203,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
s = self.state.capabilities_ls['sasl'] s = self.state.capabilities_ls['sasl']
if s is not None: if s is not None:
self.filterSaslMechanisms(set(s.split(','))) self.filterSaslMechanisms(set(s.split(',')))
if 'starttls' not in self.state.capabilities_ls and \
self.requireStarttls:
log.error('%s: Server does not support STARTTLS.', self.network)
self.feedMsg(ircmsgs.error('STARTTLS upgrade not supported '
'by the server'))
self.die()
return
# NOTE: Capabilities are requested in alphabetic order, because # NOTE: Capabilities are requested in alphabetic order, because
# sets are unordered, and their "order" is nondeterministic. # sets are unordered, and their "order" is nondeterministic.
# This is needed for the tests. # This is needed for the tests.