diff --git a/utils.py b/utils.py index 09366bc..2cb43ff 100644 --- a/utils.py +++ b/utils.py @@ -10,11 +10,25 @@ import re import inspect import importlib import os +import collections from log import log import world import conf +class KeyedDefaultdict(collections.defaultdict): + """ + Subclass of defaultdict allowing the key to be passed to the default factory. + """ + def __missing__(self, key): + if self.default_factory is None: + # If there is no default factory, just let defaultdict handle it + super().__missing__(self, key) + else: + value = self[key] = self.default_factory(key) + return value + + class NotAuthenticatedError(Exception): """ Exception raised by checkAuthenticated() when a user fails authentication