Handle AUTHENTICATE line splitting.

This commit is contained in:
Valentin Lorentz 2015-12-10 20:08:53 +01:00
parent 6a669c1483
commit 15d59d1153
4 changed files with 98 additions and 28 deletions

View File

@ -968,19 +968,11 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.password = conf.supybot.networks.get(self.network).password() self.password = conf.supybot.networks.get(self.network).password()
self.sasl_username = \ self.sasl_username = \
conf.supybot.networks.get(self.network).sasl.username() conf.supybot.networks.get(self.network).sasl.username()
# TODO Find a better way to fix this
if hasattr(self.sasl_username, 'decode'):
self.sasl_username = self.sasl_username.decode('utf-8')
self.sasl_password = \ self.sasl_password = \
conf.supybot.networks.get(self.network).sasl.password() conf.supybot.networks.get(self.network).sasl.password()
# TODO Find a better way to fix this
if hasattr(self.sasl_password, 'decode'):
self.sasl_password = self.sasl_password.decode('utf-8')
self.sasl_ecdsa_key = \ self.sasl_ecdsa_key = \
conf.supybot.networks.get(self.network).sasl.ecdsa_key() conf.supybot.networks.get(self.network).sasl.ecdsa_key()
# TODO Find a better way to fix this self.authenticate_decoder = None
if hasattr(self.sasl_ecdsa_key, 'decode'):
self.sasl_ecdsa_key = self.sasl_ecdsa_key.decode('utf-8')
self.prefix = '%s!%s@%s' % (self.nick, self.ident, 'unset.domain') self.prefix = '%s!%s@%s' % (self.nick, self.ident, 'unset.domain')
# The rest. # The rest.
self.lastTake = 0 self.lastTake = 0
@ -1038,33 +1030,41 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.REQUEST_CAPABILITIES.add('sasl') self.REQUEST_CAPABILITIES.add('sasl')
def doAuthenticate(self, msg): def doAuthenticate(self, msg):
if len(msg.args) == 1 and msg.args[0] == '+': if not self.authenticate_decoder:
self.authenticate_decoder = ircutils.AuthenticateDecoder()
self.authenticate_decoder.feed(msg)
if not self.authenticate_decoder.ready:
return # Waiting for other messages
string = self.authenticate_decoder.get()
self.authenticate_decoder = None
if string == b'':
log.info('%s: Authenticating using SASL.', self.network) log.info('%s: Authenticating using SASL.', self.network)
if self.sasl == 'external': if self.sasl == 'external':
authstring = '+' authstring = b''
elif self.sasl == 'ecdsa-nist256p-challenge': elif self.sasl == 'ecdsa-nist256p-challenge':
authstring = base64.b64encode( authstring = self.sasl_username.encode('utf-8')
self.sasl_username.encode('utf-8')).decode('utf-8')
elif self.sasl == 'plain': elif self.sasl == 'plain':
authstring = base64.b64encode('\0'.join([ authstring = b'\0'.join([
self.sasl_username, self.sasl_username.encode('utf-8'),
self.sasl_username, self.sasl_username.encode('utf-8'),
self.sasl_password self.sasl_password.encode('utf-8'),
]).encode('utf-8')).decode('utf-8') ])
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', args=(authstring,))) for chunk in ircutils.authenticate_generator(authstring):
elif (len(msg.args) == 1 and msg.args[0] != '+' and self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
self.sasl == 'ecdsa-nist256p-challenge'): args=(chunk,)))
elif (string != b'' and self.sasl == 'ecdsa-nist256p-challenge'):
try: try:
private_key = SigningKey.from_pem(open(self.sasl_ecdsa_key). with open(self.sasl_ecdsa_key) as fd:
read()) private_key = SigningKey.from_pem(fd.read())
authstring = base64.b64encode( authstring = private_key.sign(base64.b64decode(msg.args[0].encode()))
private_key.sign(base64.b64decode(msg.args[0].encode()))).decode('utf-8') chunks = ircutils.authenticate_generator(authstring)
except (BadDigestError, OSError, ValueError): except (BadDigestError, OSError, ValueError):
authstring = "*" chunks = ['*']
for chunk in chunks:
self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', args=(authstring,))) self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE',
args=(chunk,)))
def doCap(self, msg): def doCap(self, msg):
subcommand = msg.args[1] subcommand = msg.args[1]

View File

@ -37,6 +37,7 @@ object (which, as you'll read later, is quite...full-featured :))
import re import re
import time import time
import base64
import datetime import datetime
import functools import functools
@ -880,6 +881,7 @@ def monitor(subcommand, nicks=None, prefix='', msg=None):
return IrcMsg(prefix=prefix, command='MONITOR', args=(subcommand, nicks), return IrcMsg(prefix=prefix, command='MONITOR', args=(subcommand, nicks),
msg=msg) msg=msg)
def error(s, msg=None): def error(s, msg=None):
return IrcMsg(command='ERROR', args=(s,), msg=msg) return IrcMsg(command='ERROR', args=(s,), msg=msg)

View File

@ -41,6 +41,7 @@ from __future__ import print_function
import re import re
import sys import sys
import time import time
import base64
import random import random
import string import string
import textwrap import textwrap
@ -868,6 +869,38 @@ def standardSubstitute(irc, msg, text, env=None):
t.idpattern = '[a-zA-Z][a-zA-Z0-9]*' t.idpattern = '[a-zA-Z][a-zA-Z0-9]*'
return t.safe_substitute(vars) return t.safe_substitute(vars)
AUTHENTICATE_CHUNK_SIZE = 400
def authenticate_generator(authstring, base64ify=True):
if base64ify:
authstring = base64.b64encode(authstring)
if minisix.PY3:
authstring = authstring.decode()
# +1 so we get an empty string at the end if len(authstring) is a multiple
# of AUTHENTICATE_CHUNK_SIZE (including 0)
for n in range(0, len(authstring)+1, AUTHENTICATE_CHUNK_SIZE):
chunk = authstring[n:n+AUTHENTICATE_CHUNK_SIZE] or '+'
yield chunk
class AuthenticateDecoder(object):
def __init__(self):
self.chunks = []
self.ready = False
def feed(self, msg):
assert msg.command == 'AUTHENTICATE'
chunk = msg.args[0]
if chunk == '+' or len(chunk) != AUTHENTICATE_CHUNK_SIZE:
self.ready = True
if chunk != '+':
if minisix.PY3:
chunk = chunk.encode()
self.chunks.append(chunk)
def get(self):
assert self.ready
return base64.b64decode(b''.join(self.chunks))
numerics = { numerics = {
# <= 2.10 # <= 2.10
# Reply # Reply

View File

@ -391,6 +391,41 @@ class IrcStringTestCase(SupyTestCase):
self.failUnless(s1 == s2) self.failUnless(s1 == s2)
self.failIf(s1 != s2) self.failIf(s1 != s2)
class AuthenticateTestCase(SupyTestCase):
PAIRS = [
(b'', ['+']),
(b'foo'*150, [
'Zm9v'*100,
'Zm9v'*50
]),
(b'foo'*200, [
'Zm9v'*100,
'Zm9v'*100,
'+'])
]
def assertMessages(self, got, should):
got = list(got)
for (s1, s2) in zip(got, should):
self.assertEqual(s1, s2, (got, should))
def testGenerator(self):
for (decoded, encoded) in self.PAIRS:
self.assertMessages(
ircutils.authenticate_generator(decoded),
encoded)
def testDecoder(self):
for (decoded, encoded) in self.PAIRS:
decoder = ircutils.AuthenticateDecoder()
for chunk in encoded:
self.assertFalse(decoder.ready, (decoded, encoded))
decoder.feed(ircmsgs.IrcMsg(command='AUTHENTICATE',
args=(chunk,)))
self.assertTrue(decoder.ready)
self.assertEqual(decoder.get(), decoded)
# vim:set shiftwidth=4 softtabstop=4 expandtab textwidth=79: # vim:set shiftwidth=4 softtabstop=4 expandtab textwidth=79: