Refactoring of server code

Consolidate server state into one struct, similar to our client
solution.
This commit is contained in:
Johannes Bauer 2019-10-24 17:04:49 +02:00
parent 39ced77b98
commit 60b1b2bf39
2 changed files with 36 additions and 26 deletions

View File

@ -46,7 +46,7 @@ test_s: luksrku
./luksrku server -vv testdata/server.bin ./luksrku server -vv testdata/server.bin
test_c: luksrku test_c: luksrku
./luksrku client -vv testdata/client.bin 127.0.0.1 ./luksrku client -vv testdata/client.bin
.c.o: .c.o:
$(CC) $(CFLAGS) -c -o $@ $< $(CC) $(CFLAGS) -c -o $@ $<

View File

@ -48,13 +48,20 @@
#include "keydb.h" #include "keydb.h"
#include "signals.h" #include "signals.h"
struct client_ctx_t { struct client_thread_ctx_t {
struct generic_tls_ctx_t *gctx; struct generic_tls_ctx_t *gctx;
const struct keydb_t *keydb; const struct keydb_t *keydb;
const struct host_entry_t *host; const struct host_entry_t *host;
int fd; int fd;
}; };
struct keyserver_t {
struct keydb_t* keydb;
struct generic_tls_ctx_t gctx;
const struct pgmopts_server_t *opts;
int tcp_sd;
};
static int create_tcp_server_socket(int port) { static int create_tcp_server_socket(int port) {
int sd = socket(AF_INET, SOCK_STREAM, 0); int sd = socket(AF_INET, SOCK_STREAM, 0);
if (sd < 0) { if (sd < 0) {
@ -86,7 +93,7 @@ static int create_tcp_server_socket(int port) {
} }
static int psk_server_callback(SSL *ssl, const unsigned char *identity, size_t identity_len, SSL_SESSION **sessptr) { static int psk_server_callback(SSL *ssl, const unsigned char *identity, size_t identity_len, SSL_SESSION **sessptr) {
struct client_ctx_t *ctx = (struct client_ctx_t*)SSL_get_app_data(ssl); struct client_thread_ctx_t *ctx = (struct client_thread_ctx_t*)SSL_get_app_data(ssl);
if (identity_len != ASCII_UUID_CHARACTER_COUNT) { if (identity_len != ASCII_UUID_CHARACTER_COUNT) {
log_msg(LLVL_WARNING, "Received client identity of length %d, cannot be a UUID.", identity_len); log_msg(LLVL_WARNING, "Received client identity of length %d, cannot be a UUID.", identity_len);
@ -117,7 +124,7 @@ static int psk_server_callback(SSL *ssl, const unsigned char *identity, size_t i
} }
static void client_handler_thread(void *vctx) { static void client_handler_thread(void *vctx) {
struct client_ctx_t *client = (struct client_ctx_t*)vctx; struct client_thread_ctx_t *client = (struct client_thread_ctx_t*)vctx;
SSL *ssl = SSL_new(client->gctx->ctx); SSL *ssl = SSL_new(client->gctx->ctx);
if (ssl) { if (ssl) {
@ -157,36 +164,38 @@ static void client_handler_thread(void *vctx) {
bool keyserver_start(const struct pgmopts_server_t *opts) { bool keyserver_start(const struct pgmopts_server_t *opts) {
bool success = true; bool success = true;
struct keydb_t* keydb = NULL; struct keyserver_t keyserver = {
struct generic_tls_ctx_t gctx = { 0 }; .opts = opts,
.tcp_sd = -1,
};
do { do {
/* We ignore SIGPIPE or the server will die when clients disconnect suddenly */ /* We ignore SIGPIPE or the server will die when clients disconnect suddenly */
ignore_signal(SIGPIPE); ignore_signal(SIGPIPE);
/* Load key database first */ /* Load key database first */
keydb = keydb_read(opts->filename); keyserver.keydb = keydb_read(opts->filename);
if (!keydb) { if (!keyserver.keydb) {
log_msg(LLVL_FATAL, "Failed to load key database: %s", opts->filename); log_msg(LLVL_FATAL, "Failed to load key database: %s", opts->filename);
success = false; success = false;
break; break;
} }
if (!keydb->server_database) { if (!keyserver.keydb->server_database) {
log_msg(LLVL_FATAL, "Not a server key database: %s", opts->filename); log_msg(LLVL_FATAL, "Not a server key database: %s", opts->filename);
success = false; success = false;
break; break;
} }
if (!create_generic_tls_context(&gctx, true)) { if (!create_generic_tls_context(&keyserver.gctx, true)) {
log_msg(LLVL_FATAL, "Failed to create OpenSSL server context."); log_msg(LLVL_FATAL, "Failed to create OpenSSL server context.");
success = false; success = false;
break; break;
} }
SSL_CTX_set_psk_find_session_callback(gctx.ctx, psk_server_callback); SSL_CTX_set_psk_find_session_callback(keyserver.gctx.ctx, psk_server_callback);
int tcp_sock = create_tcp_server_socket(opts->port); keyserver.tcp_sd = create_tcp_server_socket(opts->port);
if (tcp_sock == -1) { if (keyserver.tcp_sd == -1) {
log_msg(LLVL_ERROR, "Cannot start server without server socket."); log_msg(LLVL_ERROR, "Cannot start server without server socket.");
success = false; success = false;
break; break;
@ -195,29 +204,30 @@ bool keyserver_start(const struct pgmopts_server_t *opts) {
while (true) { while (true) {
struct sockaddr_in addr; struct sockaddr_in addr;
unsigned int len = sizeof(addr); unsigned int len = sizeof(addr);
int client = accept(tcp_sock, (struct sockaddr*)&addr, &len); int client = accept(keyserver.tcp_sd, (struct sockaddr*)&addr, &len);
if (client < 0) { if (client < 0) {
log_libc(LLVL_ERROR, "Unable to accept(2)"); log_libc(LLVL_ERROR, "Unable to accept(2)");
close(tcp_sock); success = false;
free_generic_tls_context(&gctx); break;
return false;
} }
/* Client has connected, fire up client thread. */ /* Client has connected, fire up client thread. */
struct client_ctx_t client_ctx = { struct client_thread_ctx_t client_ctx = {
.gctx = &gctx, .gctx = &keyserver.gctx,
.keydb = keydb, .keydb = keyserver.keydb,
.fd = client, .fd = client,
}; };
if (!pthread_create_detached_thread(client_handler_thread, &client_ctx, sizeof(client_ctx))) { if (!pthread_create_detached_thread(client_handler_thread, &client_ctx, sizeof(client_ctx))) {
log_libc(LLVL_FATAL, "Unable to pthread_attr_init(3)"); log_libc(LLVL_FATAL, "Unable to create detached thread.");
close(tcp_sock); success = false;
free_generic_tls_context(&gctx); break;
return false;
} }
} }
} while (false); } while (false);
free_generic_tls_context(&gctx); if (keyserver.tcp_sd != -1) {
keydb_free(keydb); close(keyserver.tcp_sd);
}
free_generic_tls_context(&keyserver.gctx);
keydb_free(keyserver.keydb);
return success; return success;
} }