3
0
mirror of https://github.com/jlu5/PyLink.git synced 2025-01-25 19:54:25 +01:00

relay: add locks in db read/writes (thread safety)

This commit is contained in:
James Lu 2016-12-09 17:43:50 -08:00
parent e40b2f6529
commit 2b4943a780

View File

@ -14,6 +14,7 @@ relayusers = defaultdict(dict)
relayservers = defaultdict(dict) relayservers = defaultdict(dict)
spawnlocks = defaultdict(threading.RLock) spawnlocks = defaultdict(threading.RLock)
spawnlocks_servers = defaultdict(threading.RLock) spawnlocks_servers = defaultdict(threading.RLock)
db_lock = threading.RLock()
dbname = utils.getDatabaseName('pylinkrelay') dbname = utils.getDatabaseName('pylinkrelay')
datastore = structures.PickleDataStore('pylinkrelay', dbname) datastore = structures.PickleDataStore('pylinkrelay', dbname)
@ -34,15 +35,16 @@ def initializeAll(irc):
# which would break connections. # which would break connections.
world.started.wait(2) world.started.wait(2)
for chanpair, entrydata in db.items(): with db_lock:
# Iterate over all the channels stored in our relay links DB. for chanpair, entrydata in db.items():
network, channel = chanpair # Iterate over all the channels stored in our relay links DB.
network, channel = chanpair
# Initialize each relay channel on their home network, and on every linked one too. # Initialize each relay channel on their home network, and on every linked one too.
initializeChannel(irc, channel)
for link in entrydata['links']:
network, channel = link
initializeChannel(irc, channel) initializeChannel(irc, channel)
for link in entrydata['links']:
network, channel = link
initializeChannel(irc, channel)
def main(irc=None): def main(irc=None):
"""Main function, called during plugin loading at start.""" """Main function, called during plugin loading at start."""
@ -410,12 +412,13 @@ def getOrigUser(irc, user, targetirc=None):
def getRelay(chanpair): def getRelay(chanpair):
"""Finds the matching relay entry name for the given (network name, channel) """Finds the matching relay entry name for the given (network name, channel)
pair, if one exists.""" pair, if one exists."""
if chanpair in db: # This chanpair is a shared channel; others link to it with db_lock:
return chanpair if chanpair in db: # This chanpair is a shared channel; others link to it
# This chanpair is linked *to* a remote channel return chanpair
for name, dbentry in db.items(): # This chanpair is linked *to* a remote channel
if chanpair in dbentry['links']: for name, dbentry in db.items():
return name if chanpair in dbentry['links']:
return name
def getRemoteChan(irc, remoteirc, channel): def getRemoteChan(irc, remoteirc, channel):
"""Returns the linked channel name for the given channel on remoteirc, """Returns the linked channel name for the given channel on remoteirc,
@ -428,9 +431,10 @@ def getRemoteChan(irc, remoteirc, channel):
if chanpair[0] == remotenetname: if chanpair[0] == remotenetname:
return chanpair[1] return chanpair[1]
else: else:
for link in db[chanpair]['links']: with db_lock:
if link[0] == remotenetname: for link in db[chanpair]['links']:
return link[1] if link[0] == remotenetname:
return link[1]
def initializeChannel(irc, channel): def initializeChannel(irc, channel):
"""Initializes a relay channel (merge local/remote users, set modes, etc.).""" """Initializes a relay channel (merge local/remote users, set modes, etc.)."""
@ -441,8 +445,9 @@ def initializeChannel(irc, channel):
log.debug('(%s) relay.initializeChannel: relay pair found to be %s', irc.name, relay) log.debug('(%s) relay.initializeChannel: relay pair found to be %s', irc.name, relay)
queued_users = [] queued_users = []
if relay: if relay:
all_links = db[relay]['links'].copy() with db_lock:
all_links.update((relay,)) all_links = db[relay]['links'].copy()
all_links.update((relay,))
log.debug('(%s) relay.initializeChannel: all_links: %s', irc.name, all_links) log.debug('(%s) relay.initializeChannel: all_links: %s', irc.name, all_links)
# Iterate over all the remote channels linked in this relay. # Iterate over all the remote channels linked in this relay.
@ -524,12 +529,14 @@ def checkClaim(irc, channel, sender, chanobj=None):
sender_modes = getPrefixModes(irc, irc, channel, sender, mlist=mlist) sender_modes = getPrefixModes(irc, irc, channel, sender, mlist=mlist)
log.debug('(%s) relay.checkClaim: sender modes (%s/%s) are %s (mlist=%s)', irc.name, log.debug('(%s) relay.checkClaim: sender modes (%s/%s) are %s (mlist=%s)', irc.name,
sender, channel, sender_modes, mlist) sender, channel, sender_modes, mlist)
# XXX: stop hardcoding modes to check for and support mlist in isHalfopPlus and friends # XXX: stop hardcoding modes to check for and support mlist in isHalfopPlus and friends
return (not relay) or irc.name == relay[0] or not db[relay]['claim'] or \ with db_lock:
irc.name in db[relay]['claim'] or \ return (not relay) or irc.name == relay[0] or not db[relay]['claim'] or \
any([mode in sender_modes for mode in ('y', 'q', 'a', 'o', 'h')]) \ irc.name in db[relay]['claim'] or \
or irc.isInternalClient(sender) or \ any([mode in sender_modes for mode in ('y', 'q', 'a', 'o', 'h')]) \
irc.isInternalServer(sender) or irc.isInternalClient(sender) or \
irc.isInternalServer(sender)
def getSupportedUmodes(irc, remoteirc, modes): def getSupportedUmodes(irc, remoteirc, modes):
"""Given a list of user modes, filters out all of those not supported by the """Given a list of user modes, filters out all of those not supported by the
@ -1542,9 +1549,10 @@ def create(irc, source, args):
creator = irc.getHostmask(source) creator = irc.getHostmask(source)
# Create the relay database entry with the (network name, channel name) # Create the relay database entry with the (network name, channel name)
# pair - this is just a dict with various keys. # pair - this is just a dict with various keys.
db[(irc.name, channel)] = {'claim': [irc.name], 'links': set(), with db_lock:
'blocked_nets': set(), 'creator': creator, db[(irc.name, channel)] = {'claim': [irc.name], 'links': set(),
'ts': time.time()} 'blocked_nets': set(), 'creator': creator,
'ts': time.time()}
log.info('(%s) relay: Channel %s created by %s.', irc.name, channel, creator) log.info('(%s) relay: Channel %s created by %s.', irc.name, channel, creator)
initializeChannel(irc, channel) initializeChannel(irc, channel)
irc.reply('Done.') irc.reply('Done.')
@ -1554,9 +1562,10 @@ def _stop_relay(entry):
"""Internal function to deinitialize a relay link and its leaves.""" """Internal function to deinitialize a relay link and its leaves."""
network, channel = entry network, channel = entry
# Iterate over all the channel links and deinitialize them. # Iterate over all the channel links and deinitialize them.
for link in db[entry]['links']: with db_lock:
removeChannel(world.networkobjects.get(link[0]), link[1]) for link in db[entry]['links']:
removeChannel(world.networkobjects.get(network), channel) removeChannel(world.networkobjects.get(link[0]), link[1])
removeChannel(world.networkobjects.get(network), channel)
def destroy(irc, source, args): def destroy(irc, source, args):
"""[<home network>] <channel> """[<home network>] <channel>
@ -1586,17 +1595,18 @@ def destroy(irc, source, args):
entry = (network, channel) entry = (network, channel)
if entry in db: with db_lock:
_stop_relay(entry) if entry in db:
del db[entry] _stop_relay(entry)
del db[entry]
log.info('(%s) relay: Channel %s destroyed by %s.', irc.name, log.info('(%s) relay: Channel %s destroyed by %s.', irc.name,
channel, irc.getHostmask(source)) channel, irc.getHostmask(source))
irc.reply('Done.') irc.reply('Done.')
else: else:
irc.error("No such channel %r exists. If you're trying to delink a channel from " irc.error("No such channel %r exists. If you're trying to delink a channel from "
"another network, use the DESTROY command." % channel) "another network, use the DESTROY command." % channel)
return return
destroy = utils.add_cmd(destroy, featured=True) destroy = utils.add_cmd(destroy, featured=True)
@utils.add_cmd @utils.add_cmd
@ -1613,20 +1623,20 @@ def purge(irc, source, args):
count = 0 count = 0
### XXX lock to make this thread safe! with db_lock:
for entry in db.copy(): for entry in db.copy():
# Entry was owned by the target network; remove it # Entry was owned by the target network; remove it
if entry[0] == network: if entry[0] == network:
count += 1 count += 1
_stop_relay(entry) _stop_relay(entry)
del db[entry] del db[entry]
else: else:
# Drop leaf channels involving the target network # Drop leaf channels involving the target network
for link in db[entry]['links'].copy(): for link in db[entry]['links'].copy():
if link[0] == network: if link[0] == network:
count += 1 count += 1
removeChannel(world.networkobjects.get(network), link[1]) removeChannel(world.networkobjects.get(network), link[1])
db[entry]['links'].remove(link) db[entry]['links'].remove(link)
irc.reply("Done. Purged %s entries involving the network %s." % (count, network)) irc.reply("Done. Purged %s entries involving the network %s." % (count, network))
@ -1686,7 +1696,8 @@ def link(irc, source, args):
return return
try: try:
entry = db[(remotenet, channel)] with db_lock:
entry = db[(remotenet, channel)]
except KeyError: except KeyError:
irc.error('No such relay %r exists.' % channel) irc.error('No such relay %r exists.' % channel)
return return
@ -1747,13 +1758,15 @@ def delink(irc, source, args):
"network).") "network).")
return return
else: else:
for link in db[entry]['links'].copy(): with db_lock:
if link[0] == remotenet: for link in db[entry]['links'].copy():
removeChannel(world.networkobjects.get(remotenet), link[1]) if link[0] == remotenet:
db[entry]['links'].remove(link) removeChannel(world.networkobjects.get(remotenet), link[1])
db[entry]['links'].remove(link)
else: else:
removeChannel(irc, channel) removeChannel(irc, channel)
db[entry]['links'].remove((irc.name, channel)) with db_lock:
db[entry]['links'].remove((irc.name, channel))
irc.reply('Done.') irc.reply('Done.')
log.info('(%s) relay: Channel %s delinked from %s%s by %s.', irc.name, log.info('(%s) relay: Channel %s delinked from %s%s by %s.', irc.name,
channel, entry[0], entry[1], irc.getHostmask(source)) channel, entry[0], entry[1], irc.getHostmask(source))
@ -1788,58 +1801,59 @@ def linked(irc, source, args):
irc.reply("Showing channels linked to %s:" % net, private=True) irc.reply("Showing channels linked to %s:" % net, private=True)
# Sort the list of shared channels when displaying # Sort the list of shared channels when displaying
for k, v in sorted(db.items()): with db_lock:
for k, v in sorted(db.items()):
# Skip if we're filtering by network and the network given isn't relayed # Skip if we're filtering by network and the network given isn't relayed
# to the channel. # to the channel.
if net and not (net == k[0] or net in [link[0] for link in v['links']]): if net and not (net == k[0] or net in [link[0] for link in v['links']]):
continue continue
# Bold each network/channel name pair # Bold each network/channel name pair
s = '\x02%s%s\x02 ' % k s = '\x02%s%s\x02 ' % k
remoteirc = world.networkobjects.get(k[0]) remoteirc = world.networkobjects.get(k[0])
channel = k[1] # Get the channel name from the network/channel pair channel = k[1] # Get the channel name from the network/channel pair
if remoteirc and channel in remoteirc.channels: if remoteirc and channel in remoteirc.channels:
c = remoteirc.channels[channel] c = remoteirc.channels[channel]
if ('s', None) in c.modes or ('p', None) in c.modes: if ('s', None) in c.modes or ('p', None) in c.modes:
# Only show secret channels to opers or those in the channel, and tag them as # Only show secret channels to opers or those in the channel, and tag them as
# [secret]. # [secret].
localchan = getRemoteChan(remoteirc, irc, channel) localchan = getRemoteChan(remoteirc, irc, channel)
if irc.isOper(source) or (localchan and source in irc.channels[localchan].users): if irc.isOper(source) or (localchan and source in irc.channels[localchan].users):
s += '\x02[secret]\x02 ' s += '\x02[secret]\x02 '
else: else:
continue continue
if v['links']: if v['links']:
# Sort, join up and output all the linked channel names. Silently drop # Sort, join up and output all the linked channel names. Silently drop
# entries for disconnected networks. # entries for disconnected networks.
s += ' '.join([''.join(link) for link in sorted(v['links']) if link[0] in world.networkobjects s += ' '.join([''.join(link) for link in sorted(v['links']) if link[0] in world.networkobjects
and world.networkobjects[link[0]].connected.is_set()]) and world.networkobjects[link[0]].connected.is_set()])
else: # Unless it's empty; then, well... just say no relays yet. else: # Unless it's empty; then, well... just say no relays yet.
s += '(no relays yet)' s += '(no relays yet)'
irc.reply(s, private=True) irc.reply(s, private=True)
if irc.isOper(source): if irc.isOper(source):
s = '' s = ''
# If the caller is an oper, we can show the hostmasks of people # If the caller is an oper, we can show the hostmasks of people
# that created all the available channels (Janus does this too!!) # that created all the available channels (Janus does this too!!)
creator = v.get('creator') creator = v.get('creator')
if creator: if creator:
# But only if the value actually exists (old DBs will have it # But only if the value actually exists (old DBs will have it
# missing). # missing).
s += ' by \x02%s\x02' % creator s += ' by \x02%s\x02' % creator
# Ditto for creation date # Ditto for creation date
ts = v.get('ts') ts = v.get('ts')
if ts: if ts:
s += ' on %s' % time.ctime(ts) s += ' on %s' % time.ctime(ts)
if s: # Indent to make the list look nicer if s: # Indent to make the list look nicer
irc.reply(' Channel created%s.' % s, private=True) irc.reply(' Channel created%s.' % s, private=True)
linked = utils.add_cmd(linked, featured=True) linked = utils.add_cmd(linked, featured=True)
@utils.add_cmd @utils.add_cmd
@ -1864,30 +1878,32 @@ def linkacl(irc, source, args):
if not relay: if not relay:
irc.error('No such relay %r exists.' % channel) irc.error('No such relay %r exists.' % channel)
return return
if cmd == 'list':
permissions.checkPermissions(irc, source, ['relay.linkacl.view'])
s = 'Blocked networks for \x02%s\x02: \x02%s\x02' % (channel, ', '.join(db[relay]['blocked_nets']) or '(empty)')
irc.reply(s)
return
permissions.checkPermissions(irc, source, ['relay.linkacl']) with db_lock:
try: if cmd == 'list':
remotenet = args[2] permissions.checkPermissions(irc, source, ['relay.linkacl.view'])
except IndexError: s = 'Blocked networks for \x02%s\x02: \x02%s\x02' % (channel, ', '.join(db[relay]['blocked_nets']) or '(empty)')
irc.error(missingargs) irc.reply(s)
return return
if cmd == 'deny':
db[relay]['blocked_nets'].add(remotenet) permissions.checkPermissions(irc, source, ['relay.linkacl'])
irc.reply('Done.')
elif cmd == 'allow':
try: try:
db[relay]['blocked_nets'].remove(remotenet) remotenet = args[2]
except KeyError: except IndexError:
irc.error('Network %r is not on the blacklist for %r.' % (remotenet, channel)) irc.error(missingargs)
else: return
if cmd == 'deny':
db[relay]['blocked_nets'].add(remotenet)
irc.reply('Done.') irc.reply('Done.')
else: elif cmd == 'allow':
irc.error('Unknown subcommand %r: valid ones are ALLOW, DENY, and LIST.' % cmd) try:
db[relay]['blocked_nets'].remove(remotenet)
except KeyError:
irc.error('Network %r is not on the blacklist for %r.' % (remotenet, channel))
else:
irc.reply('Done.')
else:
irc.error('Unknown subcommand %r: valid ones are ALLOW, DENY, and LIST.' % cmd)
@utils.add_cmd @utils.add_cmd
def showuser(irc, source, args): def showuser(irc, source, args):
@ -1982,20 +1998,21 @@ def claim(irc, source, args):
# We override getRelay() here to limit the search to the current network. # We override getRelay() here to limit the search to the current network.
relay = (irc.name, channel) relay = (irc.name, channel)
if relay not in db: with db_lock:
irc.error('No such relay %r exists.' % channel) if relay not in db:
return irc.error('No such relay %r exists.' % channel)
claimed = db[relay]["claim"] return
try: claimed = db[relay]["claim"]
nets = args[1].strip() try:
except IndexError: # No networks given. nets = args[1].strip()
irc.reply('Channel \x02%s\x02 is claimed by: %s' % except IndexError: # No networks given.
(channel, ', '.join(claimed) or '\x1D(none)\x1D')) irc.reply('Channel \x02%s\x02 is claimed by: %s' %
else: (channel, ', '.join(claimed) or '\x1D(none)\x1D'))
if nets == '-' or not nets:
claimed = set()
else: else:
claimed = set(nets.split(',')) if nets == '-' or not nets:
db[relay]["claim"] = claimed claimed = set()
irc.reply('CLAIM for channel \x02%s\x02 set to: %s' % else:
(channel, ', '.join(claimed) or '\x1D(none)\x1D')) claimed = set(nets.split(','))
db[relay]["claim"] = claimed
irc.reply('CLAIM for channel \x02%s\x02 set to: %s' %
(channel, ', '.join(claimed) or '\x1D(none)\x1D'))