Split 'CAP REQ' commands to not exceed 512 bytes.

This commit is contained in:
Valentin Lorentz 2019-10-25 23:07:31 +02:00
parent 0014b206ad
commit 9268356e97
2 changed files with 51 additions and 8 deletions

View File

@ -32,6 +32,7 @@ import copy
import time import time
import random import random
import base64 import base64
import textwrap
import collections import collections
try: try:
@ -1300,19 +1301,24 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self._addCapabilities(msg.args[3]) self._addCapabilities(msg.args[3])
elif len(msg.args) == 3: # End of LS elif len(msg.args) == 3: # End of LS
self._addCapabilities(msg.args[2]) self._addCapabilities(msg.args[2])
common_supported_capabilities = set(self.state.capabilities_ls) & \
self.REQUEST_CAPABILITIES
if 'sasl' in self.state.capabilities_ls: if 'sasl' in self.state.capabilities_ls:
s = self.state.capabilities_ls['sasl'] s = self.state.capabilities_ls['sasl']
if s is not None: if s is not None:
self.filterSaslMechanisms(set(s.split(','))) self.filterSaslMechanisms(set(s.split(',')))
# Normally at this point, self.state.capabilities_ack should be
# empty; but let's just make sure we're not requesting the same
# caps twice for no reason.
new_caps = (
set(self.state.capabilities_ls) &
self.REQUEST_CAPABILITIES -
self.state.capabilities_ack)
# NOTE: Capabilities are requested in alphabetic order, because # NOTE: Capabilities are requested in alphabetic order, because
# sets are unordered, and their "order" is nondeterministic. # sets are unordered, and their "order" is nondeterministic.
# This is needed for the tests. # This is needed for the tests.
if common_supported_capabilities: if new_caps:
caps = ' '.join(sorted(common_supported_capabilities)) self._requestCaps(new_caps)
self.sendMsg(ircmsgs.IrcMsg(command='CAP',
args=('REQ', caps)))
else: else:
self.endCapabilityNegociation() self.endCapabilityNegociation()
else: else:
@ -1353,9 +1359,15 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
self.REQUEST_CAPABILITIES - self.REQUEST_CAPABILITIES -
self.state.capabilities_ack) self.state.capabilities_ack)
if common_supported_unrequested_capabilities: if common_supported_unrequested_capabilities:
caps = ' '.join(sorted(common_supported_unrequested_capabilities)) self._requestCaps(common_supported_unrequested_capabilities)
def _requestCaps(self, caps):
caps = ' '.join(sorted(caps))
# textwrap works here because in ASCII, all chars are 1 bytes:
cap_lines = textwrap.wrap(caps, MAX_LINE_SIZE-len('CAP REQ :'))
for cap_line in cap_lines:
self.sendMsg(ircmsgs.IrcMsg(command='CAP', self.sendMsg(ircmsgs.IrcMsg(command='CAP',
args=('REQ', caps))) args=('REQ', cap_line)))
def monitor(self, targets): def monitor(self, targets):
"""Increment a counter of how many callbacks monitor each target; """Increment a counter of how many callbacks monitor each target;

View File

@ -379,6 +379,37 @@ class IrcStateTestCase(SupyTestCase):
st = irclib.IrcState() st = irclib.IrcState()
self.assert_(st.addMsg(self.irc, ircmsgs.IrcMsg('MODE foo +i')) or 1) self.assert_(st.addMsg(self.irc, ircmsgs.IrcMsg('MODE foo +i')) or 1)
class IrcCapsTestCase(SupyTestCase):
def testReqLineLength(self):
self.irc = irclib.Irc('test')
m = self.irc.takeMsg()
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
self.failUnless(m.args == ('LS', '302'), 'Expected CAP LS 302, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'NICK', 'Expected NICK, got %r.' % m)
m = self.irc.takeMsg()
self.failUnless(m.command == 'USER', 'Expected USER, got %r.' % m)
self.irc.REQUEST_CAPABILITIES = set(['a'*400, 'b'*400])
caps = ' '.join(self.irc.REQUEST_CAPABILITIES)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', '*', 'a'*400)))
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'b'*400)))
m = self.irc.takeMsg()
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
self.assertEqual(m.args[0], 'REQ', m)
self.assertEqual(m.args[1], 'a'*400)
m = self.irc.takeMsg()
self.failUnless(m.command == 'CAP', 'Expected CAP, got %r.' % m)
self.assertEqual(m.args[0], 'REQ', m)
self.assertEqual(m.args[1], 'b'*400)
class IrcTestCase(SupyTestCase): class IrcTestCase(SupyTestCase):
def setUp(self): def setUp(self):
self.irc = irclib.Irc('test') self.irc = irclib.Irc('test')