diff --git a/Makefile b/Makefile index e037cf6..d7a2cbd 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ INSTALL_PREFIX := /usr/local/ CFLAGS := -Wall -Wextra -Wshadow -Wswitch -Wpointer-arith -Wcast-qual -Wstrict-prototypes -Wmissing-prototypes -Werror=implicit-function-declaration -Werror=format -Wno-unused-parameter CFLAGS += -O3 -std=c11 -pthread -D_POSIX_SOURCE -D_XOPEN_SOURCE=500 -DBUILD_REVISION='"$(BUILD_REVISION)"' CFLAGS += `pkg-config --cflags openssl` -#CFLAGS += -ggdb3 -DDEBUG -fsanitize=address -fsanitize=undefined -fsanitize=leak +CFLAGS += -ggdb3 -DDEBUG -fsanitize=address -fsanitize=undefined -fsanitize=leak PYPGMOPTS := ../Python/pypgmopts/pypgmopts LDFLAGS := `pkg-config --libs openssl` diff --git a/keydb.c b/keydb.c index edcc722..97fc7a7 100644 --- a/keydb.c +++ b/keydb.c @@ -98,16 +98,26 @@ static int keydb_get_host_index_by_name(struct keydb_t *keydb, const char *host_ return -1; } -struct volume_entry_t *keydb_get_volume_by_name(struct host_entry_t *host, const char *devmapper_name) { +struct volume_entry_t* keydb_get_volume_by_name(struct host_entry_t *host, const char *devmapper_name) { const int index = keydb_get_volume_index_by_name(host, devmapper_name); return (index >= 0) ? &host->volumes[index] : NULL; } -struct host_entry_t *keydb_get_host_by_name(struct keydb_t *keydb, const char *host_name) { +struct host_entry_t* keydb_get_host_by_name(struct keydb_t *keydb, const char *host_name) { const int index = keydb_get_host_index_by_name(keydb, host_name); return (index >= 0) ? &keydb->hosts[index] : NULL; } +const struct host_entry_t* keydb_get_host_by_uuid(const struct keydb_t *keydb, const uint8_t uuid[static 16]) { + for (unsigned int i = 0; i < keydb->host_count; i++) { + const struct host_entry_t *host = &keydb->hosts[i]; + if (!memcmp(host->host_uuid, uuid, 16)) { + return host; + } + } + return NULL; +} + bool keydb_add_host(struct keydb_t **keydb, const char *host_name) { if (strlen(host_name) > MAX_HOST_NAME_LENGTH - 1) { log_msg(LLVL_ERROR, "Host name \"%s\" exceeds maximum length of %d characters.", host_name, MAX_HOST_NAME_LENGTH - 1); diff --git a/keydb.h b/keydb.h index eb8dbe9..9e461df 100644 --- a/keydb.h +++ b/keydb.h @@ -57,8 +57,9 @@ struct keydb_t { struct keydb_t* keydb_new(void); struct keydb_t* keydb_export_public(struct host_entry_t *host); void keydb_free(struct keydb_t *keydb); -struct volume_entry_t *keydb_get_volume_by_name(struct host_entry_t *host, const char *devmapper_name); -struct host_entry_t *keydb_get_host_by_name(struct keydb_t *keydb, const char *host_name); +struct volume_entry_t* keydb_get_volume_by_name(struct host_entry_t *host, const char *devmapper_name); +struct host_entry_t* keydb_get_host_by_name(struct keydb_t *keydb, const char *host_name); +const struct host_entry_t* keydb_get_host_by_uuid(const struct keydb_t *keydb, const uint8_t uuid[static 16]); bool keydb_add_host(struct keydb_t **keydb, const char *host_name); bool keydb_del_host_by_name(struct keydb_t **keydb, const char *host_name); bool keydb_rekey_host(struct host_entry_t *host); diff --git a/server.c b/server.c index e0f711c..dd33fe7 100644 --- a/server.c +++ b/server.c @@ -42,6 +42,7 @@ #include "server.h" #include "luks.h" #include "pgmopts.h" +#include "uuid.h" static int create_tcp_server_socket(int port) { int s; @@ -305,22 +306,71 @@ static unsigned int psk_server_callback(SSL *ssl, const char *identity, unsigned #endif static int psk_server_callback(SSL *ssl, const unsigned char *identity, size_t identity_len, SSL_SESSION **sessptr) { - fprintf(stderr, "PSK server SSL %p identity %s len %ld sess %p\n", ssl, identity, identity_len, *sessptr); - SSL_SESSION *sess = SSL_SESSION_new(); - SSL_SESSION_set1_master_key(sess, (const unsigned char*)"\x00\x11\x22", 3); - const unsigned char tls13_aes128gcmsha256_id[] = { 0x13, 0x01 }; - const SSL_CIPHER *cipher = SSL_CIPHER_find(ssl, tls13_aes128gcmsha256_id); - if (!cipher) { + const struct keydb_t *keydb = (const struct keydb_t*)SSL_get_default_passwd_cb_userdata(ssl); + + if (identity_len != ASCII_UUID_CHARACTER_COUNT) { + log_msg(LLVL_WARNING, "Received client identity of length %d, cannot be a UUID.", identity_len); return 0; } - SSL_SESSION_set_cipher(sess, cipher); - SSL_SESSION_set_protocol_version(sess, TLS1_3_VERSION); + + char uuid_str[ASCII_UUID_BUFSIZE]; + memcpy(uuid_str, identity, ASCII_UUID_CHARACTER_COUNT); + uuid_str[ASCII_UUID_CHARACTER_COUNT] = 0; + if (!is_valid_uuid(uuid_str)) { + log_msg(LLVL_WARNING, "Received client identity of length %d, but not a valid UUID.", identity_len); + return 0; + } + + uint8_t uuid[16]; + if (!parse_uuid(uuid, uuid_str)) { + log_msg(LLVL_ERROR, "Failed to parse valid UUID."); + return 0; + } + + const struct host_entry_t *host = keydb_get_host_by_uuid(keydb, uuid); + if (!host) { + log_msg(LLVL_WARNING, "Client connected with client UUID %s, but not present in key database.", uuid_str); + return 0; + } + + const uint8_t tls13_aes128gcmsha256_id[] = { 0x13, 0x01 }; + const SSL_CIPHER *cipher = SSL_CIPHER_find(ssl, tls13_aes128gcmsha256_id); + if (!cipher) { + log_openssl(LLVL_ERROR, "Unable to look up SSL_CIPHER for TLSv1.3-PSK"); + return 0; + } + + SSL_SESSION *sess = SSL_SESSION_new(); + if (!sess) { + log_openssl(LLVL_ERROR, "Failed to create SSL_SESSION context for client."); + return 0; + } + + if (!SSL_SESSION_set1_master_key(sess, (const unsigned char*)"\x00\x11\x22", 3)) { + log_openssl(LLVL_ERROR, "Failed to set TLSv1.3-PSK master key."); + SSL_SESSION_free(sess); + return 0; + } + + if (!SSL_SESSION_set_cipher(sess, cipher)) { + log_openssl(LLVL_ERROR, "Failed to set TLSv1.3-PSK cipher."); + SSL_SESSION_free(sess); + return 0; + } + + if (!SSL_SESSION_set_protocol_version(sess, TLS1_3_VERSION)) { + log_openssl(LLVL_ERROR, "Failed to set TLSv1.3-PSK protocol version."); + SSL_SESSION_free(sess); + return 0; + } + *sessptr = sess; return 1; } struct client_ctx_t { struct generic_tls_ctx_t *gctx; + const struct keydb_t *keydb; int fd; }; @@ -329,11 +379,13 @@ static void *client_handler_thread(void *vctx) { SSL *ssl = SSL_new(client->gctx->ctx); SSL_set_fd(ssl, client->fd); + SSL_set_default_passwd_cb_userdata(ssl, (void*)client->keydb); if (SSL_accept(ssl) <= 0) { ERR_print_errors_fp(stderr); } else { - log_msg(LLVL_DEBUG, "Client connected, waiting for data..."); + log_msg(LLVL_DEBUG, "Client connected, sending their data..."); + /* while (true) { struct msg_t msg; int rxlen = SSL_read(ssl, &msg, sizeof(msg)); @@ -342,6 +394,7 @@ static void *client_handler_thread(void *vctx) { break; } } + */ fprintf(stderr, "done\n"); } @@ -353,6 +406,19 @@ static void *client_handler_thread(void *vctx) { } bool keyserver_start(const struct pgmopts_server_t *opts) { + /* Load key database first */ + struct keydb_t* keydb = keydb_read(opts->filename); + if (!keydb) { + log_msg(LLVL_FATAL, "Failed to load key database: %s", opts->filename); + return false; + } + + if (!keydb->server_database) { + log_msg(LLVL_FATAL, "Not a server key database: %s", opts->filename); + keydb_free(keydb); + return false; + } + struct generic_tls_ctx_t gctx; if (!create_generic_tls_context(&gctx, true)) { log_msg(LLVL_FATAL, "Failed to create OpenSSL server context."); @@ -361,9 +427,6 @@ bool keyserver_start(const struct pgmopts_server_t *opts) { SSL_CTX_set_psk_find_session_callback(gctx.ctx, psk_server_callback); - if (!SSL_CTX_use_psk_identity_hint(gctx.ctx, "watwatwat")) { - } - int tcp_sock = create_tcp_server_socket(opts->port); if (tcp_sock == -1) { log_msg(LLVL_ERROR, "Cannot start server without server socket."); @@ -391,6 +454,7 @@ bool keyserver_start(const struct pgmopts_server_t *opts) { return false; } client_ctx->gctx = &gctx; + client_ctx->keydb = keydb; client_ctx->fd = client; pthread_t thread; diff --git a/uuid.h b/uuid.h index 29826e6..1d6d933 100644 --- a/uuid.h +++ b/uuid.h @@ -27,8 +27,10 @@ #include #include +#define ASCII_UUID_CHARACTER_COUNT 36 + /* Already includes zero termination */ -#define ASCII_UUID_BUFSIZE 37 +#define ASCII_UUID_BUFSIZE (ASCII_UUID_CHARACTER_COUNT + 1) /*************** AUTO GENERATED SECTION FOLLOWS ***************/ bool is_valid_uuid(const char *ascii_uuid);