diff --git a/test/test_utils.py b/test/test_utils.py index dc5ee8e..1771123 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -130,5 +130,25 @@ class UtilsTestCase(unittest.TestCase): utils.remove_range( "-3--5", ["we", "love", "emotes"]) + def test_get_hostname_type(self): + self.assertEqual(utils.get_hostname_type("1.2.3.4"), 1) + self.assertEqual(utils.get_hostname_type("192.168.0.1"), 1) + self.assertEqual(utils.get_hostname_type("127.0.0.5"), 1) + + self.assertEqual(utils.get_hostname_type("0::1"), 2) + self.assertEqual(utils.get_hostname_type("::1"), 2) + self.assertEqual(utils.get_hostname_type("fc00::1234"), 2) + self.assertEqual(utils.get_hostname_type("1111:2222:3333:4444:5555:6666:7777:8888"), 2) + + self.assertEqual(utils.get_hostname_type("example.com"), False) + self.assertEqual(utils.get_hostname_type("abc.mynet.local"), False) + self.assertEqual(utils.get_hostname_type("123.example"), False) + + self.assertEqual(utils.get_hostname_type("123.456.789.000"), False) + self.assertEqual(utils.get_hostname_type("1::2::3"), False) + self.assertEqual(utils.get_hostname_type("1:"), False) + self.assertEqual(utils.get_hostname_type(":5"), False) + + if __name__ == '__main__': unittest.main() diff --git a/utils.py b/utils.py index 4101385..026a011 100644 --- a/utils.py +++ b/utils.py @@ -11,6 +11,7 @@ import importlib import os import collections import argparse +import ipaddress from .log import log from . import world, conf, structures @@ -750,3 +751,20 @@ def remove_range(rangestr, mylist): (subrange, rangestr)) return list(filter(lambda x: x is not None, mylist)) + +def get_hostname_type(address): + """ + Returns whether the given address is an IPv4 address (1), IPv6 address (2), or neither + (False; assumed to be a hostname instead). + """ + try: + ip = ipaddress.ip_address(address) + except ValueError: + return False + else: + if isinstance(ip, ipaddress.IPv4Address): + return 1 + elif isinstance(ip, ipaddress.IPv6Address): + return 2 + else: + raise ValueError("Got unknown value %r from ipaddress.ip_address()" % address)