Add proper support for nested batches.

This commit is contained in:
Valentin Lorentz 2021-03-03 23:32:00 +01:00
parent 9719bb799e
commit 6f6dad8f7b
2 changed files with 145 additions and 9 deletions

View File

@ -476,9 +476,12 @@ class ChannelState(utils.python.Object):
return ret return ret
Batch = collections.namedtuple('Batch', 'type arguments messages') Batch = collections.namedtuple('Batch', 'name type arguments messages parent_batch')
"""Represents a batch of messages, see """Represents a batch of messages, see
<https://ircv3.net/specs/extensions/batch-3.2>""" <https://ircv3.net/specs/extensions/batch-3.2>
Only access attributes by their name and do not create Batch objects
in plugins; so we can extend the structure without breaking plugins."""
class IrcStateFsm(object): class IrcStateFsm(object):
@ -752,10 +755,11 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
if ircutils.isUserHostmask(msg.prefix) and not msg.command == 'NICK': if ircutils.isUserHostmask(msg.prefix) and not msg.command == 'NICK':
self.nicksToHostmasks[msg.nick] = msg.prefix self.nicksToHostmasks[msg.nick] = msg.prefix
if 'batch' in msg.server_tags: if 'batch' in msg.server_tags:
batch = msg.server_tags['batch'] batch_name = msg.server_tags['batch']
assert batch in self.batches, \ assert batch_name in self.batches, \
'Server references undeclared batch %s' % batch 'Server references undeclared batch %r' % batch_name
self.batches[batch].messages.append(msg) for batch in self.getParentBatches(msg):
batch.messages.append(msg)
method = self.dispatchCommand(msg.command, msg.args) method = self.dispatchCommand(msg.command, msg.args)
if method is not None: if method is not None:
method(irc, msg) method(irc, msg)
@ -768,6 +772,55 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
"""Returns the hostmask for a given nick.""" """Returns the hostmask for a given nick."""
return self.nicksToHostmasks[nick] return self.nicksToHostmasks[nick]
def getParentBatches(self, msg):
"""Given an IrcMsg, returns a list of all batches that contain it,
innermost first.
Raises ValueError if ``msg`` is not in a batch;
or if it is in a batch that has already ended.
This restriction may be relaxed in the future.
This means that you should not call ``getParentBatches``
on a message that was already processed.
For example, assume Limnoria received the following::
:irc.host BATCH +outer example.com/foo
@batch=outer :irc.host BATCH +inner example.com/bar
@batch=inner :nick!user@host PRIVMSG #channel :Hi
@batch=outer :irc.host BATCH -inner
:irc.host BATCH -outer
If you call getParentBatches on any of the middle three messages,
you get ``[Batch(name='inner', ...), Batch(name='outer', ...)]``.
And if you call getParentBatches on either the first or the last
message, you get ``[Batch(name='outer', ...)]``
And you may only call `getParentBatches`` on the PRIVMSG
if only the first three messages were processed.
"""
batch = msg.tagged('batch')
if not batch:
# msg is not a BATCH command
batch_name = msg.server_tags.get('batch')
if batch_name:
batch = self.batches.get(batch_name)
if not batch:
raise ValueError(
'Called getParentBatches for a message in a batch that '
'already ended.'
)
else:
raise ValueError(
'Called getParentBatches for a message not in a batch.')
batches = []
while batch:
batches.append(batch)
batch = batch.parent_batch
return batches
def do004(self, irc, msg): def do004(self, irc, msg):
"""Handles parsing the 004 reply """Handles parsing the 004 reply
@ -1017,8 +1070,20 @@ class IrcState(IrcCommandDispatcher, log.Firewalled):
if msg.args[0].startswith('+'): if msg.args[0].startswith('+'):
batch_type = msg.args[1] batch_type = msg.args[1]
batch_arguments = tuple(msg.args[2:]) batch_arguments = tuple(msg.args[2:])
self.batches[batch_name] = Batch(type=batch_type,
arguments=batch_arguments, messages=[msg]) # Both are possibly None:
parent_batch_name = msg.server_tags.get("batch")
parent_batch = self.batches.get(parent_batch_name)
batch = Batch(
name=batch_name,
type=batch_type,
arguments=batch_arguments,
messages=[msg],
parent_batch=parent_batch
)
msg.tag('batch', batch)
self.batches[batch_name] = batch
elif msg.args[0].startswith('-'): elif msg.args[0].startswith('-'):
batch = self.batches.pop(batch_name) batch = self.batches.pop(batch_name)
batch.messages.append(msg) batch.messages.append(msg)

View File

@ -29,6 +29,7 @@
import copy import copy
import pickle import pickle
import textwrap
import unittest.mock import unittest.mock
from supybot.test import * from supybot.test import *
@ -1012,7 +1013,77 @@ class IrcTestCase(SupyTestCase):
self.irc.feedMsg(m4) self.irc.feedMsg(m4)
finally: finally:
self.irc.removeCallback(c.name()) self.irc.removeCallback(c.name())
self.assertEqual(c.batch, irclib.Batch('netjoin', (), [m1, m2, m3, m4])) self.assertEqual(
c.batch,
irclib.Batch('name', 'netjoin', (), [m1, m2, m3, m4], None),
repr(c.batch)
)
maxDiff = None
def testBatchNested(self):
self.irc.reset()
logs = textwrap.dedent('''
:irc.host BATCH +outer example.com/foo
@batch=outer :irc.host BATCH +inner example.com/bar
@batch=inner :nick!user@host PRIVMSG #channel :Hi
@batch=outer :irc.host BATCH -inner
:irc.host BATCH -outer
''')
msgs = [ircmsgs.IrcMsg(s) for s in logs.split('\n') if s]
# Feed 'BATCH +outer', it should be added in state
self.irc.feedMsg(msgs[0])
outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:1], None)
self.assertEqual(self.irc.state.batches, {'outer': outer})
msg1 = self.irc.state.history[-1]
self.assertEqual(msg1.tagged('batch'), outer)
self.assertEqual(self.irc.state.getParentBatches(msg1), [outer])
# Feed 'BATCH +inner', it should be added in state
self.irc.feedMsg(msgs[1])
outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:2], None)
inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:2], outer)
self.assertIs(self.irc.state.batches['inner'].parent_batch,
self.irc.state.batches['outer'])
self.assertEqual(dict(self.irc.state.batches),
{'outer': outer, 'inner': inner})
msg2 = self.irc.state.history[-1]
self.assertEqual(msg2.tagged('batch'), inner)
self.assertEqual(self.irc.state.getParentBatches(msg2), [inner, outer])
# Feed 'PRIVMSG'
self.irc.feedMsg(msgs[2])
outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:3], None)
inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:3], outer)
self.assertIs(self.irc.state.batches['inner'].parent_batch,
self.irc.state.batches['outer'])
self.assertEqual(self.irc.state.batches,
{'outer': outer, 'inner': inner})
msg3 = self.irc.state.history[-1]
self.assertEqual(msg3.tagged('batch'), None)
self.assertEqual(self.irc.state.getParentBatches(msg3), [inner, outer])
# Feed 'BATCH -inner', it should be remove from state
self.irc.feedMsg(msgs[3])
outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:4], None)
inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:4], outer)
self.assertEqual(self.irc.state.batches, {'outer': outer})
self.assertIs(self.irc.state.history[-1].tagged('batch').parent_batch,
self.irc.state.batches['outer'])
msg4 = self.irc.state.history[-1]
self.assertEqual(msg4.tagged('batch'), inner)
self.assertEqual(self.irc.state.getParentBatches(msg4), [inner, outer])
# Feed 'BATCH -outer', it should be remove from state
self.irc.feedMsg(msgs[4])
outer = irclib.Batch('outer', 'example.com/foo', (), msgs[0:5], None)
inner = irclib.Batch('inner', 'example.com/bar', (), msgs[1:4], outer)
self.assertEqual(self.irc.state.batches, {})
msg5 = self.irc.state.history[-1]
self.assertEqual(msg5.tagged('batch'), outer)
self.assertEqual(self.irc.state.getParentBatches(msg5), [outer])
class SaslTestCase(SupyTestCase, CapNegMixin): class SaslTestCase(SupyTestCase, CapNegMixin):
def setUp(self): def setUp(self):