From 849e3a59498dfb90b0d24f0dbc4b9e4fe1488a6d Mon Sep 17 00:00:00 2001 From: Johannes Bauer Date: Fri, 25 Oct 2019 11:08:20 +0200 Subject: [PATCH] Implemented finding of keyserver and unlocking of volumes We'll now parse the response messages on the client side, abort after a previously defined timeout and trigger the LUKS unlocking process, if requested (although the latter isn't fully implemented yet). --- Makefile | 2 +- README.md | 2 +- argparse_client.c | 44 +++++++++++++--- argparse_client.h | 13 +++-- argparse_edit.c | 2 +- argparse_edit.h | 2 +- argparse_server.c | 2 +- argparse_server.h | 2 +- blacklist.c | 10 +--- blacklist.h | 2 +- client.c | 111 ++++++++++++++++++++++++++++----------- keydb.c | 20 +++++++ keydb.h | 2 + parsers/parser_client.py | 2 + pgmopts.c | 8 +++ pgmopts.h | 2 + server.c | 4 +- udp.c | 31 +++++++++-- udp.h | 7 +-- util.c | 9 ++++ util.h | 1 + 21 files changed, 209 insertions(+), 69 deletions(-) diff --git a/Makefile b/Makefile index 0db18bc..f1965f1 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,7 @@ test_s: luksrku ./luksrku server -vv testdata/server.bin test_c: luksrku - ./luksrku client -vv testdata/client.bin + ./luksrku client -vv --no-luks testdata/client.bin .c.o: $(CC) $(CFLAGS) -c -o $@ $< diff --git a/README.md b/README.md index 56fb7b6..9027b60 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ passphrases are based on 256 bit long secrets and are converted to Base64 for easier handling (when setting up everything initially). The binary protocol that runs between both is intentionally extremely simple to -allow for easy code review. +allow for easy code review. It exclusively uses fixed message lengths. The key database is encrypted itself, using AES256-GCM, a 128 bit randomized initialization vector and authenticated with a 128 bit authentication tag. Key diff --git a/argparse_client.c b/argparse_client.c index 62d6e52..0ca499c 100644 --- a/argparse_client.c +++ b/argparse_client.c @@ -5,7 +5,7 @@ * * Do not edit it by hand, your changes will be overwritten. * - * Generated at: 2019-10-23 20:13:13 + * Generated at: 2019-10-25 11:06:30 */ #include @@ -21,19 +21,24 @@ static enum argparse_client_option_t last_parsed_option; static char last_error_message[256]; static const char *option_texts[] = { + [ARG_CLIENT_TIMEOUT] = "-t / --timeout", [ARG_CLIENT_PORT] = "-p / --port", + [ARG_CLIENT_NO_LUKS] = "--no-luks", [ARG_CLIENT_VERBOSE] = "-v / --verbose", [ARG_CLIENT_FILENAME] = "filename", [ARG_CLIENT_HOSTNAME] = "hostname", }; enum argparse_client_option_internal_t { + ARG_CLIENT_TIMEOUT_SHORT = 't', ARG_CLIENT_PORT_SHORT = 'p', ARG_CLIENT_VERBOSE_SHORT = 'v', - ARG_CLIENT_PORT_LONG = 1000, - ARG_CLIENT_VERBOSE_LONG = 1001, - ARG_CLIENT_FILENAME_LONG = 1002, - ARG_CLIENT_HOSTNAME_LONG = 1003, + ARG_CLIENT_TIMEOUT_LONG = 1000, + ARG_CLIENT_PORT_LONG = 1001, + ARG_CLIENT_NO_LUKS_LONG = 1002, + ARG_CLIENT_VERBOSE_LONG = 1003, + ARG_CLIENT_FILENAME_LONG = 1004, + ARG_CLIENT_HOSTNAME_LONG = 1005, }; static void errmsg_callback(const char *errmsg, ...) { @@ -54,9 +59,11 @@ static void errmsg_option_callback(enum argparse_client_option_t error_option, c bool argparse_client_parse(int argc, char **argv, argparse_client_callback_t argument_callback, argparse_client_plausibilization_callback_t plausibilization_callback) { last_parsed_option = ARGPARSE_CLIENT_NO_OPTION; - const char *short_options = "p:v"; + const char *short_options = "t:p:v"; struct option long_options[] = { + { "timeout", required_argument, 0, ARG_CLIENT_TIMEOUT_LONG }, { "port", required_argument, 0, ARG_CLIENT_PORT_LONG }, + { "no-luks", no_argument, 0, ARG_CLIENT_NO_LUKS_LONG }, { "verbose", no_argument, 0, ARG_CLIENT_VERBOSE_LONG }, { "filename", required_argument, 0, ARG_CLIENT_FILENAME_LONG }, { "hostname", required_argument, 0, ARG_CLIENT_HOSTNAME_LONG }, @@ -71,6 +78,14 @@ bool argparse_client_parse(int argc, char **argv, argparse_client_callback_t arg last_error_message[0] = 0; enum argparse_client_option_internal_t arg = (enum argparse_client_option_internal_t)optval; switch (arg) { + case ARG_CLIENT_TIMEOUT_SHORT: + case ARG_CLIENT_TIMEOUT_LONG: + last_parsed_option = ARG_CLIENT_TIMEOUT; + if (!argument_callback(ARG_CLIENT_TIMEOUT, optarg, errmsg_callback)) { + return false; + } + break; + case ARG_CLIENT_PORT_SHORT: case ARG_CLIENT_PORT_LONG: last_parsed_option = ARG_CLIENT_PORT; @@ -79,6 +94,13 @@ bool argparse_client_parse(int argc, char **argv, argparse_client_callback_t arg } break; + case ARG_CLIENT_NO_LUKS_LONG: + last_parsed_option = ARG_CLIENT_NO_LUKS; + if (!argument_callback(ARG_CLIENT_NO_LUKS, optarg, errmsg_callback)) { + return false; + } + break; + case ARG_CLIENT_VERBOSE_SHORT: case ARG_CLIENT_VERBOSE_LONG: last_parsed_option = ARG_CLIENT_VERBOSE; @@ -127,7 +149,7 @@ bool argparse_client_parse(int argc, char **argv, argparse_client_callback_t arg } void argparse_client_show_syntax(void) { - fprintf(stderr, "usage: luksrku client [-p port] [-v] filename [hostname]\n"); + fprintf(stderr, "usage: luksrku client [-t secs] [-p port] [--no-luks] [-v] filename [hostname]\n"); fprintf(stderr, "\n"); fprintf(stderr, "Connects to a luksrku key server and unlocks local LUKS volumes.\n"); fprintf(stderr, "\n"); @@ -139,8 +161,14 @@ void argparse_client_show_syntax(void) { fprintf(stderr, " hostname is attempted.\n"); fprintf(stderr, "\n"); fprintf(stderr, "optional arguments:\n"); + fprintf(stderr, " -t secs, --timeout secs\n"); + fprintf(stderr, " When searching for a keyserver and not all volumes can\n"); + fprintf(stderr, " be unlocked, abort after this period of time, given in\n"); + fprintf(stderr, " seconds. Defaults to 60 seconds.\n"); fprintf(stderr, " -p port, --port port Port that is used for both UDP and TCP communication.\n"); fprintf(stderr, " Defaults to 23170.\n"); + fprintf(stderr, " --no-luks Do not call LUKS/cryptsetup. Useful for testing\n"); + fprintf(stderr, " unlocking procedure.\n"); fprintf(stderr, " -v, --verbose Increase verbosity. Can be specified multiple times.\n"); } @@ -166,7 +194,9 @@ void argparse_client_parse_or_quit(int argc, char **argv, argparse_client_callba static const char *option_enum_to_str(enum argparse_client_option_t option) { switch (option) { + case ARG_CLIENT_TIMEOUT: return "ARG_CLIENT_TIMEOUT"; case ARG_CLIENT_PORT: return "ARG_CLIENT_PORT"; + case ARG_CLIENT_NO_LUKS: return "ARG_CLIENT_NO_LUKS"; case ARG_CLIENT_VERBOSE: return "ARG_CLIENT_VERBOSE"; case ARG_CLIENT_FILENAME: return "ARG_CLIENT_FILENAME"; case ARG_CLIENT_HOSTNAME: return "ARG_CLIENT_HOSTNAME"; diff --git a/argparse_client.h b/argparse_client.h index b152d0c..f2b4b57 100644 --- a/argparse_client.h +++ b/argparse_client.h @@ -5,7 +5,7 @@ * * Do not edit it by hand, your changes will be overwritten. * - * Generated at: 2019-10-23 20:13:13 + * Generated at: 2019-10-25 11:06:30 */ #ifndef __ARGPARSE_CLIENT_H__ @@ -13,6 +13,7 @@ #include +#define ARGPARSE_CLIENT_DEFAULT_TIMEOUT 60 #define ARGPARSE_CLIENT_DEFAULT_PORT 23170 #define ARGPARSE_CLIENT_DEFAULT_VERBOSE 0 @@ -20,10 +21,12 @@ #define ARGPARSE_CLIENT_POSITIONAL_ARG 1 enum argparse_client_option_t { - ARG_CLIENT_PORT = 2, - ARG_CLIENT_VERBOSE = 3, - ARG_CLIENT_FILENAME = 4, - ARG_CLIENT_HOSTNAME = 5, + ARG_CLIENT_TIMEOUT = 2, + ARG_CLIENT_PORT = 3, + ARG_CLIENT_NO_LUKS = 4, + ARG_CLIENT_VERBOSE = 5, + ARG_CLIENT_FILENAME = 6, + ARG_CLIENT_HOSTNAME = 7, }; typedef void (*argparse_client_errmsg_callback_t)(const char *errmsg, ...); diff --git a/argparse_edit.c b/argparse_edit.c index bcce7e3..d5ff33d 100644 --- a/argparse_edit.c +++ b/argparse_edit.c @@ -5,7 +5,7 @@ * * Do not edit it by hand, your changes will be overwritten. * - * Generated at: 2019-10-23 20:13:13 + * Generated at: 2019-10-25 11:06:30 */ #include diff --git a/argparse_edit.h b/argparse_edit.h index d723043..b261c57 100644 --- a/argparse_edit.h +++ b/argparse_edit.h @@ -5,7 +5,7 @@ * * Do not edit it by hand, your changes will be overwritten. * - * Generated at: 2019-10-23 20:13:13 + * Generated at: 2019-10-25 11:06:30 */ #ifndef __ARGPARSE_EDIT_H__ diff --git a/argparse_server.c b/argparse_server.c index 4497507..aa6c494 100644 --- a/argparse_server.c +++ b/argparse_server.c @@ -5,7 +5,7 @@ * * Do not edit it by hand, your changes will be overwritten. * - * Generated at: 2019-10-23 20:13:13 + * Generated at: 2019-10-25 11:06:30 */ #include diff --git a/argparse_server.h b/argparse_server.h index c13b5bc..09d2456 100644 --- a/argparse_server.h +++ b/argparse_server.h @@ -5,7 +5,7 @@ * * Do not edit it by hand, your changes will be overwritten. * - * Generated at: 2019-10-23 20:13:13 + * Generated at: 2019-10-25 11:06:30 */ #ifndef __ARGPARSE_SERVER_H__ diff --git a/blacklist.c b/blacklist.c index 5762684..b2e87c7 100644 --- a/blacklist.c +++ b/blacklist.c @@ -24,20 +24,12 @@ #include #include #include -#include #include "blacklist.h" #include "global.h" +#include "util.h" static struct blacklist_entry_t blacklist[BLACKLIST_ENTRY_COUNT]; -static double now(void) { - struct timeval tv; - if (gettimeofday(&tv, NULL)) { - return 0; - } - return tv.tv_sec + (tv.tv_usec * 1e-6); -} - static bool blacklist_entry_expired(int index) { return now() > blacklist[index].entered + BLACKLIST_ENTRY_TIMEOUT_SECS; } diff --git a/blacklist.h b/blacklist.h index b704adf..f358ea0 100644 --- a/blacklist.h +++ b/blacklist.h @@ -28,7 +28,7 @@ #include #define BLACKLIST_ENTRY_COUNT 32 -#define BLACKLIST_ENTRY_TIMEOUT_SECS 60 +#define BLACKLIST_ENTRY_TIMEOUT_SECS 15 struct blacklist_entry_t { uint32_t ip; diff --git a/client.c b/client.c index d38793f..21da4ab 100644 --- a/client.c +++ b/client.c @@ -42,12 +42,14 @@ #include "blacklist.h" #include "keydb.h" #include "uuid.h" +#include "udp.h" struct keyclient_t { const struct pgmopts_client_t *opts; struct keydb_t *keydb; bool volume_unlocked[MAX_VOLUMES_PER_HOST]; unsigned char identifier[ASCII_UUID_BUFSIZE]; + double broadcast_start_time; }; static int psk_client_callback(SSL *ssl, const EVP_MD *md, const unsigned char **id, size_t *idlen, SSL_SESSION **sessptr) { @@ -58,6 +60,36 @@ static int psk_client_callback(SSL *ssl, const EVP_MD *md, const unsigned char * return openssl_tls13_psk_establish_session(ssl, key_client->keydb->hosts[0].tls_psk, PSK_SIZE_BYTES, EVP_sha256(), sessptr); } +static bool do_unlock_luks_volume(const struct volume_entry_t *volume, const struct msg_t *unlock_msg) { + return true; +} + +static bool unlock_luks_volume(struct keyclient_t *keyclient, const struct msg_t *unlock_msg) { + const struct host_entry_t *host = &keyclient->keydb->hosts[0]; + const struct volume_entry_t* volume = keydb_get_volume_by_uuid(host, unlock_msg->volume_uuid); + if (!volume) { + char volume_uuid_str[ASCII_UUID_BUFSIZE]; + sprintf_uuid(volume_uuid_str, unlock_msg->volume_uuid); + log_msg(LLVL_WARNING, "Keyserver provided key for unlocking volume UUID %s, but this volume does not need unlocking on the client side.", volume_uuid_str); + return false; + } + + /* Volume! */ + int volume_index = keydb_get_volume_index(host, volume); + if (volume_index != -1) { + if (keyclient->opts->no_luks) { + keyclient->volume_unlocked[volume_index] = true; + } else { + keyclient->volume_unlocked[volume_index] = do_unlock_luks_volume(volume, unlock_msg); + } + } else { + log_msg(LLVL_FATAL, "Error calculating volume offset for volume %p from base %p.", volume, host->volumes); + return false; + } + + return true; +} + static bool contact_keyserver_socket(struct keyclient_t *keyclient, int sd) { struct generic_tls_ctx_t gctx; if (!create_generic_tls_context(&gctx, false)) { @@ -83,10 +115,13 @@ static bool contact_keyserver_socket(struct keyclient_t *keyclient, int sd) { log_openssl(LLVL_FATAL, "SSL_read returned %d bytes when we expected to read %d", bytes_read, sizeof(msg)); break; } - if (should_log(LLVL_TRACE)) { - char uuid_str[ASCII_UUID_BUFSIZE]; - sprintf_uuid(uuid_str, msg.volume_uuid); - log_msg(LLVL_TRACE, "Received LUKS key to unlock volume with UUID %s", uuid_str); + char uuid_str[ASCII_UUID_BUFSIZE]; + sprintf_uuid(uuid_str, msg.volume_uuid); + log_msg(LLVL_TRACE, "Received LUKS key to unlock volume with UUID %s", uuid_str); + if (unlock_luks_volume(keyclient, &msg)) { + log_msg(LLVL_DEBUG, "Successfully unlocked volume with UUID %s", uuid_str); + } else { + log_msg(LLVL_ERROR, "Failed to unlocked volume with UUID %s", uuid_str); } } OPENSSL_cleanse(&msg, sizeof(msg)); @@ -152,50 +187,62 @@ static bool contact_keyserver_hostname(struct keyclient_t *keyclient, const char return success; } -static int create_udp_socket(void) { - int sd = socket(AF_INET, SOCK_DGRAM, 0); - if (sd < 0) { - log_libc(LLVL_ERROR, "Unable to create UDP server socket(2)"); - return -1; - } - { - int value = 1; - if (setsockopt(sd, SOL_SOCKET, SO_BROADCAST, &value, sizeof(value))) { - log_libc(LLVL_ERROR, "Unable to set UDP socket in broadcast mode using setsockopt(2)"); - close(sd); - return -1; +static bool all_volumes_unlocked(struct keyclient_t *keyclient) { + const unsigned int volume_count = keyclient->keydb->hosts[0].volume_count; + for (unsigned int i = 0; i < volume_count; i++) { + if (!keyclient->volume_unlocked[i]) { + return false; } } - return sd; -} - -static bool send_udp_broadcast_message(int sd, unsigned int port, const void *data, unsigned int length) { - struct sockaddr_in destination; - memset(&destination, 0, sizeof(struct sockaddr_in)); - destination.sin_family = AF_INET; - destination.sin_port = htons(port); - destination.sin_addr.s_addr = htonl(INADDR_BROADCAST); - - if (sendto(sd, data, length, 0, (struct sockaddr *)&destination, sizeof(struct sockaddr_in)) < 0) { - log_libc(LLVL_ERROR, "Unable to sendto(2)"); - return false; - } return true; } +static bool abort_searching_for_keyserver(struct keyclient_t *keyclient) { + if (all_volumes_unlocked(keyclient)) { + log_msg(LLVL_DEBUG, "All volumes unlocked successfully."); + return true; + } + + if (keyclient->opts->timeout_seconds) { + double time_passed = now() - keyclient->broadcast_start_time; + if (time_passed >= keyclient->opts->timeout_seconds) { + log_msg(LLVL_WARNING, "Could not unlock all volumes after %u seconds, giving up.", keyclient->opts->timeout_seconds); + return true; + } + } + + return false; +} static bool broadcast_for_keyserver(struct keyclient_t *keyclient) { - int sd = create_udp_socket(); + int sd = create_udp_socket(0, true, 1000); if (sd == -1) { return false; } + keyclient->broadcast_start_time = now(); struct udp_query_t query; memcpy(query.magic, UDP_MESSAGE_MAGIC, sizeof(query.magic)); memcpy(query.host_uuid, keyclient->keydb->hosts[0].host_uuid, 16); while (true) { send_udp_broadcast_message(sd, keyclient->opts->port, &query, sizeof(query)); - sleep(1); + + struct sockaddr_in src = { + .sin_family = AF_INET, + .sin_port = htons(keyclient->opts->port), + .sin_addr.s_addr = htonl(INADDR_ANY), + }; + struct udp_response_t response; + if (wait_udp_response(sd, &response, &src)) { + log_msg(LLVL_DEBUG, "Potential keyserver found at %d.%d.%d.%d", PRINTF_FORMAT_IP(&src)); + if (!contact_keyserver_ipv4(keyclient, &src, keyclient->opts->port)) { + log_msg(LLVL_WARNING, "Keyserver announced at %d.%d.%d.%d, but connection to it failed.", PRINTF_FORMAT_IP(&src)); + } + } + + if (abort_searching_for_keyserver(keyclient)) { + break; + } } return true; } diff --git a/keydb.c b/keydb.c index 6c93930..a655a3d 100644 --- a/keydb.c +++ b/keydb.c @@ -105,6 +105,26 @@ struct volume_entry_t* keydb_get_volume_by_name(struct host_entry_t *host, const return (index >= 0) ? &host->volumes[index] : NULL; } +const struct volume_entry_t* keydb_get_volume_by_uuid(const struct host_entry_t *host, const uint8_t uuid[static 16]) { + for (unsigned int i = 0; i < host->volume_count; i++) { + const struct volume_entry_t *volume = &host->volumes[i]; + if (!memcmp(volume->volume_uuid, uuid, 16)) { + return volume; + } + } + return NULL; +} + +int keydb_get_volume_index(const struct host_entry_t *host, const struct volume_entry_t *volume) { + int offset = volume - host->volumes; + if (offset < 0) { + return -1; + } else if ((unsigned int)offset >= host->volume_count) { + return -1; + } + return offset; +} + struct host_entry_t* keydb_get_host_by_name(struct keydb_t *keydb, const char *host_name) { const int index = keydb_get_host_index_by_name(keydb, host_name); return (index >= 0) ? &keydb->hosts[index] : NULL; diff --git a/keydb.h b/keydb.h index 2ec4c65..e62855c 100644 --- a/keydb.h +++ b/keydb.h @@ -58,6 +58,8 @@ struct keydb_t* keydb_new(void); struct keydb_t* keydb_export_public(struct host_entry_t *host); void keydb_free(struct keydb_t *keydb); struct volume_entry_t* keydb_get_volume_by_name(struct host_entry_t *host, const char *devmapper_name); +const struct volume_entry_t* keydb_get_volume_by_uuid(const struct host_entry_t *host, const uint8_t uuid[static 16]); +int keydb_get_volume_index(const struct host_entry_t *host, const struct volume_entry_t *volume); struct host_entry_t* keydb_get_host_by_name(struct keydb_t *keydb, const char *host_name); const struct host_entry_t* keydb_get_host_by_uuid(const struct keydb_t *keydb, const uint8_t uuid[static 16]); bool keydb_add_host(struct keydb_t **keydb, const char *host_name); diff --git a/parsers/parser_client.py b/parsers/parser_client.py index 159d0aa..d944483 100755 --- a/parsers/parser_client.py +++ b/parsers/parser_client.py @@ -1,6 +1,8 @@ import argparse parser = argparse.ArgumentParser(prog = "luksrku client", description = "Connects to a luksrku key server and unlocks local LUKS volumes.", add_help = False) +parser.add_argument("-t", "--timeout", metavar = "secs", default = 60, help = "When searching for a keyserver and not all volumes can be unlocked, abort after this period of time, given in seconds. Defaults to %(default)d seconds.") parser.add_argument("-p", "--port", metavar = "port", default = 23170, help = "Port that is used for both UDP and TCP communication. Defaults to %(default)d.") +parser.add_argument("--no-luks", action = "store_true", help = "Do not call LUKS/cryptsetup. Useful for testing unlocking procedure.") parser.add_argument("-v", "--verbose", action = "count", default = 0, help = "Increase verbosity. Can be specified multiple times.") parser.add_argument("filename", metavar = "filename", help = "Exported database file to load TLS-PSKs and list of disks from.") parser.add_argument("hostname", metavar = "hostname", nargs = "?", help = "When hostname is given, auto-searching for suitable servers is disabled and only a connection to the given hostname is attempted.") diff --git a/pgmopts.c b/pgmopts.c index 4ae2857..5d35376 100644 --- a/pgmopts.c +++ b/pgmopts.c @@ -97,6 +97,14 @@ static bool client_callback(enum argparse_client_option_t option, const char *va pgmopts_rw.client.port = atoi(value); break; + case ARG_CLIENT_TIMEOUT: + pgmopts_rw.client.timeout_seconds = atoi(value); + break; + + case ARG_CLIENT_NO_LUKS: + pgmopts_rw.client.no_luks = true; + break; + case ARG_CLIENT_VERBOSE: pgmopts_rw.client.verbosity++; break; diff --git a/pgmopts.h b/pgmopts.h index 5085c49..e5f6888 100644 --- a/pgmopts.h +++ b/pgmopts.h @@ -48,6 +48,8 @@ struct pgmopts_client_t { const char *filename; const char *hostname; unsigned int port; + unsigned int timeout_seconds; + bool no_luks; unsigned int verbosity; }; diff --git a/server.c b/server.c index a1c8702..79b4dff 100644 --- a/server.c +++ b/server.c @@ -176,7 +176,7 @@ static void udp_handler_thread(void *vctx) { while (true) { struct udp_query_t rx_msg; struct sockaddr_in origin; - if (!wait_udp_query(client->udp_sd, client->port, &rx_msg, &origin, 1000)) { + if (!wait_udp_query(client->udp_sd, &rx_msg, &origin)) { continue; } @@ -240,7 +240,7 @@ bool keyserver_start(const struct pgmopts_server_t *opts) { } if (opts->answer_udp_queries) { - keyserver.udp_sd = create_udp_socket(opts->port, false); + keyserver.udp_sd = create_udp_socket(opts->port, false, 1000); if (keyserver.udp_sd == -1) { success = false; break; diff --git a/udp.c b/udp.c index 5d9a2fb..6bfc1b8 100644 --- a/udp.c +++ b/udp.c @@ -27,12 +27,13 @@ #include #include #include +#include #include #include "log.h" #include "udp.h" -int create_udp_socket(unsigned int listen_port, bool send_broadcast) { +int create_udp_socket(unsigned int listen_port, bool send_broadcast, unsigned int rx_timeout_millis) { int sd = socket(AF_INET, SOCK_DGRAM, 0); if (sd < 0) { log_libc(LLVL_ERROR, "Unable to create UDP server socket(2)"); @@ -59,10 +60,21 @@ int create_udp_socket(unsigned int listen_port, bool send_broadcast) { return -1; } } + if (rx_timeout_millis) { + struct timeval tv = { + .tv_sec = rx_timeout_millis / 1000, + .tv_usec = (rx_timeout_millis % 1000) * 1000, + }; + if (setsockopt(sd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) { + log_libc(LLVL_ERROR, "Unable to set UDP receive timeout to %u ms.", rx_timeout_millis); + close(sd); + return -1; + } + } return sd; } -bool wait_udp_message(int sd, int port, void *data, unsigned int length, struct sockaddr_in *source, unsigned int timeout_millis) { +bool wait_udp_message(int sd, void *data, unsigned int length, struct sockaddr_in *source) { fprintf(stderr, "RECV...\n"); socklen_t socklen = sizeof(struct sockaddr_in); ssize_t rx_bytes = recvfrom(sd,data, length, 0, (struct sockaddr*)source, &socklen); @@ -92,8 +104,8 @@ bool send_udp_broadcast_message(int sd, int port, const void *data, unsigned int return send_udp_message(sd, &destination, data, length, false); } -bool wait_udp_query(int sd, int port, struct udp_query_t *query, struct sockaddr_in *source, unsigned int timeout_millis) { - bool rx_successful = wait_udp_message(sd, port, query, sizeof(struct udp_query_t), source, timeout_millis); +bool wait_udp_query(int sd, struct udp_query_t *query, struct sockaddr_in *source) { + bool rx_successful = wait_udp_message(sd, query, sizeof(struct udp_query_t), source); if (rx_successful) { /* Also check if the message contains the correct magic */ if (!memcmp(query->magic, UDP_MESSAGE_MAGIC, UDP_MESSAGE_MAGIC_SIZE)) { @@ -102,3 +114,14 @@ bool wait_udp_query(int sd, int port, struct udp_query_t *query, struct sockaddr } return false; } + +bool wait_udp_response(int sd, struct udp_response_t *response, struct sockaddr_in *source) { + bool rx_successful = wait_udp_message(sd, response, sizeof(struct udp_response_t), source); + if (rx_successful) { + /* Also check if the message contains the correct magic */ + if (!memcmp(response->magic, UDP_MESSAGE_MAGIC, UDP_MESSAGE_MAGIC_SIZE)) { + return true; + } + } + return false; +} diff --git a/udp.h b/udp.h index 1ae0e43..835de59 100644 --- a/udp.h +++ b/udp.h @@ -28,11 +28,12 @@ #include "msg.h" /*************** AUTO GENERATED SECTION FOLLOWS ***************/ -int create_udp_socket(unsigned int listen_port, bool send_broadcast); -bool wait_udp_message(int sd, int port, void *data, unsigned int length, struct sockaddr_in *source, unsigned int timeout_millis); +int create_udp_socket(unsigned int listen_port, bool send_broadcast, unsigned int rx_timeout_millis); +bool wait_udp_message(int sd, void *data, unsigned int length, struct sockaddr_in *source); bool send_udp_message(int sd, struct sockaddr_in *destination, const void *data, unsigned int length, bool is_response); bool send_udp_broadcast_message(int sd, int port, const void *data, unsigned int length); -bool wait_udp_query(int sd, int port, struct udp_query_t *query, struct sockaddr_in *source, unsigned int timeout_millis); +bool wait_udp_query(int sd, struct udp_query_t *query, struct sockaddr_in *source); +bool wait_udp_response(int sd, struct udp_response_t *response, struct sockaddr_in *source); /*************** AUTO GENERATED SECTION ENDS ***************/ #endif diff --git a/util.c b/util.c index 8d56d06..739d2ee 100644 --- a/util.c +++ b/util.c @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "util.h" @@ -211,3 +212,11 @@ bool ascii_encode(char *dest, unsigned int dest_buffer_size, const uint8_t *sour *dest = 0; return true; } + +double now(void) { + struct timeval tv; + if (gettimeofday(&tv, NULL)) { + return 0; + } + return tv.tv_sec + (tv.tv_usec * 1e-6); +} diff --git a/util.h b/util.h index a568612..4f40aa0 100644 --- a/util.h +++ b/util.h @@ -43,6 +43,7 @@ bool buffer_randomize(uint8_t *buffer, unsigned int length); bool is_zero(const void *data, unsigned int length); bool array_remove(void *base, unsigned int element_size, unsigned int element_count, unsigned int remove_element_index); bool ascii_encode(char *dest, unsigned int dest_buffer_size, const uint8_t *source_data, unsigned int source_data_length); +double now(void); /*************** AUTO GENERATED SECTION ENDS ***************/ #endif