diff --git a/vault.c b/vault.c index 75a5d95..ed32443 100644 --- a/vault.c +++ b/vault.c @@ -28,8 +28,10 @@ #include #include #include +#include #include "vault.h" #include "util.h" +#include "log.h" static bool vault_derive_key(const struct vault_t *vault, uint8_t key[static 32]) { /* Derive the AES key from it */ @@ -73,6 +75,11 @@ struct vault_t* vault_init(unsigned int data_length, double target_derivation_ti return NULL; } + if (pthread_mutex_init(&vault->mutex, NULL)) { + log_libc(LLVL_FATAL, "Unable to initialize vault mutex."); + free(vault); + return NULL; + } vault->key = malloc(DEFAULT_KEY_LENGTH_BYTES); vault->key_length = DEFAULT_KEY_LENGTH_BYTES; if (!vault->key) { @@ -85,7 +92,7 @@ struct vault_t* vault_init(unsigned int data_length, double target_derivation_ti vault_free(vault); return NULL; } - vault->is_open = true; + vault->reference_count = 1; vault->data_length = data_length; vault_calibrate_derivation_time(vault, target_derivation_time); @@ -101,11 +108,7 @@ static void vault_destroy_content(struct vault_t *vault) { } } -bool vault_open(struct vault_t *vault) { - if (vault->is_open) { - return true; - } - +static bool vault_decrypt(struct vault_t *vault) { uint8_t dkey[32]; if (!vault_derive_key(vault, dkey)) { return false; @@ -152,7 +155,6 @@ bool vault_open(struct vault_t *vault) { } while (false); if (success) { - vault->is_open = true; OPENSSL_cleanse(vault->key, vault->key_length); OPENSSL_cleanse(vault->auth_tag, 16); } else { @@ -165,11 +167,19 @@ bool vault_open(struct vault_t *vault) { return success; } -bool vault_close(struct vault_t *vault) { - if (!vault->is_open) { - return true; +bool vault_open(struct vault_t *vault) { + bool success = true; + pthread_mutex_lock(&vault->mutex); + vault->reference_count++; + if (vault->reference_count == 1) { + /* Vault was closed, we need to decrypt it. */ + success = vault_decrypt(vault); } + pthread_mutex_unlock(&vault->mutex); + return success; +} +static bool vault_encrypt(struct vault_t *vault) { /* Generate a new key source */ if (RAND_bytes(vault->key, vault->key_length) != 1) { return false; @@ -223,9 +233,7 @@ bool vault_close(struct vault_t *vault) { } } while (false); - if (success) { - vault->is_open = false; - } else { + if (!success) { /* Vault may be in an inconsistent state. Destroy contents. */ vault_destroy_content(vault); } @@ -235,6 +243,19 @@ bool vault_close(struct vault_t *vault) { return success; } +bool vault_close(struct vault_t *vault) { + bool success = true; + pthread_mutex_lock(&vault->mutex); + vault->reference_count--; + if (vault->reference_count == 0) { + /* Vault is now closed, we need to encrypt it. */ + success = vault_encrypt(vault); + } + pthread_mutex_unlock(&vault->mutex); + return success; +} + + void vault_free(struct vault_t *vault) { vault_destroy_content(vault); free(vault->data); @@ -252,7 +273,7 @@ static void dump(const uint8_t *data, unsigned int length) { } int main(void) { - /* gcc -Wall -std=c11 -Wmissing-prototypes -Wstrict-prototypes -Werror=implicit-function-declaration -Wimplicit-fallthrough -Wshadow -pie -fPIE -fsanitize=address -fsanitize=undefined -fsanitize=leak -o vault vault.c -lasan -lubsan -lcrypto + /* gcc -D__TEST_VAULT__ -Wall -std=c11 -Wmissing-prototypes -Wstrict-prototypes -Werror=implicit-function-declaration -Wimplicit-fallthrough -Wshadow -pie -fPIE -fsanitize=address -fsanitize=undefined -fsanitize=leak -pthread -o vault vault.c util.c log.c -lcrypto */ struct vault_t *vault = vault_init(64, 0.1); dump(vault->data, vault->data_length); diff --git a/vault.h b/vault.h index c18cbb4..169834b 100644 --- a/vault.h +++ b/vault.h @@ -26,9 +26,11 @@ #include #include +#include struct vault_t { - bool is_open; + pthread_mutex_t mutex; + unsigned int reference_count; void *data; unsigned int data_length; uint8_t *key;