3
0
mirror of https://github.com/jlu5/PyLink.git synced 2025-01-11 20:52:42 +01:00

relay: add locks in db read/writes (thread safety)

This commit is contained in:
James Lu 2016-12-09 17:43:50 -08:00
parent e40b2f6529
commit 2b4943a780

View File

@ -14,6 +14,7 @@ relayusers = defaultdict(dict)
relayservers = defaultdict(dict) relayservers = defaultdict(dict)
spawnlocks = defaultdict(threading.RLock) spawnlocks = defaultdict(threading.RLock)
spawnlocks_servers = defaultdict(threading.RLock) spawnlocks_servers = defaultdict(threading.RLock)
db_lock = threading.RLock()
dbname = utils.getDatabaseName('pylinkrelay') dbname = utils.getDatabaseName('pylinkrelay')
datastore = structures.PickleDataStore('pylinkrelay', dbname) datastore = structures.PickleDataStore('pylinkrelay', dbname)
@ -34,6 +35,7 @@ def initializeAll(irc):
# which would break connections. # which would break connections.
world.started.wait(2) world.started.wait(2)
with db_lock:
for chanpair, entrydata in db.items(): for chanpair, entrydata in db.items():
# Iterate over all the channels stored in our relay links DB. # Iterate over all the channels stored in our relay links DB.
network, channel = chanpair network, channel = chanpair
@ -410,6 +412,7 @@ def getOrigUser(irc, user, targetirc=None):
def getRelay(chanpair): def getRelay(chanpair):
"""Finds the matching relay entry name for the given (network name, channel) """Finds the matching relay entry name for the given (network name, channel)
pair, if one exists.""" pair, if one exists."""
with db_lock:
if chanpair in db: # This chanpair is a shared channel; others link to it if chanpair in db: # This chanpair is a shared channel; others link to it
return chanpair return chanpair
# This chanpair is linked *to* a remote channel # This chanpair is linked *to* a remote channel
@ -428,6 +431,7 @@ def getRemoteChan(irc, remoteirc, channel):
if chanpair[0] == remotenetname: if chanpair[0] == remotenetname:
return chanpair[1] return chanpair[1]
else: else:
with db_lock:
for link in db[chanpair]['links']: for link in db[chanpair]['links']:
if link[0] == remotenetname: if link[0] == remotenetname:
return link[1] return link[1]
@ -441,6 +445,7 @@ def initializeChannel(irc, channel):
log.debug('(%s) relay.initializeChannel: relay pair found to be %s', irc.name, relay) log.debug('(%s) relay.initializeChannel: relay pair found to be %s', irc.name, relay)
queued_users = [] queued_users = []
if relay: if relay:
with db_lock:
all_links = db[relay]['links'].copy() all_links = db[relay]['links'].copy()
all_links.update((relay,)) all_links.update((relay,))
log.debug('(%s) relay.initializeChannel: all_links: %s', irc.name, all_links) log.debug('(%s) relay.initializeChannel: all_links: %s', irc.name, all_links)
@ -524,7 +529,9 @@ def checkClaim(irc, channel, sender, chanobj=None):
sender_modes = getPrefixModes(irc, irc, channel, sender, mlist=mlist) sender_modes = getPrefixModes(irc, irc, channel, sender, mlist=mlist)
log.debug('(%s) relay.checkClaim: sender modes (%s/%s) are %s (mlist=%s)', irc.name, log.debug('(%s) relay.checkClaim: sender modes (%s/%s) are %s (mlist=%s)', irc.name,
sender, channel, sender_modes, mlist) sender, channel, sender_modes, mlist)
# XXX: stop hardcoding modes to check for and support mlist in isHalfopPlus and friends # XXX: stop hardcoding modes to check for and support mlist in isHalfopPlus and friends
with db_lock:
return (not relay) or irc.name == relay[0] or not db[relay]['claim'] or \ return (not relay) or irc.name == relay[0] or not db[relay]['claim'] or \
irc.name in db[relay]['claim'] or \ irc.name in db[relay]['claim'] or \
any([mode in sender_modes for mode in ('y', 'q', 'a', 'o', 'h')]) \ any([mode in sender_modes for mode in ('y', 'q', 'a', 'o', 'h')]) \
@ -1542,6 +1549,7 @@ def create(irc, source, args):
creator = irc.getHostmask(source) creator = irc.getHostmask(source)
# Create the relay database entry with the (network name, channel name) # Create the relay database entry with the (network name, channel name)
# pair - this is just a dict with various keys. # pair - this is just a dict with various keys.
with db_lock:
db[(irc.name, channel)] = {'claim': [irc.name], 'links': set(), db[(irc.name, channel)] = {'claim': [irc.name], 'links': set(),
'blocked_nets': set(), 'creator': creator, 'blocked_nets': set(), 'creator': creator,
'ts': time.time()} 'ts': time.time()}
@ -1554,6 +1562,7 @@ def _stop_relay(entry):
"""Internal function to deinitialize a relay link and its leaves.""" """Internal function to deinitialize a relay link and its leaves."""
network, channel = entry network, channel = entry
# Iterate over all the channel links and deinitialize them. # Iterate over all the channel links and deinitialize them.
with db_lock:
for link in db[entry]['links']: for link in db[entry]['links']:
removeChannel(world.networkobjects.get(link[0]), link[1]) removeChannel(world.networkobjects.get(link[0]), link[1])
removeChannel(world.networkobjects.get(network), channel) removeChannel(world.networkobjects.get(network), channel)
@ -1586,6 +1595,7 @@ def destroy(irc, source, args):
entry = (network, channel) entry = (network, channel)
with db_lock:
if entry in db: if entry in db:
_stop_relay(entry) _stop_relay(entry)
del db[entry] del db[entry]
@ -1613,7 +1623,7 @@ def purge(irc, source, args):
count = 0 count = 0
### XXX lock to make this thread safe! with db_lock:
for entry in db.copy(): for entry in db.copy():
# Entry was owned by the target network; remove it # Entry was owned by the target network; remove it
if entry[0] == network: if entry[0] == network:
@ -1686,6 +1696,7 @@ def link(irc, source, args):
return return
try: try:
with db_lock:
entry = db[(remotenet, channel)] entry = db[(remotenet, channel)]
except KeyError: except KeyError:
irc.error('No such relay %r exists.' % channel) irc.error('No such relay %r exists.' % channel)
@ -1747,12 +1758,14 @@ def delink(irc, source, args):
"network).") "network).")
return return
else: else:
with db_lock:
for link in db[entry]['links'].copy(): for link in db[entry]['links'].copy():
if link[0] == remotenet: if link[0] == remotenet:
removeChannel(world.networkobjects.get(remotenet), link[1]) removeChannel(world.networkobjects.get(remotenet), link[1])
db[entry]['links'].remove(link) db[entry]['links'].remove(link)
else: else:
removeChannel(irc, channel) removeChannel(irc, channel)
with db_lock:
db[entry]['links'].remove((irc.name, channel)) db[entry]['links'].remove((irc.name, channel))
irc.reply('Done.') irc.reply('Done.')
log.info('(%s) relay: Channel %s delinked from %s%s by %s.', irc.name, log.info('(%s) relay: Channel %s delinked from %s%s by %s.', irc.name,
@ -1788,6 +1801,7 @@ def linked(irc, source, args):
irc.reply("Showing channels linked to %s:" % net, private=True) irc.reply("Showing channels linked to %s:" % net, private=True)
# Sort the list of shared channels when displaying # Sort the list of shared channels when displaying
with db_lock:
for k, v in sorted(db.items()): for k, v in sorted(db.items()):
# Skip if we're filtering by network and the network given isn't relayed # Skip if we're filtering by network and the network given isn't relayed
@ -1864,6 +1878,8 @@ def linkacl(irc, source, args):
if not relay: if not relay:
irc.error('No such relay %r exists.' % channel) irc.error('No such relay %r exists.' % channel)
return return
with db_lock:
if cmd == 'list': if cmd == 'list':
permissions.checkPermissions(irc, source, ['relay.linkacl.view']) permissions.checkPermissions(irc, source, ['relay.linkacl.view'])
s = 'Blocked networks for \x02%s\x02: \x02%s\x02' % (channel, ', '.join(db[relay]['blocked_nets']) or '(empty)') s = 'Blocked networks for \x02%s\x02: \x02%s\x02' % (channel, ', '.join(db[relay]['blocked_nets']) or '(empty)')
@ -1982,6 +1998,7 @@ def claim(irc, source, args):
# We override getRelay() here to limit the search to the current network. # We override getRelay() here to limit the search to the current network.
relay = (irc.name, channel) relay = (irc.name, channel)
with db_lock:
if relay not in db: if relay not in db:
irc.error('No such relay %r exists.' % channel) irc.error('No such relay %r exists.' % channel)
return return