Make vault threadsafe

We might have multiple processes accessing the vault and need to always
keep a proper reference count.
This commit is contained in:
Johannes Bauer 2019-10-25 16:30:46 +02:00
parent 54063ec025
commit 0bf0759c9c
2 changed files with 38 additions and 15 deletions

49
vault.c
View File

@ -28,8 +28,10 @@
#include <openssl/crypto.h> #include <openssl/crypto.h>
#include <openssl/rand.h> #include <openssl/rand.h>
#include <openssl/evp.h> #include <openssl/evp.h>
#include <pthread.h>
#include "vault.h" #include "vault.h"
#include "util.h" #include "util.h"
#include "log.h"
static bool vault_derive_key(const struct vault_t *vault, uint8_t key[static 32]) { static bool vault_derive_key(const struct vault_t *vault, uint8_t key[static 32]) {
/* Derive the AES key from it */ /* Derive the AES key from it */
@ -73,6 +75,11 @@ struct vault_t* vault_init(unsigned int data_length, double target_derivation_ti
return NULL; return NULL;
} }
if (pthread_mutex_init(&vault->mutex, NULL)) {
log_libc(LLVL_FATAL, "Unable to initialize vault mutex.");
free(vault);
return NULL;
}
vault->key = malloc(DEFAULT_KEY_LENGTH_BYTES); vault->key = malloc(DEFAULT_KEY_LENGTH_BYTES);
vault->key_length = DEFAULT_KEY_LENGTH_BYTES; vault->key_length = DEFAULT_KEY_LENGTH_BYTES;
if (!vault->key) { if (!vault->key) {
@ -85,7 +92,7 @@ struct vault_t* vault_init(unsigned int data_length, double target_derivation_ti
vault_free(vault); vault_free(vault);
return NULL; return NULL;
} }
vault->is_open = true; vault->reference_count = 1;
vault->data_length = data_length; vault->data_length = data_length;
vault_calibrate_derivation_time(vault, target_derivation_time); vault_calibrate_derivation_time(vault, target_derivation_time);
@ -101,11 +108,7 @@ static void vault_destroy_content(struct vault_t *vault) {
} }
} }
bool vault_open(struct vault_t *vault) { static bool vault_decrypt(struct vault_t *vault) {
if (vault->is_open) {
return true;
}
uint8_t dkey[32]; uint8_t dkey[32];
if (!vault_derive_key(vault, dkey)) { if (!vault_derive_key(vault, dkey)) {
return false; return false;
@ -152,7 +155,6 @@ bool vault_open(struct vault_t *vault) {
} while (false); } while (false);
if (success) { if (success) {
vault->is_open = true;
OPENSSL_cleanse(vault->key, vault->key_length); OPENSSL_cleanse(vault->key, vault->key_length);
OPENSSL_cleanse(vault->auth_tag, 16); OPENSSL_cleanse(vault->auth_tag, 16);
} else { } else {
@ -165,11 +167,19 @@ bool vault_open(struct vault_t *vault) {
return success; return success;
} }
bool vault_close(struct vault_t *vault) { bool vault_open(struct vault_t *vault) {
if (!vault->is_open) { bool success = true;
return true; pthread_mutex_lock(&vault->mutex);
vault->reference_count++;
if (vault->reference_count == 1) {
/* Vault was closed, we need to decrypt it. */
success = vault_decrypt(vault);
} }
pthread_mutex_unlock(&vault->mutex);
return success;
}
static bool vault_encrypt(struct vault_t *vault) {
/* Generate a new key source */ /* Generate a new key source */
if (RAND_bytes(vault->key, vault->key_length) != 1) { if (RAND_bytes(vault->key, vault->key_length) != 1) {
return false; return false;
@ -223,9 +233,7 @@ bool vault_close(struct vault_t *vault) {
} }
} while (false); } while (false);
if (success) { if (!success) {
vault->is_open = false;
} else {
/* Vault may be in an inconsistent state. Destroy contents. */ /* Vault may be in an inconsistent state. Destroy contents. */
vault_destroy_content(vault); vault_destroy_content(vault);
} }
@ -235,6 +243,19 @@ bool vault_close(struct vault_t *vault) {
return success; return success;
} }
bool vault_close(struct vault_t *vault) {
bool success = true;
pthread_mutex_lock(&vault->mutex);
vault->reference_count--;
if (vault->reference_count == 0) {
/* Vault is now closed, we need to encrypt it. */
success = vault_encrypt(vault);
}
pthread_mutex_unlock(&vault->mutex);
return success;
}
void vault_free(struct vault_t *vault) { void vault_free(struct vault_t *vault) {
vault_destroy_content(vault); vault_destroy_content(vault);
free(vault->data); free(vault->data);
@ -252,7 +273,7 @@ static void dump(const uint8_t *data, unsigned int length) {
} }
int main(void) { int main(void) {
/* gcc -Wall -std=c11 -Wmissing-prototypes -Wstrict-prototypes -Werror=implicit-function-declaration -Wimplicit-fallthrough -Wshadow -pie -fPIE -fsanitize=address -fsanitize=undefined -fsanitize=leak -o vault vault.c -lasan -lubsan -lcrypto /* gcc -D__TEST_VAULT__ -Wall -std=c11 -Wmissing-prototypes -Wstrict-prototypes -Werror=implicit-function-declaration -Wimplicit-fallthrough -Wshadow -pie -fPIE -fsanitize=address -fsanitize=undefined -fsanitize=leak -pthread -o vault vault.c util.c log.c -lcrypto
*/ */
struct vault_t *vault = vault_init(64, 0.1); struct vault_t *vault = vault_init(64, 0.1);
dump(vault->data, vault->data_length); dump(vault->data, vault->data_length);

View File

@ -26,9 +26,11 @@
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h> #include <stdint.h>
#include <pthread.h>
struct vault_t { struct vault_t {
bool is_open; pthread_mutex_t mutex;
unsigned int reference_count;
void *data; void *data;
unsigned int data_length; unsigned int data_length;
uint8_t *key; uint8_t *key;