From 48f5a051bc098a48f5ab65eb398ab82fa9e80ea5 Mon Sep 17 00:00:00 2001 From: James Prestwood Date: Thu, 10 Jan 2019 11:51:11 -0800 Subject: [PATCH] sae: update SAE to use ELL API's --- src/sae.c | 421 ++++++++++++++++++++++++------------------------------ 1 file changed, 190 insertions(+), 231 deletions(-) diff --git a/src/sae.c b/src/sae.c index 76e4c8ad..5772c77e 100644 --- a/src/sae.c +++ b/src/sae.c @@ -27,7 +27,6 @@ #include "src/handshake.h" #include "src/crypto.h" #include "src/mpdu.h" -#include "src/ecc.h" #include "src/sae.h" #define SAE_RETRANSMIT_TIMEOUT 2 @@ -42,13 +41,14 @@ enum sae_state { struct sae_sm { struct handshake_state *handshake; - struct ecc_point pwe; + struct l_ecc_point *pwe; enum sae_state state; - uint64_t rand[NUM_ECC_DIGITS]; - uint64_t scalar[NUM_ECC_DIGITS]; - uint64_t p_scalar[NUM_ECC_DIGITS]; - struct ecc_point element; - struct ecc_point p_element; + const struct l_ecc_curve *curve; + struct l_ecc_scalar *rand; + struct l_ecc_scalar *scalar; + struct l_ecc_scalar *p_scalar; + struct l_ecc_point *element; + struct l_ecc_point *p_element; uint16_t send_confirm; uint8_t kck[32]; uint8_t pmk[32]; @@ -69,60 +69,6 @@ struct sae_sm { void *user_data; }; -static uint64_t curve_p[NUM_ECC_DIGITS] = CURVE_P_32; -static uint64_t curve_n[NUM_ECC_DIGITS] = CURVE_N_32; - -/* calculate random quadratic residue */ -static void sae_get_qr(uint64_t *qr) -{ - l_getrandom(qr, 32); - - while (vli_legendre(qr, curve_p) != -1) - l_getrandom(qr, 32); -} - -/* calculate random quadratic non-residue */ -static void sae_get_qnr(uint64_t *qnr) -{ - l_getrandom(qnr, 32); - - while (vli_legendre(qnr, curve_p) != 1) - l_getrandom(qnr, 32); -} - -/* blinding technique to determine if 'value' is quadratic residue */ -static bool sae_is_quadratic_residue(uint64_t *value, uint64_t *qr, - uint64_t *qnr) -{ - uint64_t y_sqr[NUM_ECC_DIGITS]; - uint64_t r[NUM_ECC_DIGITS]; - uint64_t num[NUM_ECC_DIGITS]; - - ecc_compute_y_sqr(y_sqr, value); - - l_getrandom(r, 32); - - while (vli_cmp(r, curve_p) >= 0) - l_getrandom(r, 32); - - vli_mod_mult_fast(num, y_sqr, r); - vli_mod_mult_fast(num, num, r); - - if (r[0] & 1) { - vli_mod_mult_fast(num, num, qr); - - if (vli_legendre(num, curve_p) == -1) - return true; - } else { - vli_mod_mult_fast(num, num, qnr); - - if (vli_legendre(num, curve_p) == 1) - return true; - } - - return false; -} - static bool sae_pwd_seed(const uint8_t *addr1, const uint8_t *addr2, uint8_t *base, size_t base_len, uint8_t counter, uint8_t *out) @@ -141,25 +87,38 @@ static bool sae_pwd_seed(const uint8_t *addr1, const uint8_t *addr2, &counter, 1); } -static bool sae_pwd_value(uint8_t *pwd_seed, uint64_t *pwd_value) +static struct l_ecc_scalar *sae_pwd_value(const struct l_ecc_curve *curve, + uint8_t *pwd_seed) { - uint64_t prime[NUM_ECC_DIGITS]; + uint8_t pwd_value[L_ECC_SCALAR_MAX_BYTES]; + uint8_t prime[L_ECC_SCALAR_MAX_BYTES]; + ssize_t len; + struct l_ecc_scalar *p = l_ecc_curve_get_prime(curve); - memcpy(prime, curve_p, 32); + len = l_ecc_scalar_get_data(p, prime, sizeof(prime)); - ecc_be2native(prime); + l_ecc_scalar_free(p); - return kdf_sha256(pwd_seed, 32, "SAE Hunting and Pecking", - strlen("SAE Hunting and Pecking"), prime, 32, - pwd_value, 32); + if (!kdf_sha256(pwd_seed, 32, "SAE Hunting and Pecking", + strlen("SAE Hunting and Pecking"), prime, len, + pwd_value, 32)) + return false; + + return l_ecc_scalar_new(curve, pwd_value, sizeof(pwd_value)); } /* IEEE 802.11-2016 - Section 12.4.2 Assumptions on SAE */ static bool sae_cn(const uint8_t *kck, uint16_t send_confirm, - const uint64_t *scalar1, const uint64_t *element1, - const uint64_t *scalar2, const uint64_t *element2, + struct l_ecc_scalar *scalar1, + struct l_ecc_point *element1, + struct l_ecc_scalar *scalar2, + struct l_ecc_point *element2, uint8_t *confirm) { + uint8_t s1[L_ECC_SCALAR_MAX_BYTES]; + uint8_t s2[L_ECC_SCALAR_MAX_BYTES]; + uint8_t e1[L_ECC_POINT_MAX_BYTES]; + uint8_t e2[L_ECC_POINT_MAX_BYTES]; struct l_checksum *hmac; struct iovec iov[5]; int ret; @@ -170,14 +129,14 @@ static bool sae_cn(const uint8_t *kck, uint16_t send_confirm, iov[0].iov_base = &send_confirm; iov[0].iov_len = 2; - iov[1].iov_base = (void *) scalar1; - iov[1].iov_len = 32; - iov[2].iov_base = (void *) element1; - iov[2].iov_len = 64; - iov[3].iov_base = (void *) scalar2; - iov[3].iov_len = 32; - iov[4].iov_base = (void *) element2; - iov[4].iov_len = 64; + iov[1].iov_base = (void *) s1; + iov[1].iov_len = l_ecc_scalar_get_data(scalar1, s1, sizeof(s1)); + iov[2].iov_base = (void *) e1; + iov[2].iov_len = l_ecc_point_get_data(element1, e1, sizeof(e1)); + iov[3].iov_base = (void *) s2; + iov[3].iov_len = l_ecc_scalar_get_data(scalar2, s2, sizeof(s2));; + iov[4].iov_base = (void *) e2; + iov[4].iov_len = l_ecc_point_get_data(element2, e2, sizeof(e2));; l_checksum_updatev(hmac, iov, 5); @@ -217,28 +176,89 @@ static void sae_reject_authentication(struct sae_sm *sm, uint16_t reason) sae_authentication_failed(sm, reason); } +static struct l_ecc_scalar *sae_new_residue(const struct l_ecc_curve *curve, + bool residue) +{ + struct l_ecc_scalar *s = l_ecc_scalar_new_random(curve); + + while (l_ecc_scalar_legendre(s) != ((residue) ? -1 : 1)) { + l_ecc_scalar_free(s); + s = l_ecc_scalar_new_random(curve); + } + + return s; +} + +static bool sae_is_quadradic_residue(const struct l_ecc_curve *curve, + struct l_ecc_scalar *value, + struct l_ecc_scalar *qr, + struct l_ecc_scalar *qnr) +{ + uint64_t rbuf[L_ECC_MAX_DIGITS]; + struct l_ecc_scalar *y_sqr = l_ecc_scalar_new(curve, NULL, 0); + struct l_ecc_scalar *r = l_ecc_scalar_new_random(curve); + struct l_ecc_scalar *num = l_ecc_scalar_new(curve, NULL, 0); + size_t bytes; + + l_ecc_scalar_sum_x(y_sqr, value); + + l_ecc_scalar_multiply(num, y_sqr, r); + l_ecc_scalar_multiply(num, num, r); + + l_ecc_scalar_free(y_sqr); + + bytes = l_ecc_scalar_get_data(r, rbuf, sizeof(rbuf)); + l_ecc_scalar_free(r); + + if (bytes <= 0) { + l_ecc_scalar_free(num); + return false; + } + + if (rbuf[bytes / 8] & 1) { + l_ecc_scalar_multiply(num, num, qr); + + if (l_ecc_scalar_legendre(num) == -1) { + l_ecc_scalar_free(num); + return true; + } + } else { + l_ecc_scalar_multiply(num, num, qnr); + + if (l_ecc_scalar_legendre(num) == 1) { + l_ecc_scalar_free(num); + return true; + } + } + + l_ecc_scalar_free(num); + + return false; +} + /* * IEEE 802.11-2016 Section 12.4.4.2.2 * Generation of the password element with ECC groups */ -static bool sae_compute_pwe(char *password, const uint8_t *addr1, - const uint8_t *addr2, struct ecc_point *pwe) +static bool sae_compute_pwe(struct sae_sm *sm, char *password, + const uint8_t *addr1, const uint8_t *addr2) { bool found = false; uint8_t counter = 1; uint8_t k = 20; uint8_t pwd_seed[32]; - uint64_t pwd_value[NUM_ECC_DIGITS]; + struct l_ecc_scalar *pwd_value; uint8_t random[32]; uint8_t *base = (uint8_t *) password; size_t base_len = strlen(password); uint8_t save[32] = { 0 }; - uint64_t qr[NUM_ECC_DIGITS]; - uint64_t qnr[NUM_ECC_DIGITS]; + struct l_ecc_scalar *qr; + struct l_ecc_scalar *qnr; + uint8_t x[L_ECC_SCALAR_MAX_BYTES]; /* create qr/qnr prior to beginning hunting-and-pecking loop */ - sae_get_qr(qr); - sae_get_qnr(qnr); + qr = sae_new_residue(sm->curve, true); + qnr = sae_new_residue(sm->curve, false); do { /* pwd-seed = H(max(addr1, addr2) || min(addr1, addr2), @@ -247,87 +267,50 @@ static bool sae_compute_pwe(char *password, const uint8_t *addr1, */ sae_pwd_seed(addr1, addr2, base, base_len, counter, pwd_seed); - sae_pwd_value(pwd_seed, pwd_value); + pwd_value = sae_pwd_value(sm->curve, pwd_seed); - ecc_be2native(pwd_value); + if (sae_is_quadradic_residue(sm->curve, pwd_value, qr, qnr)) { + if (found == false) { + l_ecc_scalar_get_data(pwd_value, x, sizeof(x)); - /* if (pwd-value < p) { */ - if (vli_cmp(pwd_value, curve_p) < 0) { - if (sae_is_quadratic_residue(pwd_value, qr, qnr)) { - if (found == false) { - memcpy(pwe->x, pwd_value, 32); - memcpy(save, pwd_seed, 32); + memcpy(save, pwd_seed, 32); - l_getrandom(random, 32); - base = random; - base_len = 32; + l_getrandom(random, 32); + base = random; + base_len = 32; - found = true; - } + found = true; } } + l_ecc_scalar_free(pwd_value); + counter++; } while ((counter <= k) || (found == false)); + l_ecc_scalar_free(qr); + l_ecc_scalar_free(qnr); + if (!found) { l_error("max PWE iterations reached!"); return false; } - if (!ecc_compute_y(pwe->y, pwe->x)) { - /* should always return true */ + if (!(save[31] & 1)) + sm->pwe = l_ecc_point_from_data(sm->curve, + L_ECC_POINT_TYPE_COMPRESSED_BIT1, + x, sizeof(x)); + else + sm->pwe = l_ecc_point_from_data(sm->curve, + L_ECC_POINT_TYPE_COMPRESSED_BIT0, + x, sizeof(x)); + + if (!sm->pwe) { l_error("computing y failed, was x quadratic residue?"); return false; } - if ((pwe->y[0] & 1) != (save[31] & 1)) - vli_mod_sub(pwe->y, curve_p, pwe->y, curve_p); - - return true; -} - -/* commit-scalar = (rand + mask) mod r */ -static void sae_get_commit_scalar(uint64_t *scalar, uint64_t *mask, - uint64_t *rand) -{ - uint64_t _1[NUM_ECC_DIGITS] = { 1ull }; - - l_getrandom(rand, ECC_BYTES); - - /* ensure 1 < p_rand < r */ - while (!((vli_cmp(rand, _1) > 0) && - (vli_cmp(rand, curve_n) < 0))) - l_getrandom(rand, ECC_BYTES); - - l_getrandom(mask, ECC_BYTES); - - /* ensure 1 < p_mask < r */ - while (!((vli_cmp(mask, _1) > 0) && - (vli_cmp(mask, curve_n) < 0))) - l_getrandom(mask, ECC_BYTES); - - /* (rand + mask) mod r */ - vli_mod_add(scalar, rand, mask, curve_n); -} - -/* commit-element = inv(mask * PWE) */ -static bool sae_get_commit_element(struct ecc_point *element, - struct ecc_point *pwe, uint64_t *mask) -{ - /* p_mask * PWE */ - ecc_point_mult(element, pwe, mask, NULL, vli_num_bits(mask)); - - if (!ecc_valid_point(element)) - return false; - - /* inv(p_mask * PWE) */ - vli_sub(element->y, curve_p, element->y); - - if (!ecc_valid_point(element)) - return false; - return true; } @@ -335,10 +318,9 @@ static bool sae_build_commit(struct sae_sm *sm, const uint8_t *addr1, const uint8_t *addr2, uint8_t *commit, size_t *len, bool retry) { - uint64_t scalar[NUM_ECC_DIGITS]; - uint64_t mask[NUM_ECC_DIGITS]; - struct ecc_point element; + struct l_ecc_scalar *mask; uint8_t *ptr = commit; + struct l_ecc_scalar *order; if (retry) goto old_commit; @@ -348,18 +330,28 @@ static bool sae_build_commit(struct sae_sm *sm, const uint8_t *addr1, return false; } - if (!sae_compute_pwe(sm->handshake->passphrase, addr1, - addr2, &sm->pwe)) { + if (!sae_compute_pwe(sm, sm->handshake->passphrase, addr1, addr2)) { l_error("could not compute PWE"); return false; } - sae_get_commit_scalar(sm->scalar, mask, sm->rand); + sm->scalar = l_ecc_scalar_new(sm->curve, NULL, 0); + sm->rand = l_ecc_scalar_new_random(sm->curve); + mask = l_ecc_scalar_new_random(sm->curve); - if (!sae_get_commit_element(&sm->element, &sm->pwe, mask)) { - l_error("error calculating commit element"); - return false; - } + order = l_ecc_curve_get_order(sm->curve); + + /* commit-scalar = (rand + mask) mod r */ + l_ecc_scalar_add(sm->scalar, sm->rand, mask, order); + + l_ecc_scalar_free(order); + + /* commit-element = inv(mask * PWE) */ + sm->element = l_ecc_point_new(sm->curve); + l_ecc_point_multiply(sm->element, mask, sm->pwe); + l_ecc_point_inverse(sm->element); + + l_ecc_scalar_free(mask); /* * Several cases require retransmitting the same commit message. The @@ -367,13 +359,6 @@ static bool sae_build_commit(struct sae_sm *sm, const uint8_t *addr1, * timeout. */ old_commit: - memcpy(scalar, sm->scalar, 32); - memcpy(element.x, sm->element.x, 32); - memcpy(element.y, sm->element.y, 32); - - ecc_native2be(scalar); - ecc_native2be(element.x); - ecc_native2be(element.y); /* transaction */ l_put_le16(1, ptr); @@ -390,12 +375,8 @@ old_commit: ptr += sm->token_len; } - memcpy(ptr, scalar, 32); - ptr += 32; - memcpy(ptr, element.x, 32); - ptr += 32; - memcpy(ptr, element.y, 32); - ptr += 32; + ptr += l_ecc_scalar_get_data(sm->scalar, ptr, L_ECC_SCALAR_MAX_BYTES); + ptr += l_ecc_point_get_data(sm->element, ptr, L_ECC_POINT_MAX_BYTES); *len = ptr - commit; @@ -408,30 +389,12 @@ static void sae_send_confirm(struct sae_sm *sm) uint8_t body[38]; uint8_t *ptr = body; - ecc_native2be(sm->scalar); - ecc_native2be(sm->element.x); - ecc_native2be(sm->element.y); - ecc_native2be(sm->p_scalar); - ecc_native2be(sm->p_element.x); - ecc_native2be(sm->p_element.y); - /* * confirm = CN(KCK, send-confirm, commit-scalar, COMMIT-ELEMENT, * peer-commit-scalar, PEER-COMMIT-ELEMENT) */ - sae_cn(sm->kck, sm->sc, sm->scalar, (uint64_t *) &sm->element, - sm->p_scalar, (uint64_t *) &sm->p_element, confirm); - - /* - * in case of retransmition, we will need to reuse these values, so - * go back to native endianness for consistency. - */ - ecc_be2native(sm->scalar); - ecc_be2native(sm->element.x); - ecc_be2native(sm->element.y); - ecc_be2native(sm->p_scalar); - ecc_be2native(sm->p_element.x); - ecc_be2native(sm->p_element.y); + sae_cn(sm->kck, sm->sc, sm->scalar, sm->element, sm->p_scalar, + sm->p_element, confirm); l_put_le16(2, ptr); ptr += 2; @@ -451,14 +414,17 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, const uint8_t *frame, size_t len) { uint8_t *ptr = (uint8_t *) frame; - uint64_t k[NUM_ECC_DIGITS]; - struct ecc_point k_point; + uint8_t k[L_ECC_SCALAR_MAX_BYTES]; + struct l_ecc_point *k_point; uint8_t zero_key[32] = { 0 }; uint8_t keyseed[32]; uint8_t kck_and_pmk[2][32]; - uint64_t tmp[NUM_ECC_DIGITS]; + uint8_t tmp[L_ECC_SCALAR_MAX_BYTES]; + struct l_ecc_scalar *tmp_scalar; uint16_t group; uint16_t reason = MMPDU_REASON_CODE_UNSPECIFIED; + ssize_t klen; + struct l_ecc_scalar *order; if (sm->state != SAE_STATE_COMMITTED) { l_error("bad state %u", sm->state); @@ -479,19 +445,14 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, goto reject; } - memcpy(sm->p_scalar, ptr, 32); + sm->p_scalar = l_ecc_scalar_new(sm->curve, ptr, 32); ptr += 32; - memcpy(sm->p_element.x, ptr, 32); - ptr += 32; - memcpy(sm->p_element.y, ptr, 32); - ecc_be2native(sm->p_scalar); - ecc_be2native(sm->p_element.x); - ecc_be2native(sm->p_element.y); + sm->p_element = l_ecc_point_from_data(sm->curve, L_ECC_POINT_TYPE_FULL, + ptr, 64); - if (!memcmp(sm->p_scalar, sm->scalar, 32) || - !memcmp(sm->p_element.x, sm->element.x, 32) || - !memcmp(sm->p_element.y, sm->element.y, 32)) { + if (l_ecc_scalars_are_equal(sm->p_scalar, sm->scalar) || + l_ecc_points_are_equal(sm->p_element, sm->element)) { /* possible reflection attack, silently discard message */ l_warn("peer scalar or element matched own, discarding frame"); @@ -504,17 +465,16 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, * K = scalar-op(rand, (element-op(scalar-op(peer-commit-scalar, PWE), * PEER-COMMIT-ELEMENT))) */ + k_point = l_ecc_point_new(sm->curve); /* k_point = scalar-op(peer-commit-scalar, PWE) */ - ecc_point_mult(&k_point, &sm->pwe, sm->p_scalar, NULL, - vli_num_bits(sm->p_scalar)); + l_ecc_point_multiply(k_point, sm->p_scalar, sm->pwe); /* k_point = element-op(k_point, PEER-COMMIT-ELEMENT) */ - ecc_point_add(&k_point, &k_point, &sm->p_element); + l_ecc_point_add(k_point, k_point, sm->p_element); /* k_point = scalar-op(rand, k_point) */ - ecc_point_mult(&k_point, &k_point, sm->rand, NULL, - vli_num_bits(sm->rand)); + l_ecc_point_multiply(k_point, sm->rand, k_point); /* * IEEE 802.11-2016 - Section 12.4.4.2.1 ECC group definition @@ -522,20 +482,22 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, * point (x, y) that satisfies the curve equation to its x-coordinate— * i.e., if P = (x, y) then F(P) = x. */ - memcpy(k, k_point.x, 32); + klen = l_ecc_point_get_x(k_point, k, sizeof(k)); - ecc_native2be(k); + l_ecc_point_free(k_point); /* keyseed = H(<0>32, k) */ - hmac_sha256(zero_key, 32, k, 32, keyseed, 32); + hmac_sha256(zero_key, 32, k, klen, keyseed, 32); /* * kck_and_pmk = KDF-Hash-512(keyseed, "SAE KCK and PMK", (commit-scalar + peer-commit-scalar) mod r) */ - vli_mod_add(tmp, sm->p_scalar, sm->scalar, curve_n); + tmp_scalar = l_ecc_scalar_new(sm->curve, NULL, 0); + order = l_ecc_curve_get_order(sm->curve); - ecc_native2be(tmp); + l_ecc_scalar_add(tmp_scalar, sm->p_scalar, sm->scalar, order); + l_ecc_scalar_get_data(tmp_scalar, tmp, sizeof(tmp)); kdf_sha256(keyseed, 32, "SAE KCK and PMK", strlen("SAE KCK and PMK"), tmp, 32, kck_and_pmk, 64); @@ -546,8 +508,12 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, /* * PMKID = L((commit-scalar + peer-commit-scalar) mod r, 0, 128) */ - vli_mod_add(tmp, sm->scalar, sm->p_scalar, curve_n); - ecc_native2be(tmp); + l_ecc_scalar_add(tmp_scalar, sm->scalar, sm->p_scalar, order); + l_ecc_scalar_get_data(tmp_scalar, tmp, sizeof(tmp)); + + l_ecc_scalar_free(order); + + l_ecc_scalar_free(tmp_scalar); /* don't set the handshakes pmkid until confirm is verified */ memcpy(sm->pmkid, tmp, 16); @@ -564,23 +530,8 @@ static bool sae_verify_confirm(struct sae_sm *sm, const uint8_t *frame) uint8_t check[32]; uint16_t rc = l_get_le16(frame); - ecc_native2be(sm->scalar); - ecc_native2be(sm->element.x); - ecc_native2be(sm->element.y); - ecc_native2be(sm->p_scalar); - ecc_native2be(sm->p_element.x); - ecc_native2be(sm->p_element.y); - - sae_cn(sm->kck, rc, sm->p_scalar, - (const uint64_t *) &sm->p_element, sm->scalar, - (const uint64_t *) &sm->element, check); - - ecc_be2native(sm->scalar); - ecc_be2native(sm->element.x); - ecc_be2native(sm->element.y); - ecc_be2native(sm->p_scalar); - ecc_be2native(sm->p_element.x); - ecc_be2native(sm->p_element.y); + sae_cn(sm->kck, rc, sm->p_scalar, sm->p_element, sm->scalar, + sm->element, check); if (memcmp(frame + 2, check, 32)) { l_error("confirm did not match"); @@ -978,6 +929,7 @@ struct sae_sm *sae_sm_new(struct handshake_state *hs, sae_tx_packet_func_t tx, sm->user_data = user_data; sm->handshake = hs; sm->state = SAE_STATE_NOTHING; + sm->curve = l_ecc_curve_get(19); return sm; } @@ -986,6 +938,13 @@ void sae_sm_free(struct sae_sm *sm) { l_free(sm->token); + l_ecc_scalar_free(sm->scalar); + l_ecc_scalar_free(sm->p_scalar); + l_ecc_scalar_free(sm->rand); + l_ecc_point_free(sm->element); + l_ecc_point_free(sm->p_element); + l_ecc_point_free(sm->pwe); + /* zero out whole structure, including keys */ memset(sm, 0, sizeof(struct sae_sm));