diff --git a/src/irclib.py b/src/irclib.py index f6dba1825..14f57aea8 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1358,7 +1358,10 @@ class Irc(IrcCommandDispatcher, log.Firewalled): self._maybeStartSasl(msg) else: pass # Already in the middle of a SASL auth - else: + elif self.state.fsm.state != IrcStateFsm.States.CONNECTED: + # If we are still in the initial cap negotiation (ie. if this + # is not in response to a 'CAP NEW'), send a CAP END so the + # server sends us the MOTD self.endCapabilityNegociation(msg) else: log.debug('Waiting for ACK/NAK of capabilities: %r', diff --git a/test/test_irclib.py b/test/test_irclib.py index a3c328e5e..ca7ea0d97 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -45,6 +45,38 @@ msgs = [] rawmsgs = [] +class CapNegMixin: + """Utilities for handling the capability negotiation.""" + + def startCapNegociation(self, caps='sasl'): + m = self.irc.takeMsg() + self.assertTrue(m.command == 'CAP', 'Expected CAP, got %r.' % m) + self.assertTrue(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m) + + m = self.irc.takeMsg() + self.assertTrue(m.command == 'NICK', 'Expected NICK, got %r.' % m) + + m = self.irc.takeMsg() + self.assertTrue(m.command == 'USER', 'Expected USER, got %r.' % m) + + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'LS', caps))) + + if caps: + m = self.irc.takeMsg() + self.assertTrue(m.command == 'CAP', 'Expected CAP, got %r.' % m) + self.assertEqual(m.args[0], 'REQ', m) + self.assertEqual(m.args[1], 'sasl') + + self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', + args=('*', 'ACK', 'sasl'))) + + def endCapNegociation(self): + m = self.irc.takeMsg() + self.assertTrue(m.command == 'CAP', 'Expected CAP, got %r.' % m) + self.assertEqual(m.args, ('END',), m) + + class IrcCommandDispatcherTestCase(SupyTestCase): class DispatchedClass(irclib.IrcCommandDispatcher): def doPrivmsg(): @@ -473,7 +505,8 @@ class IrcStateTestCase(SupyTestCase): st = irclib.IrcState() self.assert_(st.addMsg(self.irc, ircmsgs.IrcMsg('MODE foo +i')) or 1) -class IrcCapsTestCase(SupyTestCase): + +class IrcCapsTestCase(SupyTestCase, CapNegMixin): def testReqLineLength(self): self.irc = irclib.Irc('test') @@ -572,6 +605,74 @@ class IrcCapsTestCase(SupyTestCase): m = self.irc.takeMsg() self.assertIsNone(m) + def testCapNew(self): + self.irc = irclib.Irc('test') + + self.assertEqual(self.irc.sasl_current_mechanism, None) + self.assertEqual(self.irc.sasl_next_mechanisms, []) + + self.startCapNegociation(caps='') + + self.endCapNegociation() + + while self.irc.takeMsg(): + pass + + self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD + + m = self.irc.takeMsg() + self.assertIsNone(m) + + self.irc.feedMsg(ircmsgs.IrcMsg( + command='CAP', args=['*', 'NEW', 'account-notify'])) + + m = self.irc.takeMsg() + self.assertEqual(m, + ircmsgs.IrcMsg(command='CAP', args=['REQ', 'account-notify'])) + + self.irc.feedMsg(ircmsgs.IrcMsg( + command='CAP', args=['*', 'ACK', 'account-notify'])) + + self.assertIn('account-notify', self.irc.state.capabilities_ack) + + def testCapNewEchomessageLabeledResponse(self): + self.irc = irclib.Irc('test') + + self.assertEqual(self.irc.sasl_current_mechanism, None) + self.assertEqual(self.irc.sasl_next_mechanisms, []) + + self.startCapNegociation(caps='') + + self.endCapNegociation() + + while self.irc.takeMsg(): + pass + + self.irc.feedMsg(ircmsgs.IrcMsg(command='422')) # ERR_NOMOTD + + m = self.irc.takeMsg() + self.assertIsNone(m) + + self.irc.feedMsg(ircmsgs.IrcMsg( + command='CAP', args=['*', 'NEW', 'echo-message'])) + + m = self.irc.takeMsg() + self.assertIsNone(m) + + self.irc.feedMsg(ircmsgs.IrcMsg( + command='CAP', args=['*', 'NEW', 'labeled-response'])) + + m = self.irc.takeMsg() + self.assertEqual(m, + ircmsgs.IrcMsg( + command='CAP', args=['REQ', 'echo-message labeled-response'])) + + self.irc.feedMsg(ircmsgs.IrcMsg( + command='CAP', args=['*', 'ACK', 'echo-message labeled-response'])) + + self.assertIn('echo-message', self.irc.state.capabilities_ack) + self.assertIn('labeled-response', self.irc.state.capabilities_ack) + class StsTestCase(SupyTestCase): def setUp(self): @@ -856,38 +957,10 @@ class IrcTestCase(SupyTestCase): self.irc.removeCallback(c.name()) self.assertEqual(c.batch, irclib.Batch('netjoin', (), [m1, m2, m3, m4])) -class SaslTestCase(SupyTestCase): +class SaslTestCase(SupyTestCase, CapNegMixin): def setUp(self): pass - def startCapNegociation(self, caps='sasl'): - m = self.irc.takeMsg() - self.assertTrue(m.command == 'CAP', 'Expected CAP, got %r.' % m) - self.assertTrue(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m) - - m = self.irc.takeMsg() - self.assertTrue(m.command == 'NICK', 'Expected NICK, got %r.' % m) - - m = self.irc.takeMsg() - self.assertTrue(m.command == 'USER', 'Expected USER, got %r.' % m) - - self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', - args=('*', 'LS', caps))) - - if caps: - m = self.irc.takeMsg() - self.assertTrue(m.command == 'CAP', 'Expected CAP, got %r.' % m) - self.assertEqual(m.args[0], 'REQ', m) - self.assertEqual(m.args[1], 'sasl') - - self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP', - args=('*', 'ACK', 'sasl'))) - - def endCapNegociation(self): - m = self.irc.takeMsg() - self.assertTrue(m.command == 'CAP', 'Expected CAP, got %r.' % m) - self.assertEqual(m.args, ('END',), m) - def testPlain(self): try: conf.supybot.networks.test.sasl.username.setValue('jilles')