Aka: Fix 'factorial-complexity' recursion and command overriding.

This commit is contained in:
Valentin Lorentz 2013-12-11 16:01:01 +00:00
parent c774013e1f
commit 01278dc56c
3 changed files with 18 additions and 13 deletions

View File

@ -102,7 +102,7 @@ if sqlalchemy:
def has_aka(self, channel, name): def has_aka(self, channel, name):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
count = self.get_db(channel).query(Alias) \ count = self.get_db(channel).query(Alias) \
@ -114,7 +114,7 @@ if sqlalchemy:
return list_ return list_
def get_alias(self, channel, name): def get_alias(self, channel, name):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
try: try:
@ -124,7 +124,7 @@ if sqlalchemy:
return None return None
def add_aka(self, channel, name, alias): def add_aka(self, channel, name, alias):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if self.has_aka(channel, name): if self.has_aka(channel, name):
raise AkaError(_('This Aka already exists.')) raise AkaError(_('This Aka already exists.'))
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
@ -137,7 +137,7 @@ if sqlalchemy:
db.commit() db.commit()
def remove_aka(self, channel, name): def remove_aka(self, channel, name):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
db = self.get_db(channel) db = self.get_db(channel)
@ -145,7 +145,7 @@ if sqlalchemy:
db.commit() db.commit()
def lock_aka(self, channel, name, by): def lock_aka(self, channel, name, by):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
db = self.get_db(channel) db = self.get_db(channel)
@ -162,7 +162,7 @@ if sqlalchemy:
db.commit() db.commit()
def unlock_aka(self, channel, name, by): def unlock_aka(self, channel, name, by):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
db = self.get_db(channel) db = self.get_db(channel)
@ -179,7 +179,7 @@ if sqlalchemy:
db.commit() db.commit()
def get_aka_lock(self, channel, name): def get_aka_lock(self, channel, name):
name = callbacks.canonicalName(name) name = callbacks.canonicalName(name, preserve_spaces=True)
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
try: try:
@ -247,7 +247,7 @@ class Aka(callbacks.Plugin):
if len(args) > 1 and \ if len(args) > 1 and \
callbacks.canonicalName(args[0]) != self.canonicalName(): callbacks.canonicalName(args[0]) != self.canonicalName():
for cb in dynamic.irc.callbacks: # including this plugin for cb in dynamic.irc.callbacks: # including this plugin
if cb.getCommand(args[0:-1]): if cb.isCommandMethod(' '.join(args[0:-1])):
return False return False
if sys.version_info[0] < 3 and isinstance(name, str): if sys.version_info[0] < 3 and isinstance(name, str):
name = name.decode('utf8') name = name.decode('utf8')
@ -264,7 +264,7 @@ class Aka(callbacks.Plugin):
self._db.get_aka_list('global')) + self._db.get_aka_list('global')) +
['add', 'remove', 'lock', 'unlock', 'importaliasdatabase'])) ['add', 'remove', 'lock', 'unlock', 'importaliasdatabase']))
def getCommand(self, args): def getCommand(self, args, check_other_plugins=True):
canonicalName = callbacks.canonicalName canonicalName = callbacks.canonicalName
# All the code from here to the 'for' loop is copied from callbacks.py # All the code from here to the 'for' loop is copied from callbacks.py
assert args == map(canonicalName, args) assert args == map(canonicalName, args)
@ -273,10 +273,10 @@ class Aka(callbacks.Plugin):
if first == cb.canonicalName(): if first == cb.canonicalName():
return cb.getCommand(args[1:]) return cb.getCommand(args[1:])
if first == self.canonicalName() and len(args) > 1: if first == self.canonicalName() and len(args) > 1:
ret = self.getCommand(args[1:]) ret = self.getCommand(args[1:], False)
if ret: if ret:
return [first] + ret return [first] + ret
for i in xrange(len(args), 0, -1): for i in xrange(1, len(args)+1):
if self.isCommandMethod(callbacks.formatCommand(args[0:i])): if self.isCommandMethod(callbacks.formatCommand(args[0:i])):
return args[0:i] return args[0:i]
return [] return []

View File

@ -164,6 +164,9 @@ class AkaChannelTestCase(ChannelPluginTestCase):
def testNoOverride(self): def testNoOverride(self):
self.assertNotError('aka add "echo foo" "echo bar"') self.assertNotError('aka add "echo foo" "echo bar"')
self.assertResponse('echo foo', 'foo') self.assertResponse('echo foo', 'foo')
self.assertNotError('aka add foo "echo baz"')
self.assertNotError('aka add "foo bar" "echo qux"')
self.assertResponse('foo bar', 'baz')
def testRecursivity(self): def testRecursivity(self):
self.assertNotError('aka add fact ' self.assertNotError('aka add fact '

View File

@ -146,7 +146,7 @@ def addressed(nick, msg, **kwargs):
msg.tag('addressed', payload) msg.tag('addressed', payload)
return payload return payload
def canonicalName(command): def canonicalName(command, preserve_spaces=False):
"""Turn a command into its canonical form. """Turn a command into its canonical form.
Currently, this makes everything lowercase and removes all dashes and Currently, this makes everything lowercase and removes all dashes and
@ -156,7 +156,9 @@ def canonicalName(command):
command = command.encode('utf-8') command = command.encode('utf-8')
elif sys.version_info[0] >= 3 and isinstance(command, bytes): elif sys.version_info[0] >= 3 and isinstance(command, bytes):
command = command.decode() command = command.decode()
special = '\t -_' special = '\t-_'
if not preserve_spaces:
special += ' '
reAppend = '' reAppend = ''
while command and command[-1] in special: while command and command[-1] in special:
reAppend = command[-1] + reAppend reAppend = command[-1] + reAppend