diff --git a/src/erp.c b/src/erp.c index c0eb38a6..08445f9f 100644 --- a/src/erp.c +++ b/src/erp.c @@ -512,10 +512,12 @@ eap_failed: return -EINVAL; } -void erp_get_rmsk(struct erp_state *erp, void **rmsk, size_t *rmsk_len) +const void *erp_get_rmsk(struct erp_state *erp, size_t *rmsk_len) { - *rmsk = erp->rmsk; - *rmsk_len = erp->cache->emsk_len; + if (rmsk_len) + *rmsk_len = erp->cache->emsk_len; + + return erp->rmsk; } void erp_init(void) diff --git a/src/erp.h b/src/erp.h index 517ae911..657b1b7b 100644 --- a/src/erp.h +++ b/src/erp.h @@ -39,7 +39,7 @@ void erp_free(struct erp_state *erp); bool erp_start(struct erp_state *erp); int erp_rx_packet(struct erp_state *erp, const uint8_t *erp_data, size_t len); -void erp_get_rmsk(struct erp_state *erp, void **rmsk, size_t *rmsk_len); +const void *erp_get_rmsk(struct erp_state *erp, size_t *rmsk_len); void erp_cache_add(const char *id, const void *session_id, size_t session_len, const void *emsk, size_t emsk_len, diff --git a/src/fils.c b/src/fils.c index 58a13111..be211c45 100644 --- a/src/fils.c +++ b/src/fils.c @@ -119,9 +119,10 @@ static void fils_erp_tx_func(const uint8_t *eap_data, size_t len, fils->auth(data, ptr - data + tlv_len, fils->user_data); } -static int fils_derive_key_data(struct fils_sm *fils, const void *rmsk, - size_t rmsk_len) +static int fils_derive_key_data(struct fils_sm *fils) { + const void *rmsk; + size_t rmsk_len; struct ie_tlv_builder builder; uint8_t key[FILS_NONCE_LEN * 2]; uint8_t key_data[64 + 48 + 16]; /* largest ICK, KEK, TK */ @@ -133,6 +134,8 @@ static int fils_derive_key_data(struct fils_sm *fils, const void *rmsk, bool sha384; unsigned int ie_len; + rmsk = erp_get_rmsk(fils->erp, &rmsk_len); + /* * IEEE 802.11ai - Section 12.12.2.5.3 */ @@ -289,8 +292,6 @@ void fils_rx_authenticate(struct fils_sm *fils, const uint8_t *frame, const uint8_t *session = NULL; const uint8_t *wrapped = NULL; size_t wrapped_len = 0; - void *rmsk; - size_t rmsk_len; if (!hdr) { l_debug("Auth frame header did not validate"); @@ -353,10 +354,7 @@ void fils_rx_authenticate(struct fils_sm *fils, const uint8_t *frame, if (erp_rx_packet(fils->erp, wrapped, wrapped_len) < 0) goto auth_failed; - erp_get_rmsk(fils->erp, &rmsk, &rmsk_len); - - fils_derive_key_data(fils, rmsk, rmsk_len); - + fils_derive_key_data(fils); return; auth_failed: