diff --git a/src/irclib.py b/src/irclib.py index 01a2c7b79..6ea0b4aee 100644 --- a/src/irclib.py +++ b/src/irclib.py @@ -476,9 +476,12 @@ class ChannelState(utils.python.Object): return ret -Batch = collections.namedtuple('Batch', 'type arguments messages') +Batch = collections.namedtuple('Batch', 'name type arguments messages parent_batch') """Represents a batch of messages, see -""" + + +Only access attributes by their name and do not create Batch objects +in plugins; so we can extend the structure without breaking plugins.""" class IrcStateFsm(object): @@ -752,10 +755,11 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): if ircutils.isUserHostmask(msg.prefix) and not msg.command == 'NICK': self.nicksToHostmasks[msg.nick] = msg.prefix if 'batch' in msg.server_tags: - batch = msg.server_tags['batch'] - assert batch in self.batches, \ - 'Server references undeclared batch %s' % batch - self.batches[batch].messages.append(msg) + batch_name = msg.server_tags['batch'] + assert batch_name in self.batches, \ + 'Server references undeclared batch %r' % batch_name + for batch in self.getParentBatches(msg): + batch.messages.append(msg) method = self.dispatchCommand(msg.command, msg.args) if method is not None: method(irc, msg) @@ -768,6 +772,55 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): """Returns the hostmask for a given nick.""" return self.nicksToHostmasks[nick] + def getParentBatches(self, msg): + """Given an IrcMsg, returns a list of all batches that contain it, + innermost first. + + Raises ValueError if ``msg`` is not in a batch; + or if it is in a batch that has already ended. + This restriction may be relaxed in the future. + + This means that you should not call ``getParentBatches`` + on a message that was already processed. + + For example, assume Limnoria received the following:: + + :irc.host BATCH +outer example.com/foo + @batch=outer :irc.host BATCH +inner example.com/bar + @batch=inner :nick!user@host PRIVMSG #channel :Hi + @batch=outer :irc.host BATCH -inner + :irc.host BATCH -outer + + If you call getParentBatches on any of the middle three messages, + you get ``[Batch(name='inner', ...), Batch(name='outer', ...)]``. + And if you call getParentBatches on either the first or the last + message, you get ``[Batch(name='outer', ...)]`` + + And you may only call `getParentBatches`` on the PRIVMSG + if only the first three messages were processed. + """ + batch = msg.tagged('batch') + if not batch: + # msg is not a BATCH command + batch_name = msg.server_tags.get('batch') + if batch_name: + batch = self.batches.get(batch_name) + if not batch: + raise ValueError( + 'Called getParentBatches for a message in a batch that ' + 'already ended.' + ) + else: + raise ValueError( + 'Called getParentBatches for a message not in a batch.') + + batches = [] + while batch: + batches.append(batch) + batch = batch.parent_batch + + return batches + def do004(self, irc, msg): """Handles parsing the 004 reply @@ -1017,8 +1070,20 @@ class IrcState(IrcCommandDispatcher, log.Firewalled): if msg.args[0].startswith('+'): batch_type = msg.args[1] batch_arguments = tuple(msg.args[2:]) - self.batches[batch_name] = Batch(type=batch_type, - arguments=batch_arguments, messages=[msg]) + + # Both are possibly None: + parent_batch_name = msg.server_tags.get("batch") + parent_batch = self.batches.get(parent_batch_name) + + batch = Batch( + name=batch_name, + type=batch_type, + arguments=batch_arguments, + messages=[msg], + parent_batch=parent_batch + ) + msg.tag('batch', batch) + self.batches[batch_name] = batch elif msg.args[0].startswith('-'): batch = self.batches.pop(batch_name) batch.messages.append(msg) diff --git a/test/test_irclib.py b/test/test_irclib.py index 3f822493d..0e7851d6c 100644 --- a/test/test_irclib.py +++ b/test/test_irclib.py @@ -29,6 +29,7 @@ import copy import pickle +import textwrap import unittest.mock from supybot.test import * @@ -1012,7 +1013,77 @@ class IrcTestCase(SupyTestCase): self.irc.feedMsg(m4) finally: self.irc.removeCallback(c.name()) - self.assertEqual(c.batch, irclib.Batch('netjoin', (), [m1, m2, m3, m4])) + self.assertEqual( + c.batch, + irclib.Batch('name', 'netjoin', (), [m1, m2, m3, m4], None), + repr(c.batch) + ) + + maxDiff = None + + def testBatchNested(self): + self.irc.reset() + logs = textwrap.dedent(''' + :irc.host BATCH +outer example.com/foo + @batch=outer :irc.host BATCH +inner example.com/bar + @batch=inner :nick!user@host PRIVMSG #channel :Hi + @batch=outer :irc.host BATCH -inner + :irc.host BATCH -outer + ''') + msgs = [ircmsgs.IrcMsg(s) for s in logs.split('\n') if s] + + # Feed 'BATCH +outer', it should be added in state + self.irc.feedMsg(msgs[0]) + outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:1], None) + self.assertEqual(self.irc.state.batches, {'outer': outer}) + msg1 = self.irc.state.history[-1] + self.assertEqual(msg1.tagged('batch'), outer) + self.assertEqual(self.irc.state.getParentBatches(msg1), [outer]) + + # Feed 'BATCH +inner', it should be added in state + self.irc.feedMsg(msgs[1]) + outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:2], None) + inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:2], outer) + self.assertIs(self.irc.state.batches['inner'].parent_batch, + self.irc.state.batches['outer']) + self.assertEqual(dict(self.irc.state.batches), + {'outer': outer, 'inner': inner}) + msg2 = self.irc.state.history[-1] + self.assertEqual(msg2.tagged('batch'), inner) + self.assertEqual(self.irc.state.getParentBatches(msg2), [inner, outer]) + + # Feed 'PRIVMSG' + self.irc.feedMsg(msgs[2]) + outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:3], None) + inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:3], outer) + self.assertIs(self.irc.state.batches['inner'].parent_batch, + self.irc.state.batches['outer']) + self.assertEqual(self.irc.state.batches, + {'outer': outer, 'inner': inner}) + msg3 = self.irc.state.history[-1] + self.assertEqual(msg3.tagged('batch'), None) + self.assertEqual(self.irc.state.getParentBatches(msg3), [inner, outer]) + + # Feed 'BATCH -inner', it should be remove from state + self.irc.feedMsg(msgs[3]) + outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:4], None) + inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:4], outer) + self.assertEqual(self.irc.state.batches, {'outer': outer}) + self.assertIs(self.irc.state.history[-1].tagged('batch').parent_batch, + self.irc.state.batches['outer']) + msg4 = self.irc.state.history[-1] + self.assertEqual(msg4.tagged('batch'), inner) + self.assertEqual(self.irc.state.getParentBatches(msg4), [inner, outer]) + + # Feed 'BATCH -outer', it should be remove from state + self.irc.feedMsg(msgs[4]) + outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:5], None) + inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:4], outer) + self.assertEqual(self.irc.state.batches, {}) + msg5 = self.irc.state.history[-1] + self.assertEqual(msg5.tagged('batch'), outer) + self.assertEqual(self.irc.state.getParentBatches(msg5), [outer]) + class SaslTestCase(SupyTestCase, CapNegMixin): def setUp(self):