3
0
mirror of https://github.com/jlu5/PyLink.git synced 2024-12-25 12:12:53 +01:00

Rework inbound connection handling to use select

Closes #588.
This commit is contained in:
James Lu 2018-03-17 11:01:32 -07:00
parent 57f77c676d
commit f7ab2564fe
3 changed files with 196 additions and 168 deletions

View File

@ -24,7 +24,7 @@ try:
except ImportError: except ImportError:
raise ImportError("PyLink requires ircmatch to function; please install it and try again.") raise ImportError("PyLink requires ircmatch to function; please install it and try again.")
from . import world, utils, structures, conf, __version__ from . import world, utils, structures, conf, __version__, selectdriver
from .log import * from .log import *
from .utils import ProtocolError # Compatibility with PyLink 1.x from .utils import ProtocolError # Compatibility with PyLink 1.x
@ -1300,10 +1300,10 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._connection_thread = None
self._queue = None self._queue = None
self._ping_timer = None self._ping_timer = None
self._socket = None self._socket = None
self._selector_key = None
def _init_vars(self, *args, **kwargs): def _init_vars(self, *args, **kwargs):
super()._init_vars(*args, **kwargs) super()._init_vars(*args, **kwargs)
@ -1335,171 +1335,154 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils):
else: else:
log.error(*args, **kwargs) log.error(*args, **kwargs)
def _connect(self): def connect(self):
""" """
Runs the connect loop for the IRC object. This is usually called by Connects to the network.
__init__ in a separate thread to allow multiple concurrent connections.
""" """
while True: self._pre_connect()
self._pre_connect()
ip = self.serverdata["ip"] ip = self.serverdata["ip"]
port = self.serverdata["port"] port = self.serverdata["port"]
checks_ok = True checks_ok = True
try: try:
# Set the socket type (IPv6 or IPv4). # Set the socket type (IPv6 or IPv4).
stype = socket.AF_INET6 if self.serverdata.get("ipv6") else socket.AF_INET stype = socket.AF_INET6 if self.serverdata.get("ipv6") else socket.AF_INET
# Creat the socket. # Creat the socket.
self._socket = socket.socket(stype) self._socket = socket.socket(stype)
self._socket.setblocking(0)
# Set the socket bind if applicable. # Set the socket bind if applicable.
if 'bindhost' in self.serverdata: if 'bindhost' in self.serverdata:
self._socket.bind((self.serverdata['bindhost'], 0)) self._socket.bind((self.serverdata['bindhost'], 0))
# Set the connection timeouts. Initial connection timeout is a # Set the connection timeouts. Initial connection timeout is a
# lot smaller than the timeout after we've connected; this is # lot smaller than the timeout after we've connected; this is
# intentional. # intentional.
self._socket.settimeout(self.pingfreq) self._socket.settimeout(self.pingfreq)
# Resolve hostnames if it's not an IP address already. # Resolve hostnames if it's not an IP address already.
old_ip = ip old_ip = ip
ip = socket.getaddrinfo(ip, port, stype)[0][-1][0] ip = socket.getaddrinfo(ip, port, stype)[0][-1][0]
log.debug('(%s) Resolving address %s to %s', self.name, old_ip, ip) log.debug('(%s) Resolving address %s to %s', self.name, old_ip, ip)
# Enable SSL if set to do so. # Enable SSL if set to do so.
self.ssl = self.serverdata.get('ssl') self.ssl = self.serverdata.get('ssl')
if self.ssl: if self.ssl:
log.info('(%s) Attempting SSL for this connection...', self.name) log.info('(%s) Attempting SSL for this connection...', self.name)
certfile = self.serverdata.get('ssl_certfile') certfile = self.serverdata.get('ssl_certfile')
keyfile = self.serverdata.get('ssl_keyfile') keyfile = self.serverdata.get('ssl_keyfile')
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
# Disable SSLv2 and SSLv3 - these are insecure # Disable SSLv2 and SSLv3 - these are insecure
context.options |= ssl.OP_NO_SSLv2 context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3 context.options |= ssl.OP_NO_SSLv3
# Cert and key files are optional, load them if specified.
if certfile and keyfile:
try:
context.load_cert_chain(certfile, keyfile)
except OSError:
log.exception('(%s) Caught OSError trying to '
'initialize the SSL connection; '
'are "ssl_certfile" and '
'"ssl_keyfile" set correctly?',
self.name)
checks_ok = False
self._socket = context.wrap_socket(self._socket)
log.info("Connecting to network %r on %s:%s", self.name, ip, port)
self._socket.connect((ip, port))
self._socket.settimeout(self.pingtimeout)
# If SSL was enabled, optionally verify the certificate
# fingerprint for some added security. I don't bother to check
# the entire certificate for validity, since most IRC networks
# self-sign their certificates anyways.
if self.ssl and checks_ok:
peercert = self._socket.getpeercert(binary_form=True)
# Hash type is configurable using the ssl_fingerprint_type
# value, and defaults to sha256.
hashtype = self.serverdata.get('ssl_fingerprint_type', 'sha256').lower()
# Cert and key files are optional, load them if specified.
if certfile and keyfile:
try: try:
hashfunc = getattr(hashlib, hashtype) context.load_cert_chain(certfile, keyfile)
except AttributeError: except OSError:
log.error('(%s) Unsupported SSL certificate fingerprint type %r given, disconnecting...', log.exception('(%s) Caught OSError trying to '
self.name, hashtype) 'initialize the SSL connection; '
checks_ok = False 'are "ssl_certfile" and '
else: '"ssl_keyfile" set correctly?',
fp = hashfunc(peercert).hexdigest() self.name)
expected_fp = self.serverdata.get('ssl_fingerprint') checks_ok = False
if expected_fp and checks_ok: self._socket = context.wrap_socket(self._socket)
if fp != expected_fp:
# SSL Fingerprint doesn't match; break. self._selector_key = selectdriver.register(self)
log.error('(%s) Uplink\'s SSL certificate ' log.info("Connecting to network %r on %s:%s", self.name, ip, port)
'fingerprint (%s) does not match the ' self._socket.connect((ip, port))
'one configured: expected %r, got %r; ' self._socket.settimeout(self.pingtimeout)
'disconnecting...', self.name, hashtype,
expected_fp, fp) # If SSL was enabled, optionally verify the certificate
checks_ok = False # fingerprint for some added security. I don't bother to check
else: # the entire certificate for validity, since most IRC networks
log.info('(%s) Uplink SSL certificate fingerprint ' # self-sign their certificates anyways.
'(%s) verified: %r', self.name, hashtype, if self.ssl and checks_ok:
fp) peercert = self._socket.getpeercert(binary_form=True)
# Hash type is configurable using the ssl_fingerprint_type
# value, and defaults to sha256.
hashtype = self.serverdata.get('ssl_fingerprint_type', 'sha256').lower()
try:
hashfunc = getattr(hashlib, hashtype)
except AttributeError:
log.error('(%s) Unsupported SSL certificate fingerprint type %r given, disconnecting...',
self.name, hashtype)
checks_ok = False
else:
fp = hashfunc(peercert).hexdigest()
expected_fp = self.serverdata.get('ssl_fingerprint')
if expected_fp and checks_ok:
if fp != expected_fp:
# SSL Fingerprint doesn't match; break.
log.error('(%s) Uplink\'s SSL certificate '
'fingerprint (%s) does not match the '
'one configured: expected %r, got %r; '
'disconnecting...', self.name, hashtype,
expected_fp, fp)
checks_ok = False
else: else:
log.info('(%s) Uplink\'s SSL certificate fingerprint (%s) ' log.info('(%s) Uplink SSL certificate fingerprint '
'is %r. You can enhance the security of your ' '(%s) verified: %r', self.name, hashtype,
'link by specifying this in a "ssl_fingerprint"' fp)
' option in your server block.', self.name, else:
hashtype, fp) log.info('(%s) Uplink\'s SSL certificate fingerprint (%s) '
'is %r. You can enhance the security of your '
'link by specifying this in a "ssl_fingerprint"'
' option in your server block.', self.name,
hashtype, fp)
if checks_ok: if checks_ok:
self._queue_thread = threading.Thread(name="Queue thread for %s" % self.name, self._queue_thread = threading.Thread(name="Queue thread for %s" % self.name,
target=self._process_queue, daemon=True) target=self._process_queue, daemon=True)
self._queue_thread.start() self._queue_thread.start()
self.sid = self.serverdata.get("sid") self.sid = self.serverdata.get("sid")
# All our checks passed, get the protocol module to connect and run the listen # All our checks passed, get the protocol module to connect and run the listen
# loop. This also updates any SID values should the protocol module do so. # loop. This also updates any SID values should the protocol module do so.
self.post_connect() self.post_connect()
log.info('(%s) Enumerating our own SID %s', self.name, self.sid) log.info('(%s) Enumerating our own SID %s', self.name, self.sid)
host = self.hostname() host = self.hostname()
self.servers[self.sid] = Server(self, None, host, internal=True, self.servers[self.sid] = Server(self, None, host, internal=True,
desc=self.serverdata.get('serverdesc') desc=self.serverdata.get('serverdesc')
or conf.conf['pylink']['serverdesc']) or conf.conf['pylink']['serverdesc'])
log.info('(%s) Starting ping schedulers....', self.name)
self._schedule_ping()
log.info('(%s) Server ready; listening for data.', self.name)
self.autoconnect_active_multiplier = 1 # Reset any extra autoconnect delays
self._run_irc()
else: # Configuration error :(
log.error('(%s) A configuration error was encountered '
'trying to set up this connection. Please check'
' your configuration file and try again.',
self.name)
# _run_irc() or the protocol module it called raised an exception, meaning we've disconnected!
# Note: socket.error, ConnectionError, IOError, etc. are included in OSError since Python 3.3,
# so we don't need to explicitly catch them here.
# We also catch SystemExit here as a way to abort out connection threads properly, and stop the
# IRC connection from freezing instead.
except (OSError, RuntimeError, SystemExit) as e:
self._log_connection_error('(%s) Disconnected from IRC:', self.name, exc_info=True)
log.info('(%s) Starting ping schedulers....', self.name)
self._schedule_ping()
log.info('(%s) Server ready; listening for data.', self.name)
self.autoconnect_active_multiplier = 1 # Reset any extra autoconnect delays
else: # Configuration error :(
log.error('(%s) A configuration error was encountered '
'trying to set up this connection. Please check'
' your configuration file and try again.',
self.name)
# _run_irc() or the protocol module it called raised an exception, meaning we've disconnected!
# Note: socket.error, ConnectionError, IOError, etc. are included in OSError since Python 3.3,
# so we don't need to explicitly catch them here.
# We also catch SystemExit here as a way to abort out connection threads properly, and stop the
# IRC connection from freezing instead.
except (OSError, RuntimeError, SystemExit) as e:
self._log_connection_error('(%s) Disconnected from IRC:', self.name, exc_info=True)
if not self._aborted.is_set(): if not self._aborted.is_set():
self.disconnect() self.disconnect()
if not self._run_autoconnect(): if not self._run_autoconnect():
return return
def connect(self):
log.debug('(%s) calling _connect() (world.testing=%s)', self.name, world.testing)
if world.testing:
# HACK: Don't thread if we're running tests.
self._connect()
else:
if self._connection_thread and self._connection_thread.is_alive():
raise RuntimeError("Refusing to start multiple connection threads for network %r!" % self.name)
self._connection_thread = threading.Thread(target=self._connect,
name="Listener for %s" %
self.name)
self._connection_thread.start()
def disconnect(self): def disconnect(self):
"""Handle disconnects from the remote server.""" """Handle disconnects from the remote server."""
self._pre_disconnect() self._pre_disconnect()
if self._socket is not None: if self._socket is not None:
selectdriver.unregister(self)
try: try:
log.debug('(%s) disconnect: Shutting down socket.', self.name) log.debug('(%s) disconnect: Shutting down socket.', self.name)
self._socket.shutdown(socket.SHUT_RDWR) self._socket.shutdown(socket.SHUT_RDWR)
@ -1523,6 +1506,9 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils):
self._ping_timer.cancel() self._ping_timer.cancel()
self._post_disconnect() self._post_disconnect()
if self._run_autoconnect():
self.connect()
def handle_events(self, line): def handle_events(self, line):
raise NotImplementedError raise NotImplementedError
@ -1547,38 +1533,35 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils):
return hook_args return hook_args
def _run_irc(self): def _run_irc(self):
"""Main IRC loop which listens for messages.""" """
buf = b"" Message handler, called when select() has data to read.
data = b"" """
while (not self._aborted.is_set()) and not world.shutting_down.is_set(): buf = b''
data = b''
try: try:
data = self._socket.recv(2048) data = self._socket.recv(2048)
except BlockingIOError: except OSError:
log.debug('(%s) No data to read, trying again later...', self.name) # Suppress socket read warnings from lingering recv() calls if
if self._aborted.wait(self.SOCKET_REPOLL_WAIT): # we've been told to shutdown.
break if self._aborted.is_set():
continue
except OSError:
# Suppress socket read warnings from lingering recv() calls if
# we've been told to shutdown.
if self._aborted.is_set():
return
raise
buf += data
if not data:
self._log_connection_error('(%s) Connection lost, disconnecting.', self.name)
return
elif (time.time() - self.lastping) > self.pingtimeout:
self._log_connection_error('(%s) Connection timed out.', self.name)
return return
raise
while b'\n' in buf: buf += data
line, buf = buf.split(b'\n', 1) if not data:
line = line.strip(b'\r') self._log_connection_error('(%s) Connection lost, disconnecting.', self.name)
line = line.decode(self.encoding, "replace") self.disconnect()
self.parse_irc_command(line) return
elif (time.time() - self.lastping) > self.pingtimeout:
self._log_connection_error('(%s) Connection timed out.', self.name)
self.disconnect()
return
while b'\n' in buf:
line, buf = buf.split(b'\n', 1)
line = line.strip(b'\r')
line = line.decode(self.encoding, "replace")
self.parse_irc_command(line)
def _send(self, data): def _send(self, data):
"""Sends raw text to the uplink server.""" """Sends raw text to the uplink server."""

View File

@ -44,7 +44,7 @@ def main():
conf.load_conf(args.config) conf.load_conf(args.config)
from pylinkirc.log import log from pylinkirc.log import log
from pylinkirc import classes, utils, coremods from pylinkirc import classes, utils, coremods, selectdriver
world.daemon = args.daemonize world.daemon = args.daemonize
if args.daemonize: if args.daemonize:
@ -177,3 +177,4 @@ def main():
world.started.set() world.started.set()
log.info("Loaded plugins: %s", ', '.join(sorted(world.plugins.keys()))) log.info("Loaded plugins: %s", ', '.join(sorted(world.plugins.keys())))
selectdriver.start()

44
selectdriver.py Normal file
View File

@ -0,0 +1,44 @@
"""
Socket handling driver using the selectors module. epoll, kqueue, and devpoll
are used internally when available.
"""
import selectors
import threading
from pylinkirc import world
from pylinkirc.log import log
SELECT_TIMEOUT = 0.5
selector = selectors.DefaultSelector()
def _process_conns():
"""Main loop which processes connected sockets."""
while not world.shutting_down.is_set():
for socketkey, mask in selector.select(timeout=SELECT_TIMEOUT):
irc = socketkey.data
if mask & selectors.EVENT_READ:
irc._run_irc()
def register(irc):
"""
Registers a network to the global selectors instance.
"""
log.debug('selectdriver: registering %s for network %s', irc._socket, irc.name)
selector.register(irc._socket, selectors.EVENT_READ, data=irc)
def unregister(irc):
"""
Removes a network from the global selectors instance.
"""
log.debug('selectdriver: de-registering %s for network %s', irc._socket, irc.name)
selector.unregister(irc._socket)
def start():
"""
Starts a thread to process connections.
"""
t = threading.Thread(target=_process_conns, name="Selector driver loop")
t.start()