diff --git a/src/ircdb.py b/src/ircdb.py index a504bbe7f..03711051d 100644 --- a/src/ircdb.py +++ b/src/ircdb.py @@ -44,6 +44,7 @@ import conf import utils import world import ircutils +import unpreserve from structures import PersistentDictionary def fromChannelCapability(capability): @@ -280,6 +281,22 @@ class IrcUser(object): """Unsets a use's authenticated hostmask.""" self.auth = None + def preserve(self, fd, indent=''): + def write(s): + fd.write(indent) + fd.write(s) + fd.write(os.linesep) + write('name %s' % self.name) + write('ignore %s' % self.ignore) + write('secure %s' % self.secure) + write('hashed %s' % self.hashed) + write('password %s' % self.password) + for capability in self.capabilities: + write('capability %s' % capability) + for hostmask in self.hostmasks: + write('hostmask %s' % hostmask) + fd.write(os.linesep) + class IrcChannel(object): """This class holds the capabilities, bans, and ignores of a channel. @@ -369,31 +386,157 @@ class IrcChannel(object): return True return False -class UsersDB(object): + def preserve(self, fd, indent=''): + def write(s): + fd.write(indent) + fd.write(s) + fd.write(os.linesep) + write('lobotomized %s' % self.lobotomized) + write('defaultAllow %s' % self.defaultAllow) + for ban in self.bans: + write('ban %s' % ban) + for ignore in self.ignores: + write('ignore %s' % ignore) + for capability in self.capabilities: + write('capability %s' % capability) + fd.write(os.linesep) + + +class Creator(object): + def command(self, command, rest, lineno): + raise ValueError, 'Invalid command on line %s: %s' % (lineno, command) + +class IrcUserCreator(Creator): + id = None + def __init__(self): + self.u = IrcUser() + + def user(self, rest, lineno): + if self.id is not None: + raise ValueError, 'Unexpected user command on line %s.' % lineno + IrcUserCreator.id = int(rest) + + def name(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.name = rest + + def ignore(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.ignore = bool(eval(rest)) + + def secure(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.secure = bool(eval(rest)) + + def hashed(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.hashed = bool(eval(rest)) + + def password(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.password = rest + + def hostmask(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.hostmasks.append(rest) + + def capability(self, rest, lineno): + if self.id is None: + raise ValueError, 'Unexpected user description without id.' + self.u.capabilities.add(rest) + + def finish(self): + if self.u.name: + users.setUser(self.id, self.u) + IrcUserCreator.id = None + +class IrcChannelCreator(Creator): + name = None + def __init__(self): + self.c = IrcChannel() + self.hadChannel = bool(self.name) + + def channel(self, rest, lineno): + if self.name is not None: + raise ValueError, 'Unexpected channel command on line %s' % lineno + IrcChannelCreator.name = rest + + def lobotomized(self, rest, lineno): + if self.name is None: + raise ValueError, 'Unexpected channel description without channel.' + self.c.lobotomized = bool(eval(rest)) + + def defaultallow(self, rest, lineno): + if self.name is None: + raise ValueError, 'Unexpected channel description without channel.' + self.c.defaultAllow = bool(eval(rest)) + + def capability(self, rest, lineno): + if self.name is None: + raise ValueError, 'Unexpected channel description without channel.' + self.c.capabilities.add(rest) + + def ban(self, rest, lineno): + if self.name is None: + raise ValueError, 'Unexpected channel description without channel.' + self.c.bans.append(rest) + + def ignore(self, rest, lineno): + if self.name is None: + raise ValueError, 'Unexpected channel description without channel.' + self.c.ignores.append(rest) + + def finish(self): + if self.hadChannel: + channels.setChannel(self.name, self.c) + IrcChannelCreator.name = None + + +class UsersDictionary(utils.IterableMap): """A simple serialized-to-file User Database.""" - def __init__(self, filename): - self.filename = filename - if os.path.exists(filename): - fd = file(filename, 'r') - s = fd.read() - fd.close() - IrcSet = ircutils.IrcSet - (self.nextId, self.users) = eval(_normalize(s)) - else: - self.nextId = 1 - self.users = [None] + def __init__(self): + self.filename = None + self.users = {} + self.nextId = 1 self._nameCache = {} self._hostmaskCache = {} + def open(self, filename): + self.filename = filename + reader = unpreserve.Reader(IrcUserCreator) + reader.readFile(filename) + def reload(self): """Reloads the database from its file.""" - self.__init__(self.filename) + if self.filename is not None: + self.nextId = 0 + self.users.clear() + self.open(self.filename) + else: + log.warning('UsersDictionary.reload called without self.filename.') def flush(self): """Flushes the database to its file.""" - fd = file(self.filename, 'w') - fd.write(repr((self.nextId, self.users))) - fd.close() + if self.filename is not None: + L = self.users.items() + L.sort() + fd = file(self.filename, 'w') + for (id, u) in L: + fd.write('user %s' % id) + fd.write(os.linesep) + u.preserve(fd, indent=' ') + fd.close() + else: + log.warning('UsersDictionary.flush called without self.filename.') + + def iteritems(self): + return self.users.iteritems() def getUserId(self, s): """Returns the user ID of a given name or hostmask.""" @@ -402,9 +545,7 @@ class UsersDB(object): return self._hostmaskCache[s] except KeyError: ids = [] - for (id, user) in enumerate(self.users): - if user is None: - continue + for (id, user) in self.users.iteritems(): if user.checkHostmask(s): ids.append(id) if len(ids) == 1: @@ -424,9 +565,7 @@ class UsersDB(object): try: return self._nameCache[s] except KeyError: - for (id, user) in enumerate(self.users): - if user is None: - continue + for (id, user) in self.users.items(): if s == user.name.lower(): self._nameCache[s] = id self._nameCache[id] = s @@ -439,13 +578,7 @@ class UsersDB(object): if not isinstance(id, int): # Must be a string. Get the UserId first. id = self.getUserId(id) - try: - ret = self.users[id] - if ret is None: - raise KeyError, id - return ret - except IndexError: - raise KeyError, id + return self.users[id] def hasUser(self, id): """Returns the database has a user given its id, name, or hostmask.""" @@ -455,19 +588,13 @@ class UsersDB(object): except KeyError: return False - def __iter__(self): - x = ifilter(None, self.users) - x.next() # Skip the bot user. - return x - def numUsers(self): - return ilen(self) + return len(self.users) def setUser(self, id, user): """Sets a user (given its id) to the IrcUser given it.""" assert isinstance(id, int), 'setUser takes an integer userId.' - if (not 0 <= id < len(self.users)) or self.users[id] is None: - raise KeyError, id + self.nextId = max(self.nextId, id) try: if self.getUserId(user.name) != id: s = '%s is already registered to someone else.' % user.name @@ -493,9 +620,7 @@ class UsersDB(object): def delUser(self, id): """Removes a user from the database.""" - if not 0 <= id < len(self.users) or self.users[id] is None: - raise KeyError, id - self.users[id] = None + del self.users[id] if id in self._nameCache: del self._nameCache[self._nameCache[id]] del self._nameCache[id] @@ -510,54 +635,79 @@ class UsersDB(object): user = IrcUser() id = self.nextId self.nextId += 1 - self.users.append(user) + self.users[id] = user self.flush() return (id, user) class ChannelsDictionary(utils.IterableMap): - def __init__(self, filename): + def __init__(self): + self.filename = None + self.channels = ircutils.IrcDict() + + def open(self, filename): self.filename = filename - Set = sets.Set - self.dict = PersistentDictionary(filename, globals(), locals()) + reader = unpreserve.Reader(IrcChannelCreator) + reader.readFile(filename) + + def flush(self): + """Flushes the channel database to its file.""" + if self.filename: + fd = file(self.filename, 'w') + for (channel, c) in self.channels.iteritems(): + fd.write('channel %s' % channel) + fd.write(os.linesep) + c.preserve(fd, indent=' ') + fd.close() + else: + log.warning('ChannelsDictionary.flush called with self.filename.') + + def reload(self): + """Reloads the channel database from its file.""" + if self.filename: + self.channels.clear() + self.open(self.filename) + else: + log.warning('ChannelsDictionary.reload called with self.filename.') def getChannel(self, channel): """Returns an IrcChannel object for the given channel.""" channel = channel.lower() - if channel in self.dict: - return self.dict[channel] + if channel in self.channels: + return self.channels[channel] else: c = IrcChannel() - self.dict[channel] = c + self.channels[channel] = c return c def setChannel(self, channel, ircChannel): """Sets a given channel to the IrcChannel object given.""" channel = channel.lower() - self.dict[channel] = ircChannel + self.channels[channel] = ircChannel self.flush() - def flush(self): - """Flushes the channel database to its file.""" - self.dict.flush() - - def reload(self): - """Reloads the channel database from its file.""" - self.__init__(self.filename) - def iteritems(self): - return self.dict.iteritems() + return self.channels.iteritems() ### # Later, I might add some special handling for botnet. ### confDir = conf.supybot.directories.conf() -users = UsersDB(os.path.join(confDir, - conf.supybot.databases.users.filename())) -channels = ChannelsDictionary(os.path.join(confDir, - conf.supybot.databases.channels.filename())) +users = UsersDictionary() +try: + users.open(os.path.join(confDir, conf.supybot.databases.users.filename())) +except EnvironmentError, e: + log.warning('Couldn\'t open user database: %s', e) + +channels = ChannelsDictionary() +try: + channelFile = conf.supybot.databases.channels.filename() + channels.open(os.path.join(confDir,channelFile)) +except EnvironmentError, e: + log.warning('Couldn\'t open channel database: %s', e) + ### # Useful functions for checking credentials. diff --git a/src/unpreserve.py b/src/unpreserve.py new file mode 100644 index 000000000..fcf347ceb --- /dev/null +++ b/src/unpreserve.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +### +# Copyright (c) 2004, Jeremiah Fincher +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions, and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the author of this software nor the name of +# contributors to this software may be used to endorse or promote products +# derived from this software without specific prior written consent. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +### + +__revision__ = "$Id$" + +class Reader(object): + def __init__(self, Creator): + self.Creator = Creator + self.creator = None + self.modifiedCreator = False + self.indent = None + + def normalizeCommand(self, s): + return s.lower() + + def readFile(self, filename): + self.read(file(filename)) + + def read(self, fd): + lineno = 0 + for line in fd: + lineno += 1 + if not line.strip(): + continue + line = line.rstrip('\r\n') + line = line.replace('\t', ' '*8) + s = line.lstrip(' ') + indent = len(line) - len(s) + if indent != self.indent: + # New indentation level. + if self.creator is not None: + self.creator.finish() + self.creator = self.Creator() + self.modifiedCreator = False + self.indent = indent + (command, rest) = s.split(None, 1) + command = self.normalizeCommand(command) + self.modifiedCreator = True + if hasattr(self.creator, command): + command = getattr(self.creator, command) + command(rest, lineno) + else: + self.creator.command(command, rest, lineno) + if self.modifiedCreator: + self.creator.finish() + + +# vim:set shiftwidth=4 tabstop=8 expandtab textwidth=78: +