diff --git a/src/irclib.py b/src/irclib.py index a8af82513..3d854d2c3 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -974,10 +974,12 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.startedAt = time.time() self.lastping = time.time() self.outstandingPing = False + self.capNegociationEnded = False self.resetSasl() def resetSasl(self): network_config = conf.supybot.networks.get(self.network) + self.sasl_authenticated = False self.sasl_username = network_config.sasl.username() self.sasl_password = network_config.sasl.password() self.sasl_ecdsa_key = network_config.sasl.ecdsa_key() @@ -1034,6 +1036,11 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if self.sasl_next_mechanisms: self.REQUEST_CAPABILITIES.add('sasl') + def endCapabilityNegociation(self): + if not self.capNegociationEnded: + self.capNegociationEnded = True + self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + def sendSaslString(self, string): for chunk in ircutils.authenticate_generator(string): self.sendMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', @@ -1046,12 +1053,13 @@ class Irc(IrcCommandDispatcher, log.Firewalled): args=(self.sasl_current_mechanism.upper(),))) else: self.sasl_current_mechanism = None - self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + self.endCapabilityNegociation() def filterSaslMechanisms(self, available): + available = set(map(str.lower, available)) self.sasl_next_mechanisms = [ x for x in self.sasl_next_mechanisms - if x in available] + if x.lower() in available] def doAuthenticate(self, msg): if not self.authenticate_decoder: @@ -1088,7 +1096,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): def do903(self, msg): log.info('%s: SASL authentication successful', self.network) - self.queueMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + self.sasl_authenticated = True + self.endCapabilityNegociation() def do904(self, msg): log.warning('%s: SASL authentication failed', self.network) @@ -1123,6 +1132,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.doCapLs(msg) elif subcommand == 'DEL': self.doCapDel(msg) + elif subcommand == 'NEW': + self.doCapNew(msg) def doCapAck(self, msg): if len(msg.args) != 3: log.warning('Bad CAP ACK from server: %r', msg) @@ -1136,7 +1147,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): if 'sasl' in caps: self.tryNextSaslMechanism() else: - self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + self.endCapabilityNegociation() def doCapNak(self, msg): if len(msg.args) != 3: log.warning('Bad CAP NAK from server: %r', msg) @@ -1146,7 +1157,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.state.capabilities_nak.update(caps) log.warning('%s: Server refused capabilities: %L', self.network, caps) - self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('END',))) + self.endCapabilityNegociation() def _addCapabilities(self, capstring): for item in capstring.split(): while item.startswith(('=', '~')): @@ -1179,8 +1190,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self.sendMsg(ircmsgs.IrcMsg(command='CAP', args=('REQ', caps))) else: - self.sendMsg(ircmsgs.IrcMsg(command='CAP', - args=('END',))) + self.endCapabilityNegociation() else: log.warning('Bad CAP LS from server: %r', msg) return @@ -1199,9 +1209,29 @@ class Irc(IrcCommandDispatcher, log.Firewalled): except KeyError: pass try: - del self.state.capabilities_ack[cap] + self.state.capabilities_ack.remove(cap) except KeyError: pass + def doCapNew(self, msg): + if len(msg.args) != 3: + log.warning('Bad CAP NEW from server: %r', msg) + return + caps = msg.args[2].split() + assert caps, 'Empty list of capabilities' + self._addCapabilities(msg.args[2]) + if not self.sasl_authenticated and 'sasl' in self.state.capabilities_ls: + self.resetSasl() + s = self.state.capabilities_ls['sasl'] + if s is not None: + self.filterSaslMechanisms(set(s.split(','))) + common_supported_unrequested_capabilities = ( + set(self.state.capabilities_ls) & + self.REQUEST_CAPABILITIES - + self.state.capabilities_ack) + if common_supported_unrequested_capabilities: + caps = ' '.join(sorted(common_supported_unrequested_capabilities)) + self.sendMsg(ircmsgs.IrcMsg(command='CAP', + args=('REQ', caps))) def monitor(self, targets): """Increment a counter of how many callbacks monitor each target; diff --git a/test/test_irclib.py b/test/test_irclib.py index 963600919..882967158 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -531,7 +531,7 @@ class SaslTestCase(SupyTestCase): def setUp(self): pass - def startCapNegociation(self, sasl_attributes=None): + def startCapNegociation(self, caps='sasl'): m = self.irc.takeMsg() self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m) self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m) @@ -542,20 +542,17 @@ class SaslTestCase(SupyTestCase): m = self.irc.takeMsg() self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m) - if sasl_attributes: - self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', - args=('*', 'LS', 'sasl=%s' % sasl_attributes))) - else: - self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', - args=('*', 'LS', 'sasl'))) - - m = self.irc.takeMsg() - self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m) - self.assertEqual(m.args[0], 'REQ', m) - self.assertEqual(m.args[1], 'sasl') - self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', - args=('*', 'ACK', 'sasl'))) + args=('*', 'LS', caps))) + + if caps: + m = self.irc.takeMsg() + self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m) + self.assertEqual(m.args[0], 'REQ', m) + self.assertEqual(m.args[1], 'sasl') + + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'ACK', 'sasl'))) def endCapNegociation(self): m = self.irc.takeMsg() @@ -642,7 +639,7 @@ class SaslTestCase(SupyTestCase): self.assertEqual(self.irc.sasl_next_mechanisms, ['external', 'plain']) - self.startCapNegociation(sasl_attributes='foo,plain,bar') + self.startCapNegociation(caps='sasl=foo,plain,bar') m = self.irc.takeMsg() self.assertEqual(m, ircmsgs.IrcMsg(command='AUTHENTICATE', @@ -659,6 +656,61 @@ class SaslTestCase(SupyTestCase): self.endCapNegociation() + def testReauthenticate(self): + try: + conf.supybot.networks.test.sasl.username.setValue('jilles') + conf.supybot.networks.test.sasl.password.setValue('sesame') + self.irc = irclib.Irc('test') + finally: + conf.supybot.networks.test.sasl.username.setValue('') + conf.supybot.networks.test.sasl.password.setValue('') + self.assertEqual(self.irc.sasl_current_mechanism, None) + self.assertEqual(self.irc.sasl_next_mechanisms, ['plain']) + + self.startCapNegociation(caps='') + + self.endCapNegociation() + + while self.irc.takeMsg(): + pass + + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'NEW', 'sasl=EXTERNAL'))) + + self.irc.takeMsg() # None. But even if it was CAP REQ sasl, it would be ok + self.assertEqual(self.irc.takeMsg(), None) + + try: + conf.supybot.networks.test.sasl.username.setValue('jilles') + conf.supybot.networks.test.sasl.password.setValue('sesame') + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'DEL', 'sasl'))) + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'NEW', 'sasl=PLAIN'))) + finally: + conf.supybot.networks.test.sasl.username.setValue('') + conf.supybot.networks.test.sasl.password.setValue('') + + m = self.irc.takeMsg() + self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m) + self.assertEqual(m.args[0], 'REQ', m) + self.assertEqual(m.args[1], 'sasl') + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'ACK', 'sasl'))) + + m = self.irc.takeMsg() + self.assertEqual(m, ircmsgs.IrcMsg(command='AUTHENTICATE', + args=('PLAIN',))) + + self.irc.feedMsg(ircmsgs.IrcMsg(command='AUTHENTICATE', args=('+',))) + + m = self.irc.takeMsg() + self.assertEqual(m, ircmsgs.IrcMsg(command='AUTHENTICATE', + args=('amlsbGVzAGppbGxlcwBzZXNhbWU=',))) + + self.irc.feedMsg(ircmsgs.IrcMsg(command='900', args=('jilles',))) + self.irc.feedMsg(ircmsgs.IrcMsg(command='903', args=('jilles',))) + class IrcCallbackTestCase(SupyTestCase):