diff --git a/plugins/Herald.py b/plugins/Herald.py index b93358d2c..d97ce7883 100644 --- a/plugins/Herald.py +++ b/plugins/Herald.py @@ -42,6 +42,7 @@ import time import log import conf import utils +import world import ircdb import ircmsgs import ircutils @@ -49,47 +50,16 @@ import privmsgs import registry import callbacks +filename = os.path.join(conf.supybot.directories.data(), 'Herald.db') -class HeraldDB(object): - def __init__(self): - self.heralds = {} - dataDir = conf.supybot.directories.data() - self.filename = os.path.join(dataDir, 'Herald.db') - self.open() - - def open(self): - dataDir = conf.supybot.directories.data() - if os.path.exists(self.filename): - fd = file(self.filename) - for line in fd: - line = line.rstrip() - try: - (idChannel, msg) = line.split(':', 1) - (id, channel) = idChannel.split(',', 1) - id = int(id) - except ValueError: - log.warning('Invalid line in HeraldDB: %r', line) - continue - self.heralds[(id, channel)] = msg - fd.close() - - def close(self): - fd = file(self.filename, 'w') - L = self.heralds.items() - L.sort() - for ((id, channel), msg) in L: - fd.write('%s,%s:%s%s' % (id, channel, msg, os.linesep)) - fd.close() - - def getHerald(self, id, channel): - return self.heralds[(id, channel)] - - def setHerald(self, id, channel, msg): - self.heralds[(id, channel)] = msg - - def delHerald(self, id, channel): - del self.heralds[(id, channel)] +class HeraldDB(plugins.ChannelUserDatabase): + def serialize(self, v): + return [v] + def deserialize(self, L): + if len(L) != 1: + raise ValueError + return L[0] conf.registerPlugin('Herald') conf.registerChannelValue(conf.supybot.plugins.Herald, 'heralding', @@ -107,11 +77,14 @@ conf.registerChannelValue(conf.supybot.plugins.Herald, 'throttleTimeAfterPart', class Herald(callbacks.Privmsg): def __init__(self): callbacks.Privmsg.__init__(self) - self.db = HeraldDB() - self.lastParts = {} - self.lastHerald = {} + self.db = HeraldDB(filename) + world.flushers.append(self.db.flush) + self.lastParts = plugins.ChannelUserDictionary() + self.lastHerald = plugins.ChannelUserDictionary() def die(self): + if self.db.flush in world.flushers: + world.flushers.remove(self.db.flush) self.db.close() callbacks.Privmsg.die(self) @@ -120,17 +93,17 @@ class Herald(callbacks.Privmsg): if self.registryValue('heralding', channel): try: id = ircdb.users.getUserId(msg.prefix) - herald = self.db.getHerald(id, channel) + herald = self.db[channel, id] except KeyError: return now = time.time() throttle = self.registryValue('throttleTime', channel) - if now - self.lastHerald.get((id, channel), 0) > throttle: - if (id, channel) in self.lastParts: + if now - self.lastHerald.get((channel, id), 0) > throttle: + if (channel, id) in self.lastParts: i = self.registryValue('throttleTimeAfterPart', channel) - if now - self.lastParts[(id, channel)] < i: + if now - self.lastParts[channel, id] < i: return - self.lastHerald[(id, channel)] = now + self.lastHerald[channel, id] = now irc.queueMsg(ircmsgs.privmsg(channel, herald)) def doPart(self, irc, msg): @@ -165,7 +138,7 @@ class Herald(callbacks.Privmsg): except KeyError: irc.errorNoUser() return - self.db.setHerald(id, channel, herald) + self.db[channel, id] = herald irc.replySuccess() def remove(self, irc, msg, args): @@ -183,7 +156,7 @@ class Herald(callbacks.Privmsg): except KeyError: irc.errorNoUser() return - self.db.delHerald(id, channel) + del self.db[channel, id] irc.replySuccess() diff --git a/src/plugins.py b/src/plugins.py index 8f463c75c..c03925dde 100644 --- a/src/plugins.py +++ b/src/plugins.py @@ -36,15 +36,18 @@ import fix import gc import os import re +import csv import sys import sets import time import types import random import urllib2 +import UserDict import threading import cdb +import log import conf import utils import world @@ -157,6 +160,86 @@ class ChannelDBHandler(object): gc.collect() +class ChannelUserDictionary(UserDict.DictMixin): + def __init__(self): + self.channels = ircutils.IrcDict() + + def __getitem__(self, (channel, id)): + return self.channels[channel][id] + + def __setitem__(self, (channel, id), v): + if channel not in self.channels: + self.channels[channel] = {} + self.channels[channel][id] = v + + def __delitem__(self, (channel, id)): + del self.channels[channel][id] + + def iteritems(self): + for (channel, ids) in self.channels.iteritems(): + for (id, v) in ids.iteritems(): + yield ((channel, id), v) + + def keys(self): + L = [] + for (k, _) in self.iteritems(): + L.append(k) + return L + + +class ChannelUserDatabase(ChannelUserDictionary): + def __init__(self, filename): + ChannelUserDictionary.__init__(self) + self.filename = filename + try: + fd = file(self.filename) + except EnvironmentError, e: + log.warning('Couldn\'t open %s: %s.', self.filename, e) + return + reader = csv.reader(fd) + try: + lineno = 0 + for t in reader: + lineno += 1 + try: + channel = t.pop(0) + id = t.pop(0) + id = int(id) + v = self.deserialize(t) + self[channel, id] = v + except Exception, e: + log.warning('Invalid line #%s in %s.', + lineno, self.__class__.__name__) + except Exception, e: # This catches exceptions from csv.reader. + log.warning('Invalid line #%s in %s.', + lineno, self.__class__.__name__) + + def flush(self): + fd = file(self.filename, 'w') + writer = csv.writer(fd) + items = self.items() + items.sort() + for ((channel, id), v) in items: + L = self.serialize(v) + L.insert(0, id) + L.insert(0, channel) + writer.writerow(L) + fd.close() + + def close(self): + self.flush() + self.clear() + + def deserialize(self, L): + """Should take a list of strings and return an object to be accessed + via self.get(channel, id).""" + raise NotImplementedError + + def serialize(self, x): + """Should take an object (as returned by self.get(channel, id)) and + return a list (of any type serializable to csv).""" + raise NotImplementedError + class PeriodicFileDownloader(object): """A class to periodically download a file/files.