More concise UID generators

This commit is contained in:
James Lu 2021-07-14 21:56:48 -07:00
parent bc3a7abe02
commit 46cc621df1
6 changed files with 187 additions and 47 deletions

View File

@ -1,4 +1,4 @@
[FORMAT]
max-line-length=120
good-names=ip
good-names=ip,f,i

View File

@ -9,49 +9,34 @@ from pylinkirc import conf
from pylinkirc.classes import IRCNetwork, ProtocolError
from pylinkirc.log import log
__all__ = ['IncrementalUIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol']
__all__ = ['UIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol']
class IncrementalUIDGenerator():
class UIDGenerator():
"""
Incremental UID Generator module, adapted from InspIRCd source:
https://github.com/inspircd/inspircd/blob/f449c6b296ab/src/server.cpp#L85-L156
Generate UIDs for IRC S2S.
"""
def __init__(self, sid):
if not (hasattr(self, 'allowedchars') and hasattr(self, 'length')):
raise RuntimeError("Allowed characters list not defined. Subclass "
"%s by defining self.allowedchars and self.length "
"and then calling super().__init__()." % self.__class__.__name__)
self.uidchars = [self.allowedchars[0]]*self.length
self.sid = str(sid)
def increment(self, pos=None):
"""
Increments the UID generator to the next available UID.
"""
# Position starts at 1 less than the UID length.
if pos is None:
pos = self.length - 1
# If we're at the last character in the list of allowed ones, reset
# and increment the next level above.
if self.uidchars[pos] == self.allowedchars[-1]:
self.uidchars[pos] = self.allowedchars[0]
self.increment(pos-1)
else:
# Find what position in the allowed characters list we're currently
# on, and add one.
idx = self.allowedchars.find(self.uidchars[pos])
self.uidchars[pos] = self.allowedchars[idx+1]
def __init__(self, uidchars, length, sid):
self.uidchars = uidchars # corpus of characters to choose from
self.length = length # desired length of uid part, padded with uidchars[0]
self.sid = str(sid) # server id (prefixed to every result)
self.counter = 0
def next_uid(self):
"""
Returns the next unused UID for the server.
"""
uid = self.sid + ''.join(self.uidchars)
self.increment()
return uid
uid = ''
num = self.counter
if num >= (len(self.uidchars) ** self.length):
raise RuntimeError("UID overflowed")
while num > 0:
num, index = divmod(num, len(self.uidchars))
uid = self.uidchars[index] + uid
self.counter += 1
uid = uid.rjust(self.length, self.uidchars[0])
return self.sid + uid
class IRCCommonProtocol(IRCNetwork):

View File

@ -4,6 +4,7 @@ p10.py: P10 protocol module for PyLink, supporting Nefarious IRCu and others.
import base64
import socket
import string
import struct
import time
from ipaddress import ip_address
@ -16,13 +17,13 @@ from pylinkirc.protocols.ircs2s_common import *
__all__ = ['P10Protocol']
class P10UIDGenerator(IncrementalUIDGenerator):
"""Implements an incremental P10 UID Generator."""
class P10UIDGenerator(UIDGenerator):
"""Implements a P10 UID Generator."""
def __init__(self, sid):
self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789[]'
self.length = 3
super().__init__(sid)
def __init__(self, sid):
uidchars = string.ascii_uppercase + string.ascii_lowercase + string.digits + '[]'
length = 3
super().__init__(uidchars, length, sid)
def p10b64encode(num, length=2):
"""

View File

@ -87,18 +87,16 @@ class TS6SIDGenerator():
sid = ''.join(self.output)
return sid
class TS6UIDGenerator(IncrementalUIDGenerator):
class TS6UIDGenerator(UIDGenerator):
"""Implements an incremental TS6 UID Generator."""
def __init__(self, sid):
# Define the options for IncrementalUIDGenerator, and then
# initialize its functions.
# TS6 UIDs are 6 characters in length (9 including the SID).
# They go from ABCDEFGHIJKLMNOPQRSTUVWXYZ -> 0123456789 -> wrap around:
# e.g. AAAAAA, AAAAAB ..., AAAAA8, AAAAA9, AAAABA, etc.
self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456879'
self.length = 6
super().__init__(sid)
uidchars = string.ascii_uppercase + string.digits
length = 6
super().__init__(uidchars, length, sid)
class TS6BaseProtocol(IRCS2SProtocol):
def __init__(self, *args, **kwargs):

86
test/test_protocol_p10.py Normal file
View File

@ -0,0 +1,86 @@
"""
Tests for protocols/p10
"""
import unittest
from pylinkirc.protocols import p10
class P10UIDGeneratorTest(unittest.TestCase):
def setUp(self):
self.uidgen = p10.P10UIDGenerator('HI')
def test_initial_UID(self):
expected = [
"HIAAA",
"HIAAB",
"HIAAC",
"HIAAD",
"HIAAE",
"HIAAF"
]
self.uidgen.counter = 0
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_first_lowercase(self):
expected = [
"HIAAY",
"HIAAZ",
"HIAAa",
"HIAAb",
"HIAAc",
"HIAAd",
]
self.uidgen.counter = 24
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_first_num(self):
expected = [
"HIAAz",
"HIAA0",
"HIAA1",
"HIAA2",
"HIAA3",
"HIAA4",
]
self.uidgen.counter = 26*2-1
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_second(self):
expected = [
"HIAA8",
"HIAA9",
"HIAA[",
"HIAA]",
"HIABA",
"HIABB",
"HIABC",
"HIABD",
]
self.uidgen.counter = 26*2+10-2
actual = [self.uidgen.next_uid() for i in range(8)]
self.assertEqual(expected, actual)
def test_rollover_third(self):
expected = [
"HIE]9",
"HIE][",
"HIE]]",
"HIFAA",
"HIFAB",
"HIFAC",
]
self.uidgen.counter = 5*64**2 - 3
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_overflow(self):
self.uidgen.counter = 64**3-1
self.assertTrue(self.uidgen.next_uid())
self.assertRaises(RuntimeError, self.uidgen.next_uid)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,70 @@
"""
Tests for protocols/ts6_common
"""
import unittest
from pylinkirc.protocols import ts6_common
class TS6UIDGeneratorTest(unittest.TestCase):
def setUp(self):
self.uidgen = ts6_common.TS6UIDGenerator('123')
def test_initial_UID(self):
expected = [
"123AAAAAA",
"123AAAAAB",
"123AAAAAC",
"123AAAAAD",
"123AAAAAE",
"123AAAAAF",
]
self.uidgen.counter = 0
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_first_num(self):
expected = [
"123AAAAAY",
"123AAAAAZ",
"123AAAAA0",
"123AAAAA1",
"123AAAAA2",
"123AAAAA3",
]
self.uidgen.counter = 24
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_second(self):
expected = [
"123AAAAA8",
"123AAAAA9",
"123AAAABA",
"123AAAABB",
"123AAAABC",
"123AAAABD",
]
self.uidgen.counter = 36 - 2
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_third(self):
expected = [
"123AAAE98",
"123AAAE99",
"123AAAFAA",
"123AAAFAB",
"123AAAFAC",
]
self.uidgen.counter = 5*36**2 - 2
actual = [self.uidgen.next_uid() for i in range(5)]
self.assertEqual(expected, actual)
def test_overflow(self):
self.uidgen.counter = 36**6-1
self.assertTrue(self.uidgen.next_uid())
self.assertRaises(RuntimeError, self.uidgen.next_uid)
if __name__ == '__main__':
unittest.main()