diff --git a/src/crypto.c b/src/crypto.c index d11d2dfe..57b640c4 100644 --- a/src/crypto.c +++ b/src/crypto.c @@ -400,10 +400,12 @@ bool kdf_sha256(const void *key, size_t key_len, * Max operations for nonces are with the nonces treated as positive integers * converted as specified in 8.2.2. */ -bool crypto_derive_ptk(const uint8_t *pmk, size_t pmk_len, const char *label, - const uint8_t *addr1, const uint8_t *addr2, - const uint8_t *nonce1, const uint8_t *nonce2, - uint8_t *out_ptk, size_t ptk_len) +static bool crypto_derive_ptk(const uint8_t *pmk, size_t pmk_len, + const char *label, + const uint8_t *addr1, const uint8_t *addr2, + const uint8_t *nonce1, const uint8_t *nonce2, + uint8_t *out_ptk, size_t ptk_len, + bool use_sha256) { /* Nonce length is 32 */ uint8_t data[ETH_ALEN * 2 + 64]; @@ -431,16 +433,22 @@ bool crypto_derive_ptk(const uint8_t *pmk, size_t pmk_len, const char *label, pos += 64; - return prf_sha1(pmk, pmk_len, label, strlen(label), - data, sizeof(data), out_ptk, ptk_len); + if (use_sha256) + return kdf_sha256(pmk, pmk_len, label, strlen(label), + data, sizeof(data), out_ptk, ptk_len); + else + return prf_sha1(pmk, pmk_len, label, strlen(label), + data, sizeof(data), out_ptk, ptk_len); } bool crypto_derive_pairwise_ptk(const uint8_t *pmk, const uint8_t *addr1, const uint8_t *addr2, const uint8_t *nonce1, const uint8_t *nonce2, - struct crypto_ptk *out_ptk, size_t ptk_len) + struct crypto_ptk *out_ptk, size_t ptk_len, + bool use_sha256) { return crypto_derive_ptk(pmk, 32, "Pairwise key expansion", addr1, addr2, nonce1, nonce2, - (uint8_t *) out_ptk, ptk_len); + (uint8_t *) out_ptk, ptk_len, + use_sha256); } diff --git a/src/crypto.h b/src/crypto.h index 6cd00b5b..9b663383 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -69,12 +69,8 @@ bool prf_sha1(const void *key, size_t key_len, const void *prefix, size_t prefix_len, const void *data, size_t data_len, void *output, size_t size); -bool crypto_derive_ptk(const uint8_t *pmk, size_t pmk_len, const char *label, - const uint8_t *addr1, const uint8_t *addr2, - const uint8_t *nonce1, const uint8_t *nonce2, - uint8_t *out_ptk, size_t ptk_len); - bool crypto_derive_pairwise_ptk(const uint8_t *pmk, const uint8_t *addr1, const uint8_t *addr2, const uint8_t *nonce1, const uint8_t *nonce2, - struct crypto_ptk *out_ptk, size_t ptk_len); + struct crypto_ptk *out_ptk, size_t ptk_len, + bool use_sha256);