diff --git a/src/netdev.c b/src/netdev.c index ff574d65..e3b5fbca 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -97,7 +97,6 @@ struct netdev { void *user_data; struct eapol_sm *sm; struct sae_sm *sae_sm; - struct owe_sm *owe; struct auth_proto *ap; struct handshake_state *handshake; uint32_t connect_cmd_id; @@ -557,11 +556,6 @@ static void netdev_connect_free(struct netdev *netdev) netdev->sae_sm = NULL; } - if (netdev->owe) { - owe_sm_free(netdev->owe); - netdev->owe = NULL; - } - if (netdev->ap) { auth_proto_free(netdev->ap); netdev->ap = NULL; @@ -2376,7 +2370,7 @@ static void netdev_authenticate_event(struct l_genl_msg *msg, * the FT Associate command is included in the attached frame and is * not available in the Authenticate command callback. */ - if (!netdev->in_ft && !netdev->sae_sm && !netdev->owe && !netdev->ap) + if (!netdev->in_ft && !netdev->sae_sm && !netdev->ap) return; if (!l_genl_attr_init(&attr, msg)) { @@ -2417,8 +2411,6 @@ static void netdev_authenticate_event(struct l_genl_msg *msg, netdev_sae_process(netdev, ((struct mmpdu_header *)frame)->address_2, frame + 26, frame_len - 26); - else if (netdev->owe) - owe_rx_authenticate(netdev->owe); else if (netdev->ap) { ret = auth_proto_rx_authenticate(netdev->ap, frame, frame_len); if (ret == 0 || ret == -EAGAIN) @@ -2453,8 +2445,7 @@ static void netdev_associate_event(struct l_genl_msg *msg, if (!netdev->connected || netdev->aborting) return; - if (!netdev->owe && !netdev->in_ft && !netdev->handshake->mde && - !netdev->ap) + if (!netdev->in_ft && !netdev->handshake->mde && !netdev->ap) return; if (!l_genl_attr_init(&attr, msg)) { @@ -2483,10 +2474,7 @@ static void netdev_associate_event(struct l_genl_msg *msg, if (!frame) goto assoc_failed; - if (netdev->owe) { - owe_rx_associate(netdev->owe, frame, frame_len); - return; - } else if (netdev->ap) { + if (netdev->ap) { ret = auth_proto_rx_associate(netdev->ap, frame, frame_len); if (ret == 0) { auth_proto_free(netdev->ap); @@ -2749,38 +2737,6 @@ static void netdev_owe_tx_associate(struct iovec *ie_iov, size_t iov_len, } } -static void netdev_owe_complete(uint16_t status, void *user_data) -{ - struct netdev *netdev = user_data; - - switch (status) { - case 0: /* success */ - break; - case MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP: - if (owe_retry(netdev->owe)) { - netdev->ignore_connect_event = true; - return; - } - /* fall through */ - default: - netdev->result = NETDEV_RESULT_ASSOCIATION_FAILED; - netdev->last_code = status; - netdev->expect_connect_failure = true; - - goto free_owe; - } - - netdev->ignore_connect_event = true; - - netdev->sm = eapol_sm_new(netdev->handshake); - eapol_register(netdev->sm); - eapol_start(netdev->sm); - -free_owe: - owe_sm_free(netdev->owe); - netdev->owe = NULL; -} - static void netdev_fils_tx_authenticate(const uint8_t *body, size_t body_len, void *user_data) @@ -2957,8 +2913,6 @@ static int netdev_connect_common(struct netdev *netdev, if (netdev->sae_sm) sae_start(netdev->sae_sm); - else if (netdev->owe) - owe_start(netdev->owe); else auth_proto_start(netdev->ap); @@ -2987,9 +2941,8 @@ int netdev_connect(struct netdev *netdev, struct scan_bss *bss, netdev_sae_complete, netdev); break; case IE_RSN_AKM_SUITE_OWE: - netdev->owe = owe_sm_new(hs, netdev_owe_tx_authenticate, + netdev->ap = owe_sm_new(hs, netdev_owe_tx_authenticate, netdev_owe_tx_associate, - netdev_owe_complete, netdev); break; case IE_RSN_AKM_SUITE_FILS_SHA256: diff --git a/src/owe.c b/src/owe.c index 3e455e68..a8e66da2 100644 --- a/src/owe.c +++ b/src/owe.c @@ -31,8 +31,10 @@ #include "src/handshake.h" #include "src/owe.h" #include "src/mpdu.h" +#include "src/auth-proto.h" struct owe_sm { + struct auth_proto ap; struct handshake_state *hs; const struct l_ecc_curve *curve; struct l_ecc_scalar *private; @@ -43,7 +45,6 @@ struct owe_sm { owe_tx_authenticate_func_t auth_tx; owe_tx_associate_func_t assoc_tx; - owe_complete_func_t complete; void *user_data; }; @@ -71,43 +72,30 @@ static bool owe_reset(struct owe_sm *owe) return true; } -struct owe_sm *owe_sm_new(struct handshake_state *hs, - owe_tx_authenticate_func_t auth, - owe_tx_associate_func_t assoc, - owe_complete_func_t complete, void *user_data) +static void owe_free(struct auth_proto *ap) { - struct owe_sm *owe = l_new(struct owe_sm, 1); + struct owe_sm *owe = l_container_of(ap, struct owe_sm, ap); - owe->hs = hs; - owe->auth_tx = auth; - owe->assoc_tx = assoc; - owe->user_data = user_data; - owe->complete = complete; - owe->ecc_groups = l_ecc_curve_get_supported_ike_groups(); - - if (!owe_reset(owe)) { - l_free(owe); - return NULL; - } - - return owe; -} - -void owe_sm_free(struct owe_sm *owe) -{ l_ecc_scalar_free(owe->private); l_ecc_point_free(owe->public_key); l_free(owe); } -void owe_start(struct owe_sm *owe) +static bool owe_start(struct auth_proto *ap) { + struct owe_sm *owe = l_container_of(ap, struct owe_sm, ap); + owe->auth_tx(owe->user_data); + + return true; } -void owe_rx_authenticate(struct owe_sm *owe) +static int owe_rx_authenticate(struct auth_proto *ap, const uint8_t *frame, + size_t frame_len) { + struct owe_sm *owe = l_container_of(ap, struct owe_sm, ap); + uint8_t buf[5 + L_ECC_SCALAR_MAX_BYTES]; struct iovec iov[3]; int iov_elems = 0; @@ -138,6 +126,8 @@ void owe_rx_authenticate(struct owe_sm *owe) iov_elems++; owe->assoc_tx(iov, iov_elems, owe->user_data); + + return 0; } /* @@ -229,8 +219,26 @@ failed: return false; } -void owe_rx_associate(struct owe_sm *owe, const uint8_t *frame, size_t len) +static bool owe_retry(struct owe_sm *owe) { + /* retry with another group, if possible */ + owe->retry++; + + if (!owe_reset(owe)) + return false; + + l_debug("OWE retrying with group %u", owe->group); + + owe_rx_authenticate(&owe->ap, NULL, 0); + + return true; +} + +static int owe_rx_associate(struct auth_proto *ap, const uint8_t *frame, + size_t len) +{ + struct owe_sm *owe = l_container_of(ap, struct owe_sm, ap); + const struct mmpdu_header *mpdu = NULL; const struct mmpdu_association_response *body; struct ie_tlv_iter iter; @@ -243,15 +251,18 @@ void owe_rx_associate(struct owe_sm *owe, const uint8_t *frame, size_t len) mpdu = mpdu_validate(frame, len); if (!mpdu) { l_error("could not process frame"); - goto owe_failed; + return -EBADMSG; } body = mmpdu_body(mpdu); - if (body->status_code) { - owe->complete(body->status_code, owe->user_data); - return; - } + if (body->status_code == MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP) { + if (!owe_retry(owe)) + goto owe_bad_status; + + return -EAGAIN; + } else if (body->status_code) + goto owe_bad_status; ie_tlv_iter_init(&iter, body->ies, (const uint8_t *) mpdu + len - body->ies); @@ -271,7 +282,7 @@ void owe_rx_associate(struct owe_sm *owe, const uint8_t *frame, size_t len) case IE_TYPE_RSN: if (ie_parse_rsne(&iter, &info) < 0) { l_error("could not parse RSN IE"); - goto owe_failed; + goto invalid_ies; } /* @@ -282,7 +293,7 @@ void owe_rx_associate(struct owe_sm *owe, const uint8_t *frame, size_t len) */ if (info.akm_suites != IE_RSN_AKM_SUITE_OWE) { l_error("OWE AKM not included"); - goto owe_failed; + goto invalid_ies; } akm_found = true; @@ -295,39 +306,51 @@ void owe_rx_associate(struct owe_sm *owe, const uint8_t *frame, size_t len) if (!owe_dh || owe_dh_len < 34 || !akm_found) { l_error("associate response did not include proper OWE IE's"); - goto owe_failed; + goto invalid_ies; } if (l_get_le16(owe_dh) != owe->group) { l_error("associate response contained unsupported group %u", l_get_le16(owe_dh)); - goto owe_failed; + return -EBADMSG; } if (!owe_compute_keys(owe, owe_dh + 2, owe_dh_len - 2)) { l_error("could not compute OWE keys"); - goto owe_failed; + return -EBADMSG; } - owe->complete(0, owe->user_data); + return 0; - return; +invalid_ies: + return MMPDU_STATUS_CODE_INVALID_ELEMENT; -owe_failed: - owe->complete(MMPDU_REASON_CODE_UNSPECIFIED, owe->user_data); +owe_bad_status: + return (int)body->status_code; } -bool owe_retry(struct owe_sm *owe) +struct auth_proto *owe_sm_new(struct handshake_state *hs, + owe_tx_authenticate_func_t auth, + owe_tx_associate_func_t assoc, + void *user_data) { - /* retry with another group, if possible */ - owe->retry++; + struct owe_sm *owe = l_new(struct owe_sm, 1); - if (!owe_reset(owe)) - return false; + owe->hs = hs; + owe->auth_tx = auth; + owe->assoc_tx = assoc; + owe->user_data = user_data; + owe->ecc_groups = l_ecc_curve_get_supported_ike_groups(); - l_debug("OWE retrying with group %u", owe->group); + owe->ap.start = owe_start; + owe->ap.free = owe_free; + owe->ap.rx_authenticate = owe_rx_authenticate; + owe->ap.rx_associate = owe_rx_associate; - owe_rx_authenticate(owe); + if (!owe_reset(owe)) { + l_free(owe); + return NULL; + } - return true; + return &owe->ap; } diff --git a/src/owe.h b/src/owe.h index 75c91bba..f2c0d3e9 100644 --- a/src/owe.h +++ b/src/owe.h @@ -26,15 +26,8 @@ struct handshake_state; typedef void (*owe_tx_authenticate_func_t)(void *user_data); typedef void (*owe_tx_associate_func_t)(struct iovec *ie_iov, size_t iov_len, void *user_data); -typedef void (*owe_complete_func_t)(uint16_t status, void *user_data); -struct owe_sm *owe_sm_new(struct handshake_state *hs, +struct auth_proto *owe_sm_new(struct handshake_state *hs, owe_tx_authenticate_func_t auth, owe_tx_associate_func_t assoc, - owe_complete_func_t complete, void *user_data); -void owe_sm_free(struct owe_sm *owe); - -void owe_start(struct owe_sm *owe); -bool owe_retry(struct owe_sm *owe); -void owe_rx_authenticate(struct owe_sm *owe); -void owe_rx_associate(struct owe_sm *owe, const uint8_t *frame, size_t len); + void *user_data);