Merge branch 'rewrite-rss' into testing

This commit is contained in:
Valentin Lorentz 2014-07-31 22:53:44 +02:00
commit 7ac053d8f1
4 changed files with 71 additions and 10 deletions

View File

@ -29,13 +29,15 @@
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
### ###
import re
import os
import sys
import json
import time import time
import types import types
import string import string
import socket import socket
import threading import threading
import re
import sys
import feedparser import feedparser
import supybot.conf as conf import supybot.conf as conf
@ -56,10 +58,13 @@ def get_feedName(irc, msg, args, state):
state.args.append(callbacks.canonicalName(args.pop(0))) state.args.append(callbacks.canonicalName(args.pop(0)))
addConverter('feedName', get_feedName) addConverter('feedName', get_feedName)
announced_headlines_filename = \
conf.supybot.directories.data.dirize('RSS_announced.flat')
class Feed: class Feed:
__slots__ = ('url', 'name', 'data', 'last_update', 'entries', __slots__ = ('url', 'name', 'data', 'last_update', 'entries',
'lock', 'announced_entries') 'lock', 'announced_entries')
def __init__(self, name, url, plugin_is_loading=False): def __init__(self, name, url, plugin_is_loading=False, announced=None):
assert name, name assert name, name
if not url: if not url:
assert utils.web.httpUrlRe.match(name), name assert utils.web.httpUrlRe.match(name), name
@ -72,7 +77,12 @@ class Feed:
self.last_update = time.time() if plugin_is_loading else 0 self.last_update = time.time() if plugin_is_loading else 0
self.entries = [] self.entries = []
self.lock = threading.Lock() self.lock = threading.Lock()
self.announced_entries = utils.structures.TruncatableSet() self.announced_entries = announced or \
utils.structures.TruncatableSet()
def __repr__(self):
return 'Feed(%r, %r, <bool>, %r)' % \
(self.name, self.url, self.announced_entries)
def get_command(self, plugin): def get_command(self, plugin):
docstring = format(_("""[<number of headlines>] docstring = format(_("""[<number of headlines>]
@ -105,6 +115,14 @@ def sort_feed_items(items, order):
return items return items
return sitems return sitems
def load_announces_db(fd):
return dict((name, utils.structures.TruncatableSet(entries))
for (name, entries) in json.load(fd).items())
def save_announces_db(db, fd):
json.dump(dict((name, list(entries)) for (name, entries) in db), fd)
class RSS(callbacks.Plugin): class RSS(callbacks.Plugin):
"""This plugin is useful both for announcing updates to RSS feeds in a """This plugin is useful both for announcing updates to RSS feeds in a
channel, and for retrieving the headlines of RSS feeds via command. Use channel, and for retrieving the headlines of RSS feeds via command. Use
@ -118,6 +136,11 @@ class RSS(callbacks.Plugin):
self.feed_names = callbacks.CanonicalNameDict() self.feed_names = callbacks.CanonicalNameDict()
# Scheme: {url: feed} # Scheme: {url: feed}
self.feeds = {} self.feeds = {}
if os.path.isfile(announced_headlines_filename):
with open(announced_headlines_filename) as fd:
announced = load_announces_db(fd)
else:
announced = {}
for name in self.registryValue('feeds'): for name in self.registryValue('feeds'):
self.assert_feed_does_not_exist(name) self.assert_feed_does_not_exist(name)
self.register_feed_config(name) self.register_feed_config(name)
@ -126,7 +149,20 @@ class RSS(callbacks.Plugin):
except registry.NonExistentRegistryEntry: except registry.NonExistentRegistryEntry:
self.log.warning('%s is not a registered feed, removing.',name) self.log.warning('%s is not a registered feed, removing.',name)
continue continue
self.register_feed(name, url, True) self.register_feed(name, url, True, announced.get(name, []))
world.flushers.append(self._flush)
def die(self):
self._flush()
world.flushers.remove(self._flush)
self.__parent.die()
def _flush(self):
l = [(f.name, f.announced_entries) for f in self.feeds.values()]
with utils.file.AtomicFile(announced_headlines_filename, 'wb',
backupDir='/dev/null') as fd:
save_announces_db(l, fd)
################## ##################
# Feed registering # Feed registering
@ -141,9 +177,9 @@ class RSS(callbacks.Plugin):
group = self.registryValue('feeds', value=False) group = self.registryValue('feeds', value=False)
conf.registerGlobalValue(group, name, registry.String(url, '')) conf.registerGlobalValue(group, name, registry.String(url, ''))
def register_feed(self, name, url, plugin_is_loading): def register_feed(self, name, url, plugin_is_loading, announced=[]):
self.feed_names[name] = url self.feed_names[name] = url
self.feeds[url] = Feed(name, url, plugin_is_loading) self.feeds[url] = Feed(name, url, plugin_is_loading, announced)
def remove_feed(self, feed): def remove_feed(self, feed):
del self.feed_names[feed.name] del self.feed_names[feed.name]

View File

@ -81,6 +81,23 @@ class RSSTestCase(ChannelPluginTestCase):
self._feedMsg('rss remove xkcd') self._feedMsg('rss remove xkcd')
feedparser._open_resource = old_open feedparser._open_resource = old_open
def testAnnounceReload(self):
old_open = feedparser._open_resource
feedparser._open_resource = constant(xkcd_old)
try:
with conf.supybot.plugins.RSS.waitPeriod.context(1):
self.assertNotError('rss add xkcd http://xkcd.com/rss.xml')
self.assertNotError('rss announce add xkcd')
self.assertNotError(' ')
self.assertNotError('reload RSS')
self.assertNoResponse(' ')
time.sleep(1.1)
self.assertNoResponse(' ')
finally:
self._feedMsg('rss announce remove xkcd')
self._feedMsg('rss remove xkcd')
feedparser._open_resource = old_open
if network: if network:
def testRssinfo(self): def testRssinfo(self):
self.assertNotError('rss info %s' % url) self.assertNotError('rss info %s' % url)

View File

@ -156,13 +156,19 @@ class AtomicFile(object):
# self.__parent = super(AtomicFile, self) # self.__parent = super(AtomicFile, self)
self._fd = open(self.tempFilename, mode) self._fd = open(self.tempFilename, mode)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
self.rollback()
else:
self.close()
@property @property
def closed(self): def closed(self):
return self._fd.closed return self._fd.closed
def close(self):
return self._fd.close()
def write(self, data): def write(self, data):
return self._fd.write(data) return self._fd.write(data)

View File

@ -460,6 +460,8 @@ class TruncatableSet(collections.MutableSet):
def __init__(self, iterable=[]): def __init__(self, iterable=[]):
self._ordered_items = list(iterable) self._ordered_items = list(iterable)
self._items = set(self._ordered_items) self._items = set(self._ordered_items)
def __repr__(self):
return 'TruncatableSet({%r})' % self._items
def __contains__(self, item): def __contains__(self, item):
return item in self._items return item in self._items
def __iter__(self): def __iter__(self):