diff --git a/src/crypto.c b/src/crypto.c index 0c7cef75..b347a838 100644 --- a/src/crypto.c +++ b/src/crypto.c @@ -749,13 +749,15 @@ bool prf_plus_sha1(const void *key, size_t key_len, } /* Defined in 802.11-2012, Section 11.6.1.7.2 Key derivation function (KDF) */ -bool kdf_sha256(const void *key, size_t key_len, +bool crypto_kdf(enum l_checksum_type type, const void *key, size_t key_len, const void *prefix, size_t prefix_len, const void *data, size_t data_len, void *output, size_t size) { struct l_checksum *hmac; unsigned int i, offset = 0; unsigned int counter; + unsigned int chunk_size; + unsigned int n_iterations; uint8_t counter_le[2]; uint8_t length_le[2]; struct iovec iov[4] = { @@ -765,19 +767,21 @@ bool kdf_sha256(const void *key, size_t key_len, [3] = { .iov_base = length_le, .iov_len = 2 }, }; - hmac = l_checksum_new_hmac(L_CHECKSUM_SHA256, key, key_len); + hmac = l_checksum_new_hmac(type, key, key_len); if (!hmac) return false; + chunk_size = l_checksum_digest_length(type); + n_iterations = (size + chunk_size - 1) / chunk_size; + /* Length is denominated in bits, not bytes */ l_put_le16(size * 8, length_le); - /* KDF processes in 256-bit chunks (32 bytes) */ - for (i = 0, counter = 1; i < (size + 31) / 32; i++, counter++) { + for (i = 0, counter = 1; i < n_iterations; i++, counter++) { size_t len; - if (size - offset > 32) - len = 32; + if (size - offset > chunk_size) + len = chunk_size; else len = size - offset; @@ -794,49 +798,20 @@ bool kdf_sha256(const void *key, size_t key_len, return true; } +bool kdf_sha256(const void *key, size_t key_len, + const void *prefix, size_t prefix_len, + const void *data, size_t data_len, void *output, size_t size) +{ + return crypto_kdf(L_CHECKSUM_SHA256, key, key_len, prefix, prefix_len, + data, data_len, output, size); +} + bool kdf_sha384(const void *key, size_t key_len, const void *prefix, size_t prefix_len, const void *data, size_t data_len, void *output, size_t size) { - struct l_checksum *hmac; - unsigned int i, offset = 0; - unsigned int counter; - uint8_t counter_le[2]; - uint8_t length_le[2]; - struct iovec iov[4] = { - [0] = { .iov_base = counter_le, .iov_len = 2 }, - [1] = { .iov_base = (void *) prefix, .iov_len = prefix_len }, - [2] = { .iov_base = (void *) data, .iov_len = data_len }, - [3] = { .iov_base = length_le, .iov_len = 2 }, - }; - - hmac = l_checksum_new_hmac(L_CHECKSUM_SHA384, key, key_len); - if (!hmac) - return false; - - /* Length is denominated in bits, not bytes */ - l_put_le16(size * 8, length_le); - - /* KDF processes in 384-bit chunks (48 bytes) */ - for (i = 0, counter = 1; i < (size + 47) / 48; i++, counter++) { - size_t len; - - if (size - offset > 48) - len = 48; - else - len = size - offset; - - l_put_le16(counter, counter_le); - - l_checksum_updatev(hmac, iov, 4); - l_checksum_get_digest(hmac, output + offset, len); - - offset += len; - } - - l_checksum_free(hmac); - - return true; + return crypto_kdf(L_CHECKSUM_SHA384, key, key_len, prefix, prefix_len, + data, data_len, output, size); } /* @@ -939,14 +914,12 @@ static bool crypto_derive_ptk(const uint8_t *pmk, size_t pmk_len, } pos += 64; - if (type == L_CHECKSUM_SHA384) - return kdf_sha384(pmk, pmk_len, label, strlen(label), - data, sizeof(data), out_ptk, ptk_len); - else if (type == L_CHECKSUM_SHA256) - return kdf_sha256(pmk, pmk_len, label, strlen(label), + + if (type == L_CHECKSUM_SHA1) + return prf_sha1(pmk, pmk_len, label, strlen(label), data, sizeof(data), out_ptk, ptk_len); else - return prf_sha1(pmk, pmk_len, label, strlen(label), + return crypto_kdf(type, pmk, pmk_len, label, strlen(label), data, sizeof(data), out_ptk, ptk_len); } diff --git a/src/crypto.h b/src/crypto.h index 96be515a..d359da61 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -99,6 +99,9 @@ int crypto_psk_from_passphrase(const char *passphrase, const unsigned char *ssid, size_t ssid_len, unsigned char *out_psk); +bool crypto_kdf(enum l_checksum_type type, const void *key, size_t key_len, + const void *prefix, size_t prefix_len, + const void *data, size_t data_len, void *output, size_t size); bool kdf_sha256(const void *key, size_t key_len, const void *prefix, size_t prefix_len, const void *data, size_t data_len, void *output, size_t size);