diff --git a/src/eapol.c b/src/eapol.c index 6047c5d4..ca4e0464 100644 --- a/src/eapol.c +++ b/src/eapol.c @@ -434,8 +434,8 @@ struct eapol_sm { uint32_t ifindex; uint8_t spa[6]; uint8_t aa[6]; - uint8_t *ap_rsn; - uint8_t *own_rsn; + uint8_t *ap_ie; + uint8_t *own_ie; uint8_t pmk[32]; uint64_t replay_counter; uint8_t snonce[32]; @@ -445,14 +445,15 @@ struct eapol_sm { struct l_timeout *timeout; bool have_snonce:1; bool have_replay:1; + bool wpa_ie:1; }; static void eapol_sm_destroy(void *value) { struct eapol_sm *sm = value; - l_free(sm->ap_rsn); - l_free(sm->own_rsn); + l_free(sm->ap_ie); + l_free(sm->own_ie); l_timeout_remove(sm->timeout); @@ -488,24 +489,48 @@ void eapol_sm_set_pmk(struct eapol_sm *sm, const uint8_t *pmk) memcpy(sm->pmk, pmk, sizeof(sm->pmk)); } -void eapol_sm_set_ap_rsn(struct eapol_sm *sm, const uint8_t *rsn_ie, - size_t len) +static void eapol_sm_set_ap_ie(struct eapol_sm *sm, const uint8_t *ie, + size_t len, bool is_wpa) { - if (rsn_ie[1] + 2u != len) + if (ie[1] + 2u != len) return; - l_free(sm->ap_rsn); - sm->ap_rsn = l_memdup(rsn_ie, len); + l_free(sm->ap_ie); + sm->ap_ie = l_memdup(ie, len); + sm->wpa_ie = is_wpa; +} + +static void eapol_sm_set_own_ie(struct eapol_sm *sm, const uint8_t *ie, + size_t len, bool is_wpa) +{ + if (ie[1] + 2u != len) + return; + + l_free(sm->own_ie); + sm->own_ie = l_memdup(ie, len); + sm->wpa_ie = is_wpa; +} + +void eapol_sm_set_ap_rsn(struct eapol_sm *sm, const uint8_t *rsn_ie, size_t len) +{ + eapol_sm_set_ap_ie(sm, rsn_ie, len, false); } void eapol_sm_set_own_rsn(struct eapol_sm *sm, const uint8_t *rsn_ie, size_t len) { - if (rsn_ie[1] + 2u != len) - return; + eapol_sm_set_own_ie(sm, rsn_ie, len, false); +} - l_free(sm->own_rsn); - sm->own_rsn = l_memdup(rsn_ie, len); +void eapol_sm_set_ap_wpa(struct eapol_sm *sm, const uint8_t *wpa_ie, size_t len) +{ + eapol_sm_set_ap_ie(sm, wpa_ie, len, true); +} + +void eapol_sm_set_own_wpa(struct eapol_sm *sm, const uint8_t *wpa_ie, + size_t len) +{ + eapol_sm_set_own_ie(sm, wpa_ie, len, true); } void eapol_sm_set_user_data(struct eapol_sm *sm, void *user_data) @@ -592,7 +617,7 @@ static void eapol_handle_ptk_1_of_4(uint32_t ifindex, struct eapol_sm *sm, ek->key_descriptor_version, sm->replay_counter, sm->snonce, - sm->own_rsn[1] + 2, sm->own_rsn); + sm->own_ie[1] + 2, sm->own_ie); if (!eapol_calculate_mic(ptk->kck, step2, mic)) { l_info("MIC calculation failed. " @@ -781,7 +806,7 @@ static void eapol_handle_ptk_3_of_4(uint32_t ifindex, return; } - if (!eapol_ap_rsne_matches(rsne, sm->ap_rsn)) { + if (!eapol_ap_rsne_matches(rsne, sm->ap_ie)) { handshake_failed(ifindex, sm, MPDU_REASON_CODE_IE_DIFFERENT); return; } diff --git a/src/eapol.h b/src/eapol.h index 05e2cf63..66342cf7 100644 --- a/src/eapol.h +++ b/src/eapol.h @@ -160,6 +160,10 @@ void eapol_sm_set_ap_rsn(struct eapol_sm *sm, const uint8_t *rsn_ie, size_t len); void eapol_sm_set_own_rsn(struct eapol_sm *sm, const uint8_t *rsn_ie, size_t len); +void eapol_sm_set_ap_wpa(struct eapol_sm *sm, const uint8_t *wpa_ie, + size_t len); +void eapol_sm_set_own_wpa(struct eapol_sm *sm, const uint8_t *wpa_ie, + size_t len); void eapol_sm_set_user_data(struct eapol_sm *sm, void *user_data); struct l_io *eapol_open_pae(uint32_t index);