From f77f48d0f3bd45c8bca1e22d94e9166e6db2646d Mon Sep 17 00:00:00 2001 From: Jeremy Fincher Date: Fri, 9 Apr 2004 05:22:56 +0000 Subject: [PATCH] Added some lockingEXCLAIM w00rEXCLAIM LocksEXCLAIM --- plugins/RSS.py | 101 +++++++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/plugins/RSS.py b/plugins/RSS.py index 5bff9c8fb..2af4535cd 100644 --- a/plugins/RSS.py +++ b/plugins/RSS.py @@ -95,9 +95,9 @@ class RSS(callbacks.Privmsg): def __init__(self): callbacks.Privmsg.__init__(self) self.feedNames = sets.Set() + self.locks = {} self.lastRequest = {} self.cachedFeeds = {} - self.currentlyDownloading = sets.Set() for (name, url) in registry._cache.iteritems(): name = name.lower() if name.startswith('supybot.plugins.rss.feeds.'): @@ -121,53 +121,71 @@ class RSS(callbacks.Privmsg): url = name if self.willGetNewFeed(url): newFeeds.setdefault(url, []).append(channel) - for (feed, channels) in newFeeds.iteritems(): - t = threading.Thread(target=self._newHeadlines, - name='Fetching <%s>' % url, - args=(irc, channels, name, url)) - self.log.info('Spawning thread to fetch <%s>', url) - world.threadsSpawned += 1 - t.setDaemon(True) - t.start() + for (url, channels) in newFeeds.iteritems(): + # We check if we can acquire the lock right here because if we + # don't, we'll possibly end up spawning a lot of threads to get + # the feed, because this thread may run for a number of bytecodes + # before it switches to a thread that'll get the lock in + # _newHeadlines. + if self.locks[url].acquire(blocking=False): + try: + t = threading.Thread(target=self._newHeadlines, + name='Fetching <%s>' % url, + args=(irc, channels, name, url)) + self.log.info('Spawning thread to fetch <%s>', url) + world.threadsSpawned += 1 + t.setDaemon(True) + t.start() + finally: + self.locks[url].release() def _newHeadlines(self, irc, channels, name, url): try: - oldresults = self.cachedFeeds[url] - oldheadlines = self.getHeadlines(oldresults) - except KeyError: - oldheadlines = [] - newresults = self.getFeed(url) - newheadlines = self.getHeadlines(newresults) - for headline in oldheadlines: + # We acquire the lock here so there's only one announcement thread + # in this code at any given time. Otherwise, several announcement + # threads will getFeed (all blocking, in turn); then they'll all + # want to sent their news messages to the appropriate channels. + self.locks[url].acquire() try: - newheadlines.remove(headline) - except ValueError: - pass - if newheadlines: - for channel in channels: - bold = self.registryValue('bold', channel) - sep = self.registryValue('headlineSeparator', channel) - prefix = self.registryValue('announcementPrefix', channel) - pre = '%s%s: ' % (prefix, name) - if bold: - pre = ircutils.bold(pre) - irc.replies(newheadlines, prefixer=pre, joiner=sep, - to=channel, prefixName=False, private=True) + oldresults = self.cachedFeeds[url] + oldheadlines = self.getHeadlines(oldresults) + except KeyError: + oldheadlines = [] + newresults = self.getFeed(url) + newheadlines = self.getHeadlines(newresults) + for headline in oldheadlines: + try: + newheadlines.remove(headline) + except ValueError: + pass + if newheadlines: + for channel in channels: + bold = self.registryValue('bold', channel) + sep = self.registryValue('headlineSeparator', channel) + prefix = self.registryValue('announcementPrefix', channel) + pre = '%s%s: ' % (prefix, name) + if bold: + pre = ircutils.bold(pre) + irc.replies(newheadlines, prefixer=pre, joiner=sep, + to=channel, prefixName=False, private=True) + finally: + self.locks[url].release() def willGetNewFeed(self, url): now = time.time() wait = self.registryValue('waitPeriod') - if url in self.currentlyDownloading: - return False if url not in self.lastRequest or now - self.lastRequest[url] > wait: return True else: return False def getFeed(self, url): - if self.willGetNewFeed(url): - try: - self.currentlyDownloading.add(url) + try: + # This is the most obvious place to acquire the lock, because a + # malicious user could conceivably flood the bot with rss commands + # and DoS the website in question. + self.locks[url].acquire() + if self.willGetNewFeed(url): try: self.log.info('Downloading new feed from <%s>', url) results = rssparser.parse(url) @@ -176,13 +194,13 @@ class RSS(callbacks.Privmsg): raise callbacks.Error, 'Invalid (unparseable) RSS feed.' self.cachedFeeds[url] = results self.lastRequest[url] = time.time() - finally: - self.currentlyDownloading.discard(url) - try: - return self.cachedFeeds[url] - except KeyError: - self.lastRequest[url] = 0 - return {'items': {'title': 'Unable to download feed.'}} + try: + return self.cachedFeeds[url] + except KeyError: + self.lastRequest[url] = 0 + return {'items': [{'title': 'Unable to download feed.'}]} + finally: + self.locks[url].release() def getHeadlines(self, feed): return [utils.htmlToText(d['title'].strip()) for d in feed['items']] @@ -195,6 +213,7 @@ class RSS(callbacks.Privmsg): to 1800 (30 minutes) since that's what most websites prefer. """ % (name, url) name = callbacks.canonicalName(name) + self.locks[url] = threading.RLock() if hasattr(self, name): s = 'I already have a command in this plugin named %s' % name raise callbacks.Error, s