From 46cc621df1449c2f7fd491d269aee198b131ba52 Mon Sep 17 00:00:00 2001 From: James Lu Date: Wed, 14 Jul 2021 21:56:48 -0700 Subject: [PATCH] More concise UID generators --- .pylintrc | 2 +- protocols/ircs2s_common.py | 53 +++++++------------- protocols/p10.py | 13 ++--- protocols/ts6_common.py | 10 ++-- test/test_protocol_p10.py | 86 ++++++++++++++++++++++++++++++++ test/test_protocol_ts6_common.py | 70 ++++++++++++++++++++++++++ 6 files changed, 187 insertions(+), 47 deletions(-) create mode 100644 test/test_protocol_p10.py create mode 100644 test/test_protocol_ts6_common.py diff --git a/.pylintrc b/.pylintrc index c304fe0..0f7017e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,4 +1,4 @@ [FORMAT] max-line-length=120 -good-names=ip +good-names=ip,f,i diff --git a/protocols/ircs2s_common.py b/protocols/ircs2s_common.py index 8cd4593..8371ad7 100644 --- a/protocols/ircs2s_common.py +++ b/protocols/ircs2s_common.py @@ -9,49 +9,34 @@ from pylinkirc import conf from pylinkirc.classes import IRCNetwork, ProtocolError from pylinkirc.log import log -__all__ = ['IncrementalUIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol'] +__all__ = ['UIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol'] - -class IncrementalUIDGenerator(): +class UIDGenerator(): """ - Incremental UID Generator module, adapted from InspIRCd source: - https://github.com/inspircd/inspircd/blob/f449c6b296ab/src/server.cpp#L85-L156 + Generate UIDs for IRC S2S. """ - def __init__(self, sid): - if not (hasattr(self, 'allowedchars') and hasattr(self, 'length')): - raise RuntimeError("Allowed characters list not defined. Subclass " - "%s by defining self.allowedchars and self.length " - "and then calling super().__init__()." % self.__class__.__name__) - self.uidchars = [self.allowedchars[0]]*self.length - self.sid = str(sid) - - def increment(self, pos=None): - """ - Increments the UID generator to the next available UID. - """ - # Position starts at 1 less than the UID length. - if pos is None: - pos = self.length - 1 - - # If we're at the last character in the list of allowed ones, reset - # and increment the next level above. - if self.uidchars[pos] == self.allowedchars[-1]: - self.uidchars[pos] = self.allowedchars[0] - self.increment(pos-1) - else: - # Find what position in the allowed characters list we're currently - # on, and add one. - idx = self.allowedchars.find(self.uidchars[pos]) - self.uidchars[pos] = self.allowedchars[idx+1] + def __init__(self, uidchars, length, sid): + self.uidchars = uidchars # corpus of characters to choose from + self.length = length # desired length of uid part, padded with uidchars[0] + self.sid = str(sid) # server id (prefixed to every result) + self.counter = 0 def next_uid(self): """ Returns the next unused UID for the server. """ - uid = self.sid + ''.join(self.uidchars) - self.increment() - return uid + uid = '' + num = self.counter + if num >= (len(self.uidchars) ** self.length): + raise RuntimeError("UID overflowed") + while num > 0: + num, index = divmod(num, len(self.uidchars)) + uid = self.uidchars[index] + uid + + self.counter += 1 + uid = uid.rjust(self.length, self.uidchars[0]) + return self.sid + uid class IRCCommonProtocol(IRCNetwork): diff --git a/protocols/p10.py b/protocols/p10.py index 5334dd9..441a6b7 100644 --- a/protocols/p10.py +++ b/protocols/p10.py @@ -4,6 +4,7 @@ p10.py: P10 protocol module for PyLink, supporting Nefarious IRCu and others. import base64 import socket +import string import struct import time from ipaddress import ip_address @@ -16,13 +17,13 @@ from pylinkirc.protocols.ircs2s_common import * __all__ = ['P10Protocol'] -class P10UIDGenerator(IncrementalUIDGenerator): - """Implements an incremental P10 UID Generator.""" +class P10UIDGenerator(UIDGenerator): + """Implements a P10 UID Generator.""" - def __init__(self, sid): - self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789[]' - self.length = 3 - super().__init__(sid) + def __init__(self, sid): + uidchars = string.ascii_uppercase + string.ascii_lowercase + string.digits + '[]' + length = 3 + super().__init__(uidchars, length, sid) def p10b64encode(num, length=2): """ diff --git a/protocols/ts6_common.py b/protocols/ts6_common.py index 273f9c0..db55397 100644 --- a/protocols/ts6_common.py +++ b/protocols/ts6_common.py @@ -87,18 +87,16 @@ class TS6SIDGenerator(): sid = ''.join(self.output) return sid -class TS6UIDGenerator(IncrementalUIDGenerator): +class TS6UIDGenerator(UIDGenerator): """Implements an incremental TS6 UID Generator.""" def __init__(self, sid): - # Define the options for IncrementalUIDGenerator, and then - # initialize its functions. # TS6 UIDs are 6 characters in length (9 including the SID). # They go from ABCDEFGHIJKLMNOPQRSTUVWXYZ -> 0123456789 -> wrap around: # e.g. AAAAAA, AAAAAB ..., AAAAA8, AAAAA9, AAAABA, etc. - self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456879' - self.length = 6 - super().__init__(sid) + uidchars = string.ascii_uppercase + string.digits + length = 6 + super().__init__(uidchars, length, sid) class TS6BaseProtocol(IRCS2SProtocol): def __init__(self, *args, **kwargs): diff --git a/test/test_protocol_p10.py b/test/test_protocol_p10.py new file mode 100644 index 0000000..e1d90ad --- /dev/null +++ b/test/test_protocol_p10.py @@ -0,0 +1,86 @@ +""" +Tests for protocols/p10 +""" + +import unittest + +from pylinkirc.protocols import p10 + +class P10UIDGeneratorTest(unittest.TestCase): + def setUp(self): + self.uidgen = p10.P10UIDGenerator('HI') + + def test_initial_UID(self): + expected = [ + "HIAAA", + "HIAAB", + "HIAAC", + "HIAAD", + "HIAAE", + "HIAAF" + ] + self.uidgen.counter = 0 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_rollover_first_lowercase(self): + expected = [ + "HIAAY", + "HIAAZ", + "HIAAa", + "HIAAb", + "HIAAc", + "HIAAd", + ] + self.uidgen.counter = 24 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_rollover_first_num(self): + expected = [ + "HIAAz", + "HIAA0", + "HIAA1", + "HIAA2", + "HIAA3", + "HIAA4", + ] + self.uidgen.counter = 26*2-1 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_rollover_second(self): + expected = [ + "HIAA8", + "HIAA9", + "HIAA[", + "HIAA]", + "HIABA", + "HIABB", + "HIABC", + "HIABD", + ] + self.uidgen.counter = 26*2+10-2 + actual = [self.uidgen.next_uid() for i in range(8)] + self.assertEqual(expected, actual) + + def test_rollover_third(self): + expected = [ + "HIE]9", + "HIE][", + "HIE]]", + "HIFAA", + "HIFAB", + "HIFAC", + ] + self.uidgen.counter = 5*64**2 - 3 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_overflow(self): + self.uidgen.counter = 64**3-1 + self.assertTrue(self.uidgen.next_uid()) + self.assertRaises(RuntimeError, self.uidgen.next_uid) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_protocol_ts6_common.py b/test/test_protocol_ts6_common.py new file mode 100644 index 0000000..7fbfa93 --- /dev/null +++ b/test/test_protocol_ts6_common.py @@ -0,0 +1,70 @@ +""" +Tests for protocols/ts6_common +""" + +import unittest + +from pylinkirc.protocols import ts6_common + +class TS6UIDGeneratorTest(unittest.TestCase): + def setUp(self): + self.uidgen = ts6_common.TS6UIDGenerator('123') + + def test_initial_UID(self): + expected = [ + "123AAAAAA", + "123AAAAAB", + "123AAAAAC", + "123AAAAAD", + "123AAAAAE", + "123AAAAAF", + ] + self.uidgen.counter = 0 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_rollover_first_num(self): + expected = [ + "123AAAAAY", + "123AAAAAZ", + "123AAAAA0", + "123AAAAA1", + "123AAAAA2", + "123AAAAA3", + ] + self.uidgen.counter = 24 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_rollover_second(self): + expected = [ + "123AAAAA8", + "123AAAAA9", + "123AAAABA", + "123AAAABB", + "123AAAABC", + "123AAAABD", + ] + self.uidgen.counter = 36 - 2 + actual = [self.uidgen.next_uid() for i in range(6)] + self.assertEqual(expected, actual) + + def test_rollover_third(self): + expected = [ + "123AAAE98", + "123AAAE99", + "123AAAFAA", + "123AAAFAB", + "123AAAFAC", + ] + self.uidgen.counter = 5*36**2 - 2 + actual = [self.uidgen.next_uid() for i in range(5)] + self.assertEqual(expected, actual) + + def test_overflow(self): + self.uidgen.counter = 36**6-1 + self.assertTrue(self.uidgen.next_uid()) + self.assertRaises(RuntimeError, self.uidgen.next_uid) + +if __name__ == '__main__': + unittest.main()