3
0
mirror of https://github.com/jlu5/PyLink.git synced 2024-11-27 04:59:24 +01:00

More concise UID generators

This commit is contained in:
James Lu 2021-07-14 21:56:48 -07:00
parent bc3a7abe02
commit 46cc621df1
6 changed files with 187 additions and 47 deletions

View File

@ -1,4 +1,4 @@
[FORMAT] [FORMAT]
max-line-length=120 max-line-length=120
good-names=ip good-names=ip,f,i

View File

@ -9,49 +9,34 @@ from pylinkirc import conf
from pylinkirc.classes import IRCNetwork, ProtocolError from pylinkirc.classes import IRCNetwork, ProtocolError
from pylinkirc.log import log from pylinkirc.log import log
__all__ = ['IncrementalUIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol'] __all__ = ['UIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol']
class UIDGenerator():
class IncrementalUIDGenerator():
""" """
Incremental UID Generator module, adapted from InspIRCd source: Generate UIDs for IRC S2S.
https://github.com/inspircd/inspircd/blob/f449c6b296ab/src/server.cpp#L85-L156
""" """
def __init__(self, sid): def __init__(self, uidchars, length, sid):
if not (hasattr(self, 'allowedchars') and hasattr(self, 'length')): self.uidchars = uidchars # corpus of characters to choose from
raise RuntimeError("Allowed characters list not defined. Subclass " self.length = length # desired length of uid part, padded with uidchars[0]
"%s by defining self.allowedchars and self.length " self.sid = str(sid) # server id (prefixed to every result)
"and then calling super().__init__()." % self.__class__.__name__) self.counter = 0
self.uidchars = [self.allowedchars[0]]*self.length
self.sid = str(sid)
def increment(self, pos=None):
"""
Increments the UID generator to the next available UID.
"""
# Position starts at 1 less than the UID length.
if pos is None:
pos = self.length - 1
# If we're at the last character in the list of allowed ones, reset
# and increment the next level above.
if self.uidchars[pos] == self.allowedchars[-1]:
self.uidchars[pos] = self.allowedchars[0]
self.increment(pos-1)
else:
# Find what position in the allowed characters list we're currently
# on, and add one.
idx = self.allowedchars.find(self.uidchars[pos])
self.uidchars[pos] = self.allowedchars[idx+1]
def next_uid(self): def next_uid(self):
""" """
Returns the next unused UID for the server. Returns the next unused UID for the server.
""" """
uid = self.sid + ''.join(self.uidchars) uid = ''
self.increment() num = self.counter
return uid if num >= (len(self.uidchars) ** self.length):
raise RuntimeError("UID overflowed")
while num > 0:
num, index = divmod(num, len(self.uidchars))
uid = self.uidchars[index] + uid
self.counter += 1
uid = uid.rjust(self.length, self.uidchars[0])
return self.sid + uid
class IRCCommonProtocol(IRCNetwork): class IRCCommonProtocol(IRCNetwork):

View File

@ -4,6 +4,7 @@ p10.py: P10 protocol module for PyLink, supporting Nefarious IRCu and others.
import base64 import base64
import socket import socket
import string
import struct import struct
import time import time
from ipaddress import ip_address from ipaddress import ip_address
@ -16,13 +17,13 @@ from pylinkirc.protocols.ircs2s_common import *
__all__ = ['P10Protocol'] __all__ = ['P10Protocol']
class P10UIDGenerator(IncrementalUIDGenerator): class P10UIDGenerator(UIDGenerator):
"""Implements an incremental P10 UID Generator.""" """Implements a P10 UID Generator."""
def __init__(self, sid): def __init__(self, sid):
self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789[]' uidchars = string.ascii_uppercase + string.ascii_lowercase + string.digits + '[]'
self.length = 3 length = 3
super().__init__(sid) super().__init__(uidchars, length, sid)
def p10b64encode(num, length=2): def p10b64encode(num, length=2):
""" """

View File

@ -87,18 +87,16 @@ class TS6SIDGenerator():
sid = ''.join(self.output) sid = ''.join(self.output)
return sid return sid
class TS6UIDGenerator(IncrementalUIDGenerator): class TS6UIDGenerator(UIDGenerator):
"""Implements an incremental TS6 UID Generator.""" """Implements an incremental TS6 UID Generator."""
def __init__(self, sid): def __init__(self, sid):
# Define the options for IncrementalUIDGenerator, and then
# initialize its functions.
# TS6 UIDs are 6 characters in length (9 including the SID). # TS6 UIDs are 6 characters in length (9 including the SID).
# They go from ABCDEFGHIJKLMNOPQRSTUVWXYZ -> 0123456789 -> wrap around: # They go from ABCDEFGHIJKLMNOPQRSTUVWXYZ -> 0123456789 -> wrap around:
# e.g. AAAAAA, AAAAAB ..., AAAAA8, AAAAA9, AAAABA, etc. # e.g. AAAAAA, AAAAAB ..., AAAAA8, AAAAA9, AAAABA, etc.
self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456879' uidchars = string.ascii_uppercase + string.digits
self.length = 6 length = 6
super().__init__(sid) super().__init__(uidchars, length, sid)
class TS6BaseProtocol(IRCS2SProtocol): class TS6BaseProtocol(IRCS2SProtocol):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

86
test/test_protocol_p10.py Normal file
View File

@ -0,0 +1,86 @@
"""
Tests for protocols/p10
"""
import unittest
from pylinkirc.protocols import p10
class P10UIDGeneratorTest(unittest.TestCase):
def setUp(self):
self.uidgen = p10.P10UIDGenerator('HI')
def test_initial_UID(self):
expected = [
"HIAAA",
"HIAAB",
"HIAAC",
"HIAAD",
"HIAAE",
"HIAAF"
]
self.uidgen.counter = 0
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_first_lowercase(self):
expected = [
"HIAAY",
"HIAAZ",
"HIAAa",
"HIAAb",
"HIAAc",
"HIAAd",
]
self.uidgen.counter = 24
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_first_num(self):
expected = [
"HIAAz",
"HIAA0",
"HIAA1",
"HIAA2",
"HIAA3",
"HIAA4",
]
self.uidgen.counter = 26*2-1
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_second(self):
expected = [
"HIAA8",
"HIAA9",
"HIAA[",
"HIAA]",
"HIABA",
"HIABB",
"HIABC",
"HIABD",
]
self.uidgen.counter = 26*2+10-2
actual = [self.uidgen.next_uid() for i in range(8)]
self.assertEqual(expected, actual)
def test_rollover_third(self):
expected = [
"HIE]9",
"HIE][",
"HIE]]",
"HIFAA",
"HIFAB",
"HIFAC",
]
self.uidgen.counter = 5*64**2 - 3
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_overflow(self):
self.uidgen.counter = 64**3-1
self.assertTrue(self.uidgen.next_uid())
self.assertRaises(RuntimeError, self.uidgen.next_uid)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,70 @@
"""
Tests for protocols/ts6_common
"""
import unittest
from pylinkirc.protocols import ts6_common
class TS6UIDGeneratorTest(unittest.TestCase):
def setUp(self):
self.uidgen = ts6_common.TS6UIDGenerator('123')
def test_initial_UID(self):
expected = [
"123AAAAAA",
"123AAAAAB",
"123AAAAAC",
"123AAAAAD",
"123AAAAAE",
"123AAAAAF",
]
self.uidgen.counter = 0
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_first_num(self):
expected = [
"123AAAAAY",
"123AAAAAZ",
"123AAAAA0",
"123AAAAA1",
"123AAAAA2",
"123AAAAA3",
]
self.uidgen.counter = 24
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_second(self):
expected = [
"123AAAAA8",
"123AAAAA9",
"123AAAABA",
"123AAAABB",
"123AAAABC",
"123AAAABD",
]
self.uidgen.counter = 36 - 2
actual = [self.uidgen.next_uid() for i in range(6)]
self.assertEqual(expected, actual)
def test_rollover_third(self):
expected = [
"123AAAE98",
"123AAAE99",
"123AAAFAA",
"123AAAFAB",
"123AAAFAC",
]
self.uidgen.counter = 5*36**2 - 2
actual = [self.uidgen.next_uid() for i in range(5)]
self.assertEqual(expected, actual)
def test_overflow(self):
self.uidgen.counter = 36**6-1
self.assertTrue(self.uidgen.next_uid())
self.assertRaises(RuntimeError, self.uidgen.next_uid)
if __name__ == '__main__':
unittest.main()