diff --git a/src/crypto.c b/src/crypto.c index fe91a0b3..b811aecf 100644 --- a/src/crypto.c +++ b/src/crypto.c @@ -610,6 +610,79 @@ bool prf_sha1(const void *key, size_t key_len, return true; } +/* PRF+ from RFC 5295 Section 3.1.2 (also RFC 4306 Section 2.13) */ +bool prf_plus(enum l_checksum_type type, const void *key, size_t key_len, + const char *label, void *out, size_t out_len, + size_t n_extra, ...) +{ + struct iovec iov[n_extra + 3]; + uint8_t *t = out; + size_t t_len = 0; + uint8_t count = 1; + uint8_t *out_ptr = out; + va_list va; + struct l_checksum *hmac; + ssize_t ret; + size_t i; + + iov[1].iov_base = (void *) label; + iov[1].iov_len = strlen(label); + + /* Include the '\0' from the label in S if extra arguments provided */ + if (n_extra) + iov[1].iov_len += 1; + + va_start(va, n_extra); + + for (i = 0; i < n_extra; i++) { + iov[i + 2].iov_base = va_arg(va, void *); + iov[i + 2].iov_len = va_arg(va, size_t); + } + + va_end(va); + + iov[n_extra + 2].iov_base = &count; + iov[n_extra + 2].iov_len = 1; + + hmac = l_checksum_new_hmac(type, key, key_len); + if (!hmac) + return false; + + while (out_len > 0) { + iov[0].iov_base = t; + iov[0].iov_len = t_len; + + if (!l_checksum_updatev(hmac, iov, n_extra + 3)) { + l_checksum_free(hmac); + return false; + } + + ret = l_checksum_get_digest(hmac, out_ptr, out_len); + if (ret < 0) { + l_checksum_free(hmac); + return false; + } + + /* + * RFC specifies that T(0) = empty string, so after the first + * iteration we update the length for T(1)...T(N) + */ + t_len = ret; + t = out_ptr; + count++; + + out_len -= ret; + out_ptr += ret; + + if (out_len) + l_checksum_reset(hmac); + } + + l_checksum_free(hmac); + + return true; +} + bool prf_plus_sha1(const void *key, size_t key_len, const void *label, size_t label_len, const void *seed, size_t seed_len, @@ -812,59 +885,9 @@ bool hkdf_extract(enum l_checksum_type type, const void *key, } bool hkdf_expand(enum l_checksum_type type, const uint8_t *key, size_t key_len, - const char *info, size_t info_len, void *out, - size_t out_len) + const char *info, void *out, size_t out_len) { - uint8_t *t = out; - size_t t_len = 0; - struct l_checksum *hmac; - uint8_t count = 1; - uint8_t *out_ptr = out; - - hmac = l_checksum_new_hmac(type, key, key_len); - if (!hmac) - return false; - - while (out_len > 0) { - ssize_t ret; - struct iovec iov[3]; - - iov[0].iov_base = t; - iov[0].iov_len = t_len; - iov[1].iov_base = (void *) info; - iov[1].iov_len = info_len; - iov[2].iov_base = &count; - iov[2].iov_len = 1; - - if (!l_checksum_updatev(hmac, iov, 3)) { - l_checksum_free(hmac); - return false; - } - - ret = l_checksum_get_digest(hmac, out_ptr, out_len); - if (ret < 0) { - l_checksum_free(hmac); - return false; - } - - /* - * RFC specifies that T(0) = empty string, so after the first - * iteration we update the length for T(1)...T(N) - */ - t_len = ret; - t = out_ptr; - count++; - - out_len -= ret; - out_ptr += ret; - - if (out_len) - l_checksum_reset(hmac); - } - - l_checksum_free(hmac); - - return true; + return prf_plus(type, key, key_len, info, out, out_len, 0); } /* diff --git a/src/crypto.h b/src/crypto.h index 65e6c6ef..a6f1ff09 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -109,12 +109,16 @@ bool prf_sha1(const void *key, size_t key_len, bool prf_plus_sha1(const void *key, size_t key_len, const void *prefix, size_t prefix_len, const void *data, size_t data_len, void *output, size_t size); + +bool prf_plus(enum l_checksum_type type, const void *key, size_t key_len, + const char *label, void *out, size_t out_len, + size_t n_extra, ...); + bool hkdf_extract(enum l_checksum_type type, const void *key, size_t key_len, uint8_t num_args, void *out, ...); bool hkdf_expand(enum l_checksum_type type, const uint8_t *key, size_t key_len, - const char *info, size_t info_len, void *out, - size_t out_len); + const char *info, void *out, size_t out_len); bool crypto_derive_pairwise_ptk(const uint8_t *pmk, size_t pmk_len, const uint8_t *addr1, const uint8_t *addr2, diff --git a/src/erp.c b/src/erp.c index d77a926e..b16aa625 100644 --- a/src/erp.c +++ b/src/erp.c @@ -278,17 +278,15 @@ static bool erp_derive_emsk_name(const uint8_t *session_id, size_t session_len, char buf[static 17]) { uint8_t hex[8]; - char info[7] = { 'E', 'M', 'S', 'K', '\0', 0x0, 0x8}; + uint16_t eight = L_CPU_TO_BE16(8); char *ascii; - if (!hkdf_expand(L_CHECKSUM_SHA256, session_id, session_len, info, - sizeof(info), hex, 8)) + if (!prf_plus(L_CHECKSUM_SHA256, session_id, session_len, "EMSK", + hex, 8, 1, &eight, sizeof(eight))) return false; ascii = l_util_hexstring(hex, 8); - strcpy(buf, ascii); - l_free(ascii); return true; @@ -308,26 +306,17 @@ static bool erp_derive_emsk_name(const uint8_t *session_id, size_t session_len, static bool erp_derive_reauth_keys(const uint8_t *emsk, size_t emsk_len, void *r_rk, void *r_ik) { - char info[256]; - char *ptr; + uint16_t len = L_CPU_TO_BE16(emsk_len); + uint8_t cryptosuite = ERP_CRYPTOSUITE_SHA256_128; - ptr = info + l_strlcpy(info, ERP_RRK_LABEL, sizeof(info)) + 1; - - l_put_be16(emsk_len, ptr); - ptr += 2; - - if (!hkdf_expand(L_CHECKSUM_SHA256, emsk, emsk_len, (const char *)info, - ptr - info, r_rk, emsk_len)) + if (!prf_plus(L_CHECKSUM_SHA256, emsk, emsk_len, ERP_RRK_LABEL, + r_rk, emsk_len, 1, + &len, sizeof(len))) return false; - ptr = info + l_strlcpy(info, ERP_RIK_LABEL, sizeof(info)) + 1; - - *ptr++ = ERP_CRYPTOSUITE_SHA256_128; - l_put_be16(emsk_len, ptr); - ptr += 2; - - if (!hkdf_expand(L_CHECKSUM_SHA256, r_rk, emsk_len, (const char *) info, - ptr - info, r_ik, emsk_len)) + if (!prf_plus(L_CHECKSUM_SHA256, r_rk, emsk_len, ERP_RIK_LABEL, + r_ik, emsk_len, 2 + &cryptosuite, 1, &len, sizeof(len))) return false; return true; @@ -411,11 +400,10 @@ int erp_rx_packet(struct erp_state *erp, const uint8_t *pkt, size_t len) struct erp_tlv_iter iter; enum eap_erp_cryptosuite cs; uint8_t hash[16]; - char info[256]; - char *ptr = info; const uint8_t *nai = NULL; uint8_t type; uint16_t seq; + uint16_t length; bool r; /* @@ -503,16 +491,14 @@ int erp_rx_packet(struct erp_state *erp, const uint8_t *pkt, size_t len) /* * RFC 6696 Section 4.6 - rMSK Derivation */ - strcpy(ptr, ERP_RMSK_LABEL); - ptr += strlen(ERP_RMSK_LABEL); - *ptr++ = '\0'; - l_put_be16(erp->seq, ptr); - ptr += 2; - l_put_be16(64, ptr); - ptr += 2; + seq = L_CPU_TO_BE16(erp->seq); + length = L_CPU_TO_BE16(64); - if (!hkdf_expand(L_CHECKSUM_SHA256, erp->r_rk, erp->cache->emsk_len, - info, ptr - info, erp->rmsk, erp->cache->emsk_len)) + if (!prf_plus(L_CHECKSUM_SHA256, erp->r_rk, erp->cache->emsk_len, + ERP_RMSK_LABEL, + erp->rmsk, erp->cache->emsk_len, 2, + &seq, sizeof(seq), + &length, sizeof(length))) goto eap_failed; return 0; diff --git a/src/owe.c b/src/owe.c index 2297d966..0b96bcb2 100644 --- a/src/owe.c +++ b/src/owe.c @@ -193,8 +193,7 @@ static bool owe_compute_keys(struct owe_sm *owe, const void *public_key, goto failed; /* PMK = HKDF-expand(prk, "OWE Key Generation", n) */ - if (!hkdf_expand(type, prk, nbytes, "OWE Key Generation", - strlen("OWE Key Generation"), pmk, nbytes)) + if (!hkdf_expand(type, prk, nbytes, "OWE Key Generation", pmk, nbytes)) goto failed; sha = l_checksum_new(type);