diff --git a/src/irclib.py b/src/irclib.py index 867db17a0..aa8f71aae 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -968,19 +968,11 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.password = conf.supybot.networks.get(self.network).password() self.sasl_username = \ conf.supybot.networks.get(self.network).sasl.username() - # TODO Find a better way to fix this - if hasattr(self.sasl_username, 'decode'): - self.sasl_username = self.sasl_username.decode('utf-8') self.sasl_password = \ conf.supybot.networks.get(self.network).sasl.password() - # TODO Find a better way to fix this - if hasattr(self.sasl_password, 'decode'): - self.sasl_password = self.sasl_password.decode('utf-8') self.sasl_ecdsa_key = \ conf.supybot.networks.get(self.network).sasl.ecdsa_key() - # TODO Find a better way to fix this - if hasattr(self.sasl_ecdsa_key, 'decode'): - self.sasl_ecdsa_key = self.sasl_ecdsa_key.decode('utf-8') + self.authenticate_decoder = None self.prefix = '%s!%s@%s' % (self.nick, self.ident, 'unset.domain') # The rest. self.lastTake = 0 @@ -1038,33 +1030,41 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.REQUEST_CAPABILITIES.add('sasl') def doAuthenticate(self, msg): - if len(msg.args) == 1 and msg.args[0] == '+': + if not self.authenticate_decoder: + self.authenticate_decoder = ircutils.AuthenticateDecoder() + self.authenticate_decoder.feed(msg) + if not self.authenticate_decoder.ready: + return # Waiting for other messages + string = self.authenticate_decoder.get() + self.authenticate_decoder = None + if string == b'': log.info('%s: Authenticating using SASL.', self.network) if self.sasl == 'external': - authstring = '+' + authstring = b'' elif self.sasl == 'ecdsa-nist256p-challenge': - authstring = base64.b64encode( - self.sasl_username.encode('utf-8')).decode('utf-8') + authstring = self.sasl_username.encode('utf-8') elif self.sasl == 'plain': - authstring = base64.b64encode('\0'.join([ - self.sasl_username, - self.sasl_username, - self.sasl_password - ]).encode('utf-8')).decode('utf-8') + authstring = b'\0'.join([ + self.sasl_username.encode('utf-8'), + self.sasl_username.encode('utf-8'), + self.sasl_password.encode('utf-8'), + ]) - self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', args=(authstring,))) - elif (len(msg.args) == 1 and msg.args[0] != '+' and - self.sasl == 'ecdsa-nist256p-challenge'): + for chunk in ircutils.authenticate_generator(authstring): + self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', + args=(chunk,))) + elif (string != b'' and self.sasl == 'ecdsa-nist256p-challenge'): try: - private_key = SigningKey.from_pem(open(self.sasl_ecdsa_key). - read()) - authstring = base64.b64encode( - private_key.sign(base64.b64decode(msg.args[0].encode()))).decode('utf-8') + with open(self.sasl_ecdsa_key) as fd: + private_key = SigningKey.from_pem(fd.read()) + authstring = private_key.sign(base64.b64decode(msg.args[0].encode())) + chunks = ircutils.authenticate_generator(authstring) except (BadDigestError, OSError, ValueError): - authstring = "*" - - self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', args=(authstring,))) + chunks = ['*'] + for chunk in chunks: + self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', + args=(chunk,))) def doCap(self, msg): subcommand = msg.args[1] diff --git a/src/ircmsgs.py b/src/ircmsgs.py index a2062215a..5b27a7675 100644 --- a/src/ircmsgs.py +++ b/src/ircmsgs.py @@ -37,6 +37,7 @@ object (which, as you'll read later, is quite...full-featured :)) import re import time +import base64 import datetime import functools @@ -880,6 +881,7 @@ def monitor(subcommand, nicks=None, prefix='', msg=None): return IrcMsg(prefix=prefix, command='MONITOR', args=(subcommand, nicks), msg=msg) + def error(s, msg=None): return IrcMsg(command='ERROR', args=(s,), msg=msg) diff --git a/src/ircutils.py b/src/ircutils.py index f4df17b31..264adc09c 100644 --- a/src/ircutils.py +++ b/src/ircutils.py @@ -41,6 +41,7 @@ from __future__ import print_function import re import sys import time +import base64 import random import string import textwrap @@ -868,6 +869,38 @@ def standardSubstitute(irc, msg, text, env=None): t.idpattern = '[a-zA-Z][a-zA-Z0-9]*' return t.safe_substitute(vars) + + +AUTHENTICATE_CHUNK_SIZE = 400 +def authenticate_generator(authstring, base64ify=True): + if base64ify: + authstring = base64.b64encode(authstring) + if minisix.PY3: + authstring = authstring.decode() + # +1 so we get an empty string at the end if len(authstring) is a multiple + # of AUTHENTICATE_CHUNK_SIZE (including 0) + for n in range(0, len(authstring)+1, AUTHENTICATE_CHUNK_SIZE): + chunk = authstring[n:n+AUTHENTICATE_CHUNK_SIZE] or '+' + yield chunk + +class AuthenticateDecoder(object): + def __init__(self): + self.chunks = [] + self.ready = False + def feed(self, msg): + assert msg.command == 'AUTHENTICATE' + chunk = msg.args[0] + if chunk == '+' or len(chunk) != AUTHENTICATE_CHUNK_SIZE: + self.ready = True + if chunk != '+': + if minisix.PY3: + chunk = chunk.encode() + self.chunks.append(chunk) + def get(self): + assert self.ready + return base64.b64decode(b''.join(self.chunks)) + + numerics = { # <= 2.10 # Reply diff --git a/test/test_ircutils.py b/test/test_ircutils.py index 35baefdd3..908fca7b2 100644 --- a/test/test_ircutils.py +++ b/test/test_ircutils.py @@ -391,6 +391,41 @@ class IrcStringTestCase(SupyTestCase): self.failUnless(s1 == s2) self.failIf(s1 != s2) +class AuthenticateTestCase(SupyTestCase): + PAIRS = [ + (b'', ['+']), + (b'foo'*150, [ + 'Zm9v'*100, + 'Zm9v'*50 + ]), + (b'foo'*200, [ + 'Zm9v'*100, + 'Zm9v'*100, + '+']) + ] + def assertMessages(self, got, should): + got = list(got) + for (s1, s2) in zip(got, should): + self.assertEqual(s1, s2, (got, should)) + + def testGenerator(self): + for (decoded, encoded) in self.PAIRS: + self.assertMessages( + ircutils.authenticate_generator(decoded), + encoded) + + def testDecoder(self): + for (decoded, encoded) in self.PAIRS: + decoder = ircutils.AuthenticateDecoder() + for chunk in encoded: + self.assertFalse(decoder.ready, (decoded, encoded)) + decoder.feed(ircmsgs.IrcMsg(command='AUTHENTICATE', + args=(chunk,))) + self.assertTrue(decoder.ready) + self.assertEqual(decoder.get(), decoded) + + + # vim:set shiftwidth=4 softtabstop=4 expandtab textwidth=79: