mirror of
https://github.com/jlu5/PyLink.git
synced 2024-11-23 11:09:22 +01:00
More concise UID generators
This commit is contained in:
parent
bc3a7abe02
commit
46cc621df1
@ -9,49 +9,34 @@ from pylinkirc import conf
|
||||
from pylinkirc.classes import IRCNetwork, ProtocolError
|
||||
from pylinkirc.log import log
|
||||
|
||||
__all__ = ['IncrementalUIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol']
|
||||
__all__ = ['UIDGenerator', 'IRCCommonProtocol', 'IRCS2SProtocol']
|
||||
|
||||
|
||||
class IncrementalUIDGenerator():
|
||||
class UIDGenerator():
|
||||
"""
|
||||
Incremental UID Generator module, adapted from InspIRCd source:
|
||||
https://github.com/inspircd/inspircd/blob/f449c6b296ab/src/server.cpp#L85-L156
|
||||
Generate UIDs for IRC S2S.
|
||||
"""
|
||||
|
||||
def __init__(self, sid):
|
||||
if not (hasattr(self, 'allowedchars') and hasattr(self, 'length')):
|
||||
raise RuntimeError("Allowed characters list not defined. Subclass "
|
||||
"%s by defining self.allowedchars and self.length "
|
||||
"and then calling super().__init__()." % self.__class__.__name__)
|
||||
self.uidchars = [self.allowedchars[0]]*self.length
|
||||
self.sid = str(sid)
|
||||
|
||||
def increment(self, pos=None):
|
||||
"""
|
||||
Increments the UID generator to the next available UID.
|
||||
"""
|
||||
# Position starts at 1 less than the UID length.
|
||||
if pos is None:
|
||||
pos = self.length - 1
|
||||
|
||||
# If we're at the last character in the list of allowed ones, reset
|
||||
# and increment the next level above.
|
||||
if self.uidchars[pos] == self.allowedchars[-1]:
|
||||
self.uidchars[pos] = self.allowedchars[0]
|
||||
self.increment(pos-1)
|
||||
else:
|
||||
# Find what position in the allowed characters list we're currently
|
||||
# on, and add one.
|
||||
idx = self.allowedchars.find(self.uidchars[pos])
|
||||
self.uidchars[pos] = self.allowedchars[idx+1]
|
||||
def __init__(self, uidchars, length, sid):
|
||||
self.uidchars = uidchars # corpus of characters to choose from
|
||||
self.length = length # desired length of uid part, padded with uidchars[0]
|
||||
self.sid = str(sid) # server id (prefixed to every result)
|
||||
self.counter = 0
|
||||
|
||||
def next_uid(self):
|
||||
"""
|
||||
Returns the next unused UID for the server.
|
||||
"""
|
||||
uid = self.sid + ''.join(self.uidchars)
|
||||
self.increment()
|
||||
return uid
|
||||
uid = ''
|
||||
num = self.counter
|
||||
if num >= (len(self.uidchars) ** self.length):
|
||||
raise RuntimeError("UID overflowed")
|
||||
while num > 0:
|
||||
num, index = divmod(num, len(self.uidchars))
|
||||
uid = self.uidchars[index] + uid
|
||||
|
||||
self.counter += 1
|
||||
uid = uid.rjust(self.length, self.uidchars[0])
|
||||
return self.sid + uid
|
||||
|
||||
class IRCCommonProtocol(IRCNetwork):
|
||||
|
||||
|
@ -4,6 +4,7 @@ p10.py: P10 protocol module for PyLink, supporting Nefarious IRCu and others.
|
||||
|
||||
import base64
|
||||
import socket
|
||||
import string
|
||||
import struct
|
||||
import time
|
||||
from ipaddress import ip_address
|
||||
@ -16,13 +17,13 @@ from pylinkirc.protocols.ircs2s_common import *
|
||||
__all__ = ['P10Protocol']
|
||||
|
||||
|
||||
class P10UIDGenerator(IncrementalUIDGenerator):
|
||||
"""Implements an incremental P10 UID Generator."""
|
||||
class P10UIDGenerator(UIDGenerator):
|
||||
"""Implements a P10 UID Generator."""
|
||||
|
||||
def __init__(self, sid):
|
||||
self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789[]'
|
||||
self.length = 3
|
||||
super().__init__(sid)
|
||||
def __init__(self, sid):
|
||||
uidchars = string.ascii_uppercase + string.ascii_lowercase + string.digits + '[]'
|
||||
length = 3
|
||||
super().__init__(uidchars, length, sid)
|
||||
|
||||
def p10b64encode(num, length=2):
|
||||
"""
|
||||
|
@ -87,18 +87,16 @@ class TS6SIDGenerator():
|
||||
sid = ''.join(self.output)
|
||||
return sid
|
||||
|
||||
class TS6UIDGenerator(IncrementalUIDGenerator):
|
||||
class TS6UIDGenerator(UIDGenerator):
|
||||
"""Implements an incremental TS6 UID Generator."""
|
||||
|
||||
def __init__(self, sid):
|
||||
# Define the options for IncrementalUIDGenerator, and then
|
||||
# initialize its functions.
|
||||
# TS6 UIDs are 6 characters in length (9 including the SID).
|
||||
# They go from ABCDEFGHIJKLMNOPQRSTUVWXYZ -> 0123456789 -> wrap around:
|
||||
# e.g. AAAAAA, AAAAAB ..., AAAAA8, AAAAA9, AAAABA, etc.
|
||||
self.allowedchars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456879'
|
||||
self.length = 6
|
||||
super().__init__(sid)
|
||||
uidchars = string.ascii_uppercase + string.digits
|
||||
length = 6
|
||||
super().__init__(uidchars, length, sid)
|
||||
|
||||
class TS6BaseProtocol(IRCS2SProtocol):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
86
test/test_protocol_p10.py
Normal file
86
test/test_protocol_p10.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
Tests for protocols/p10
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from pylinkirc.protocols import p10
|
||||
|
||||
class P10UIDGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.uidgen = p10.P10UIDGenerator('HI')
|
||||
|
||||
def test_initial_UID(self):
|
||||
expected = [
|
||||
"HIAAA",
|
||||
"HIAAB",
|
||||
"HIAAC",
|
||||
"HIAAD",
|
||||
"HIAAE",
|
||||
"HIAAF"
|
||||
]
|
||||
self.uidgen.counter = 0
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_first_lowercase(self):
|
||||
expected = [
|
||||
"HIAAY",
|
||||
"HIAAZ",
|
||||
"HIAAa",
|
||||
"HIAAb",
|
||||
"HIAAc",
|
||||
"HIAAd",
|
||||
]
|
||||
self.uidgen.counter = 24
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_first_num(self):
|
||||
expected = [
|
||||
"HIAAz",
|
||||
"HIAA0",
|
||||
"HIAA1",
|
||||
"HIAA2",
|
||||
"HIAA3",
|
||||
"HIAA4",
|
||||
]
|
||||
self.uidgen.counter = 26*2-1
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_second(self):
|
||||
expected = [
|
||||
"HIAA8",
|
||||
"HIAA9",
|
||||
"HIAA[",
|
||||
"HIAA]",
|
||||
"HIABA",
|
||||
"HIABB",
|
||||
"HIABC",
|
||||
"HIABD",
|
||||
]
|
||||
self.uidgen.counter = 26*2+10-2
|
||||
actual = [self.uidgen.next_uid() for i in range(8)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_third(self):
|
||||
expected = [
|
||||
"HIE]9",
|
||||
"HIE][",
|
||||
"HIE]]",
|
||||
"HIFAA",
|
||||
"HIFAB",
|
||||
"HIFAC",
|
||||
]
|
||||
self.uidgen.counter = 5*64**2 - 3
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_overflow(self):
|
||||
self.uidgen.counter = 64**3-1
|
||||
self.assertTrue(self.uidgen.next_uid())
|
||||
self.assertRaises(RuntimeError, self.uidgen.next_uid)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
70
test/test_protocol_ts6_common.py
Normal file
70
test/test_protocol_ts6_common.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""
|
||||
Tests for protocols/ts6_common
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from pylinkirc.protocols import ts6_common
|
||||
|
||||
class TS6UIDGeneratorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.uidgen = ts6_common.TS6UIDGenerator('123')
|
||||
|
||||
def test_initial_UID(self):
|
||||
expected = [
|
||||
"123AAAAAA",
|
||||
"123AAAAAB",
|
||||
"123AAAAAC",
|
||||
"123AAAAAD",
|
||||
"123AAAAAE",
|
||||
"123AAAAAF",
|
||||
]
|
||||
self.uidgen.counter = 0
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_first_num(self):
|
||||
expected = [
|
||||
"123AAAAAY",
|
||||
"123AAAAAZ",
|
||||
"123AAAAA0",
|
||||
"123AAAAA1",
|
||||
"123AAAAA2",
|
||||
"123AAAAA3",
|
||||
]
|
||||
self.uidgen.counter = 24
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_second(self):
|
||||
expected = [
|
||||
"123AAAAA8",
|
||||
"123AAAAA9",
|
||||
"123AAAABA",
|
||||
"123AAAABB",
|
||||
"123AAAABC",
|
||||
"123AAAABD",
|
||||
]
|
||||
self.uidgen.counter = 36 - 2
|
||||
actual = [self.uidgen.next_uid() for i in range(6)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_rollover_third(self):
|
||||
expected = [
|
||||
"123AAAE98",
|
||||
"123AAAE99",
|
||||
"123AAAFAA",
|
||||
"123AAAFAB",
|
||||
"123AAAFAC",
|
||||
]
|
||||
self.uidgen.counter = 5*36**2 - 2
|
||||
actual = [self.uidgen.next_uid() for i in range(5)]
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_overflow(self):
|
||||
self.uidgen.counter = 36**6-1
|
||||
self.assertTrue(self.uidgen.next_uid())
|
||||
self.assertRaises(RuntimeError, self.uidgen.next_uid)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user