diff --git a/classes.py b/classes.py index 6e17281..653a770 100644 --- a/classes.py +++ b/classes.py @@ -24,7 +24,7 @@ try: except ImportError: raise ImportError("PyLink requires ircmatch to function; please install it and try again.") -from . import world, utils, structures, conf, __version__ +from . import world, utils, structures, conf, __version__, selectdriver from .log import * from .utils import ProtocolError # Compatibility with PyLink 1.x @@ -1300,10 +1300,10 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._connection_thread = None self._queue = None self._ping_timer = None self._socket = None + self._selector_key = None def _init_vars(self, *args, **kwargs): super()._init_vars(*args, **kwargs) @@ -1335,171 +1335,154 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils): else: log.error(*args, **kwargs) - def _connect(self): + def connect(self): """ - Runs the connect loop for the IRC object. This is usually called by - __init__ in a separate thread to allow multiple concurrent connections. + Connects to the network. """ - while True: - self._pre_connect() + self._pre_connect() - ip = self.serverdata["ip"] - port = self.serverdata["port"] - checks_ok = True - try: - # Set the socket type (IPv6 or IPv4). - stype = socket.AF_INET6 if self.serverdata.get("ipv6") else socket.AF_INET + ip = self.serverdata["ip"] + port = self.serverdata["port"] + checks_ok = True + try: + # Set the socket type (IPv6 or IPv4). + stype = socket.AF_INET6 if self.serverdata.get("ipv6") else socket.AF_INET - # Creat the socket. - self._socket = socket.socket(stype) - self._socket.setblocking(0) + # Creat the socket. + self._socket = socket.socket(stype) - # Set the socket bind if applicable. - if 'bindhost' in self.serverdata: - self._socket.bind((self.serverdata['bindhost'], 0)) + # Set the socket bind if applicable. + if 'bindhost' in self.serverdata: + self._socket.bind((self.serverdata['bindhost'], 0)) - # Set the connection timeouts. Initial connection timeout is a - # lot smaller than the timeout after we've connected; this is - # intentional. - self._socket.settimeout(self.pingfreq) + # Set the connection timeouts. Initial connection timeout is a + # lot smaller than the timeout after we've connected; this is + # intentional. + self._socket.settimeout(self.pingfreq) - # Resolve hostnames if it's not an IP address already. - old_ip = ip - ip = socket.getaddrinfo(ip, port, stype)[0][-1][0] - log.debug('(%s) Resolving address %s to %s', self.name, old_ip, ip) + # Resolve hostnames if it's not an IP address already. + old_ip = ip + ip = socket.getaddrinfo(ip, port, stype)[0][-1][0] + log.debug('(%s) Resolving address %s to %s', self.name, old_ip, ip) - # Enable SSL if set to do so. - self.ssl = self.serverdata.get('ssl') - if self.ssl: - log.info('(%s) Attempting SSL for this connection...', self.name) - certfile = self.serverdata.get('ssl_certfile') - keyfile = self.serverdata.get('ssl_keyfile') + # Enable SSL if set to do so. + self.ssl = self.serverdata.get('ssl') + if self.ssl: + log.info('(%s) Attempting SSL for this connection...', self.name) + certfile = self.serverdata.get('ssl_certfile') + keyfile = self.serverdata.get('ssl_keyfile') - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - # Disable SSLv2 and SSLv3 - these are insecure - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - - # Cert and key files are optional, load them if specified. - if certfile and keyfile: - try: - context.load_cert_chain(certfile, keyfile) - except OSError: - log.exception('(%s) Caught OSError trying to ' - 'initialize the SSL connection; ' - 'are "ssl_certfile" and ' - '"ssl_keyfile" set correctly?', - self.name) - checks_ok = False - - self._socket = context.wrap_socket(self._socket) - - log.info("Connecting to network %r on %s:%s", self.name, ip, port) - self._socket.connect((ip, port)) - self._socket.settimeout(self.pingtimeout) - - # If SSL was enabled, optionally verify the certificate - # fingerprint for some added security. I don't bother to check - # the entire certificate for validity, since most IRC networks - # self-sign their certificates anyways. - if self.ssl and checks_ok: - peercert = self._socket.getpeercert(binary_form=True) - - # Hash type is configurable using the ssl_fingerprint_type - # value, and defaults to sha256. - hashtype = self.serverdata.get('ssl_fingerprint_type', 'sha256').lower() + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + # Disable SSLv2 and SSLv3 - these are insecure + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + # Cert and key files are optional, load them if specified. + if certfile and keyfile: try: - hashfunc = getattr(hashlib, hashtype) - except AttributeError: - log.error('(%s) Unsupported SSL certificate fingerprint type %r given, disconnecting...', - self.name, hashtype) - checks_ok = False - else: - fp = hashfunc(peercert).hexdigest() - expected_fp = self.serverdata.get('ssl_fingerprint') + context.load_cert_chain(certfile, keyfile) + except OSError: + log.exception('(%s) Caught OSError trying to ' + 'initialize the SSL connection; ' + 'are "ssl_certfile" and ' + '"ssl_keyfile" set correctly?', + self.name) + checks_ok = False - if expected_fp and checks_ok: - if fp != expected_fp: - # SSL Fingerprint doesn't match; break. - log.error('(%s) Uplink\'s SSL certificate ' - 'fingerprint (%s) does not match the ' - 'one configured: expected %r, got %r; ' - 'disconnecting...', self.name, hashtype, - expected_fp, fp) - checks_ok = False - else: - log.info('(%s) Uplink SSL certificate fingerprint ' - '(%s) verified: %r', self.name, hashtype, - fp) + self._socket = context.wrap_socket(self._socket) + + self._selector_key = selectdriver.register(self) + log.info("Connecting to network %r on %s:%s", self.name, ip, port) + self._socket.connect((ip, port)) + self._socket.settimeout(self.pingtimeout) + + # If SSL was enabled, optionally verify the certificate + # fingerprint for some added security. I don't bother to check + # the entire certificate for validity, since most IRC networks + # self-sign their certificates anyways. + if self.ssl and checks_ok: + peercert = self._socket.getpeercert(binary_form=True) + + # Hash type is configurable using the ssl_fingerprint_type + # value, and defaults to sha256. + hashtype = self.serverdata.get('ssl_fingerprint_type', 'sha256').lower() + + try: + hashfunc = getattr(hashlib, hashtype) + except AttributeError: + log.error('(%s) Unsupported SSL certificate fingerprint type %r given, disconnecting...', + self.name, hashtype) + checks_ok = False + else: + fp = hashfunc(peercert).hexdigest() + expected_fp = self.serverdata.get('ssl_fingerprint') + + if expected_fp and checks_ok: + if fp != expected_fp: + # SSL Fingerprint doesn't match; break. + log.error('(%s) Uplink\'s SSL certificate ' + 'fingerprint (%s) does not match the ' + 'one configured: expected %r, got %r; ' + 'disconnecting...', self.name, hashtype, + expected_fp, fp) + checks_ok = False else: - log.info('(%s) Uplink\'s SSL certificate fingerprint (%s) ' - 'is %r. You can enhance the security of your ' - 'link by specifying this in a "ssl_fingerprint"' - ' option in your server block.', self.name, - hashtype, fp) + log.info('(%s) Uplink SSL certificate fingerprint ' + '(%s) verified: %r', self.name, hashtype, + fp) + else: + log.info('(%s) Uplink\'s SSL certificate fingerprint (%s) ' + 'is %r. You can enhance the security of your ' + 'link by specifying this in a "ssl_fingerprint"' + ' option in your server block.', self.name, + hashtype, fp) - if checks_ok: + if checks_ok: - self._queue_thread = threading.Thread(name="Queue thread for %s" % self.name, - target=self._process_queue, daemon=True) - self._queue_thread.start() + self._queue_thread = threading.Thread(name="Queue thread for %s" % self.name, + target=self._process_queue, daemon=True) + self._queue_thread.start() - self.sid = self.serverdata.get("sid") - # All our checks passed, get the protocol module to connect and run the listen - # loop. This also updates any SID values should the protocol module do so. - self.post_connect() + self.sid = self.serverdata.get("sid") + # All our checks passed, get the protocol module to connect and run the listen + # loop. This also updates any SID values should the protocol module do so. + self.post_connect() - log.info('(%s) Enumerating our own SID %s', self.name, self.sid) - host = self.hostname() + log.info('(%s) Enumerating our own SID %s', self.name, self.sid) + host = self.hostname() - self.servers[self.sid] = Server(self, None, host, internal=True, - desc=self.serverdata.get('serverdesc') - or conf.conf['pylink']['serverdesc']) - - log.info('(%s) Starting ping schedulers....', self.name) - self._schedule_ping() - log.info('(%s) Server ready; listening for data.', self.name) - self.autoconnect_active_multiplier = 1 # Reset any extra autoconnect delays - self._run_irc() - else: # Configuration error :( - log.error('(%s) A configuration error was encountered ' - 'trying to set up this connection. Please check' - ' your configuration file and try again.', - self.name) - # _run_irc() or the protocol module it called raised an exception, meaning we've disconnected! - # Note: socket.error, ConnectionError, IOError, etc. are included in OSError since Python 3.3, - # so we don't need to explicitly catch them here. - # We also catch SystemExit here as a way to abort out connection threads properly, and stop the - # IRC connection from freezing instead. - except (OSError, RuntimeError, SystemExit) as e: - self._log_connection_error('(%s) Disconnected from IRC:', self.name, exc_info=True) + self.servers[self.sid] = Server(self, None, host, internal=True, + desc=self.serverdata.get('serverdesc') + or conf.conf['pylink']['serverdesc']) + log.info('(%s) Starting ping schedulers....', self.name) + self._schedule_ping() + log.info('(%s) Server ready; listening for data.', self.name) + self.autoconnect_active_multiplier = 1 # Reset any extra autoconnect delays + else: # Configuration error :( + log.error('(%s) A configuration error was encountered ' + 'trying to set up this connection. Please check' + ' your configuration file and try again.', + self.name) + # _run_irc() or the protocol module it called raised an exception, meaning we've disconnected! + # Note: socket.error, ConnectionError, IOError, etc. are included in OSError since Python 3.3, + # so we don't need to explicitly catch them here. + # We also catch SystemExit here as a way to abort out connection threads properly, and stop the + # IRC connection from freezing instead. + except (OSError, RuntimeError, SystemExit) as e: + self._log_connection_error('(%s) Disconnected from IRC:', self.name, exc_info=True) if not self._aborted.is_set(): self.disconnect() if not self._run_autoconnect(): return - def connect(self): - log.debug('(%s) calling _connect() (world.testing=%s)', self.name, world.testing) - if world.testing: - # HACK: Don't thread if we're running tests. - self._connect() - else: - if self._connection_thread and self._connection_thread.is_alive(): - raise RuntimeError("Refusing to start multiple connection threads for network %r!" % self.name) - - self._connection_thread = threading.Thread(target=self._connect, - name="Listener for %s" % - self.name) - self._connection_thread.start() - def disconnect(self): """Handle disconnects from the remote server.""" self._pre_disconnect() if self._socket is not None: + selectdriver.unregister(self) try: log.debug('(%s) disconnect: Shutting down socket.', self.name) self._socket.shutdown(socket.SHUT_RDWR) @@ -1523,6 +1506,9 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils): self._ping_timer.cancel() self._post_disconnect() + if self._run_autoconnect(): + self.connect() + def handle_events(self, line): raise NotImplementedError @@ -1547,38 +1533,35 @@ class IRCNetwork(PyLinkNetworkCoreWithUtils): return hook_args def _run_irc(self): - """Main IRC loop which listens for messages.""" - buf = b"" - data = b"" - while (not self._aborted.is_set()) and not world.shutting_down.is_set(): - - try: - data = self._socket.recv(2048) - except BlockingIOError: - log.debug('(%s) No data to read, trying again later...', self.name) - if self._aborted.wait(self.SOCKET_REPOLL_WAIT): - break - continue - except OSError: - # Suppress socket read warnings from lingering recv() calls if - # we've been told to shutdown. - if self._aborted.is_set(): - return - raise - - buf += data - if not data: - self._log_connection_error('(%s) Connection lost, disconnecting.', self.name) - return - elif (time.time() - self.lastping) > self.pingtimeout: - self._log_connection_error('(%s) Connection timed out.', self.name) + """ + Message handler, called when select() has data to read. + """ + buf = b'' + data = b'' + try: + data = self._socket.recv(2048) + except OSError: + # Suppress socket read warnings from lingering recv() calls if + # we've been told to shutdown. + if self._aborted.is_set(): return + raise - while b'\n' in buf: - line, buf = buf.split(b'\n', 1) - line = line.strip(b'\r') - line = line.decode(self.encoding, "replace") - self.parse_irc_command(line) + buf += data + if not data: + self._log_connection_error('(%s) Connection lost, disconnecting.', self.name) + self.disconnect() + return + elif (time.time() - self.lastping) > self.pingtimeout: + self._log_connection_error('(%s) Connection timed out.', self.name) + self.disconnect() + return + + while b'\n' in buf: + line, buf = buf.split(b'\n', 1) + line = line.strip(b'\r') + line = line.decode(self.encoding, "replace") + self.parse_irc_command(line) def _send(self, data): """Sends raw text to the uplink server.""" diff --git a/launcher.py b/launcher.py index ed769be..3233660 100644 --- a/launcher.py +++ b/launcher.py @@ -44,7 +44,7 @@ def main(): conf.load_conf(args.config) from pylinkirc.log import log - from pylinkirc import classes, utils, coremods + from pylinkirc import classes, utils, coremods, selectdriver world.daemon = args.daemonize if args.daemonize: @@ -177,3 +177,4 @@ def main(): world.started.set() log.info("Loaded plugins: %s", ', '.join(sorted(world.plugins.keys()))) + selectdriver.start() diff --git a/selectdriver.py b/selectdriver.py new file mode 100644 index 0000000..f924dd6 --- /dev/null +++ b/selectdriver.py @@ -0,0 +1,44 @@ +""" +Socket handling driver using the selectors module. epoll, kqueue, and devpoll +are used internally when available. +""" + +import selectors +import threading + +from pylinkirc import world +from pylinkirc.log import log + +SELECT_TIMEOUT = 0.5 + +selector = selectors.DefaultSelector() + +def _process_conns(): + """Main loop which processes connected sockets.""" + + while not world.shutting_down.is_set(): + for socketkey, mask in selector.select(timeout=SELECT_TIMEOUT): + irc = socketkey.data + if mask & selectors.EVENT_READ: + irc._run_irc() + +def register(irc): + """ + Registers a network to the global selectors instance. + """ + log.debug('selectdriver: registering %s for network %s', irc._socket, irc.name) + selector.register(irc._socket, selectors.EVENT_READ, data=irc) + +def unregister(irc): + """ + Removes a network from the global selectors instance. + """ + log.debug('selectdriver: de-registering %s for network %s', irc._socket, irc.name) + selector.unregister(irc._socket) + +def start(): + """ + Starts a thread to process connections. + """ + t = threading.Thread(target=_process_conns, name="Selector driver loop") + t.start()