diff --git a/src/irclib.py b/src/irclib.py index a1c2cb19f..bef32e4b1 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -1350,7 +1350,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled): 'IRC specifications. If you know what you are doing, ' 'set supybot.protocols.irc.experimentalExtensions.') - if len(msg) < 2: + if len(msgs) < 2: raise ValueError( 'queueBatch called with less than two messages.') if msgs[0].command.upper() != 'BATCH' or msgs[0].args[0][0] != '+': @@ -1362,16 +1362,17 @@ class Irc(IrcCommandDispatcher, log.Firewalled): batch_name = msgs[0].args[0][1:] - if msgs[0].args[0][1:] != batch_name: + if msgs[-1].args[0][1:] != batch_name: raise ValueError( 'queueBatch called with mismatched BATCH name args.') - if any(msg.server_tags.get('batch') != batch_name for msg in msgs): + if any(msg.server_tags['batch'] != batch_name for msg in msgs[1:-1]): raise ValueError( 'queueBatch called with mismatched batch names.') return if batch_name in self._queued_batches: raise ValueError( 'queueBatch called with a batch name already in flight') + self._queued_batches[batch_name] = msgs # Enqueue only the start of the batch. When takeMsg sees it, it will @@ -1472,7 +1473,8 @@ class Irc(IrcCommandDispatcher, log.Firewalled): # regular queue, which means the fastqueue is empty. # But let's not take any risk, eg. if race condition # with a plugin appending directly to the fastqueue.) - batch_messages = self._queued_batches + batch_name = msg.args[0][1:] + batch_messages = self._queued_batches.pop(batch_name) if batch_messages[0] != msg: log.error('Enqueue "BATCH +" message does not match ' 'the one of the batch in flight.') diff --git a/test/test_irclib.py b/test/test_irclib.py index 4342cdfa3..efd173a82 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -1131,6 +1131,153 @@ class IrcTestCase(SupyTestCase): str(m), 'PRIVMSG #test :%s\r\n' % remaining_payload) +class BatchTestCase(SupyTestCase): + def setUp(self): + self.irc = irclib.Irc('test') + conf.supybot.protocols.irc.experimentalExtensions.setValue(True) + while self.irc.takeMsg() is not None: + self.irc.takeMsg() + + def tearDown(self): + conf.supybot.protocols.irc.experimentalExtensions.setValue(False) + + def testQueueBatch(self): + """Basic operation of queueBatch""" + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :there'), + ircmsgs.IrcMsg('BATCH -label'), + ] + + self.irc.queueBatch(copy.deepcopy(msgs)) + for msg in msgs: + self.assertEqual(msg, self.irc.takeMsg()) + + def testQueueBatchStartMinus(self): + msgs = [ + ircmsgs.IrcMsg('BATCH -label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH -label'), + ] + + with self.assertRaises(ValueError): + self.irc.queueBatch(msgs) + self.assertIsNone(self.irc.takeMsg()) + + def testQueueBatchEndPlus(self): + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH +label'), + ] + + with self.assertRaises(ValueError): + self.irc.queueBatch(msgs) + self.assertIsNone(self.irc.takeMsg()) + + def testQueueBatchMismatchStartEnd(self): + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH -label2'), + ] + + with self.assertRaises(ValueError): + self.irc.queueBatch(msgs) + self.assertIsNone(self.irc.takeMsg()) + + def testQueueBatchMismatchInner(self): + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label2 PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH -label'), + ] + + with self.assertRaises(ValueError): + self.irc.queueBatch(msgs) + self.assertIsNone(self.irc.takeMsg()) + + def testQueueBatchTwice(self): + """Basic operation of queueBatch""" + all_msgs = [] + for label in ('label1', 'label2'): + msgs = [ + ircmsgs.IrcMsg('BATCH +%s batchtype' % label), + ircmsgs.IrcMsg('@batch=%s PRIVMSG #channel :hello' % label), + ircmsgs.IrcMsg('@batch=%s PRIVMSG #channel :there' % label), + ircmsgs.IrcMsg('BATCH -%s' % label), + ] + all_msgs.extend(msgs) + self.irc.queueBatch(copy.deepcopy(msgs)) + + for msg in all_msgs: + self.assertEqual(msg, self.irc.takeMsg()) + + def testQueueBatchDuplicate(self): + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH -label'), + ] + + self.irc.queueBatch(copy.deepcopy(msgs)) + + with self.assertRaises(ValueError): + self.irc.queueBatch(copy.deepcopy(msgs)) + + for msg in msgs: + self.assertEqual(msg, self.irc.takeMsg()) + self.assertIsNone(self.irc.takeMsg()) + + def testQueueBatchReuse(self): + """We can reuse the same label after the batch is closed.""" + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH -label'), + ] + + self.irc.queueBatch(copy.deepcopy(msgs)) + for msg in msgs: + self.assertEqual(msg, self.irc.takeMsg()) + + self.irc.queueBatch(copy.deepcopy(msgs)) + for msg in msgs: + self.assertEqual(msg, self.irc.takeMsg()) + + def testBatchInterleaved(self): + """Make sure it's not possible for an unrelated message to be sent + while a batch is open""" + msgs = [ + ircmsgs.IrcMsg('BATCH +label batchtype'), + ircmsgs.IrcMsg('@batch=label PRIVMSG #channel :hello'), + ircmsgs.IrcMsg('BATCH -label'), + ] + msg = ircmsgs.IrcMsg('PRIVMSG #channel :unrelated message') + + with self.subTest('sendMsg called before "BATCH +" is dequeued'): + self.irc.queueBatch(copy.deepcopy(msgs)) + self.irc.sendMsg(msg) + + self.assertEqual(msg, self.irc.takeMsg()) + self.assertEqual(msgs[0], self.irc.takeMsg()) + self.assertEqual(msgs[1], self.irc.takeMsg()) + self.assertEqual(msgs[2], self.irc.takeMsg()) + self.assertIsNone(self.irc.takeMsg()) + + with self.subTest('sendMsg called after "BATCH +" is dequeued'): + self.irc.queueBatch(copy.deepcopy(msgs)) + self.assertEqual(msgs[0], self.irc.takeMsg()) + + self.irc.sendMsg(msg) + + self.assertEqual(msgs[1], self.irc.takeMsg()) + self.assertEqual(msgs[2], self.irc.takeMsg()) + self.assertEqual(msg, self.irc.takeMsg()) + self.assertIsNone(self.irc.takeMsg()) + + class SaslTestCase(SupyTestCase, CapNegMixin): def setUp(self): pass