diff --git a/src/netdev.c b/src/netdev.c index e3b5fbca..e35e7df4 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -551,11 +551,6 @@ static void netdev_connect_free(struct netdev *netdev) netdev->sm = NULL; } - if (netdev->sae_sm) { - sae_sm_free(netdev->sae_sm); - netdev->sae_sm = NULL; - } - if (netdev->ap) { auth_proto_free(netdev->ap); netdev->ap = NULL; @@ -2336,12 +2331,6 @@ ft_error: netdev, NULL); } -static void netdev_sae_process(struct netdev *netdev, const uint8_t *from, - const uint8_t *frame, size_t frame_len) -{ - sae_rx_packet(netdev->sae_sm, from, frame, frame_len); -} - static void netdev_authenticate_event(struct l_genl_msg *msg, struct netdev *netdev) { @@ -2370,7 +2359,7 @@ static void netdev_authenticate_event(struct l_genl_msg *msg, * the FT Associate command is included in the attached frame and is * not available in the Authenticate command callback. */ - if (!netdev->in_ft && !netdev->sae_sm && !netdev->ap) + if (!netdev->in_ft && !netdev->ap) return; if (!l_genl_attr_init(&attr, msg)) { @@ -2384,10 +2373,7 @@ static void netdev_authenticate_event(struct l_genl_msg *msg, case NL80211_ATTR_TIMED_OUT: l_warn("authentication timed out"); - if (netdev->sae_sm) { - sae_timeout(netdev->sae_sm); - return; - } else if (auth_proto_auth_timeout(netdev->ap)) + if (auth_proto_auth_timeout(netdev->ap)) return; goto auth_error; @@ -2407,10 +2393,6 @@ static void netdev_authenticate_event(struct l_genl_msg *msg, if (netdev->in_ft) netdev_ft_process_authenticate(netdev, frame, frame_len); - else if (netdev->sae_sm) - netdev_sae_process(netdev, - ((struct mmpdu_header *)frame)->address_2, - frame + 26, frame_len - 26); else if (netdev->ap) { ret = auth_proto_rx_authenticate(netdev->ap, frame, frame_len); if (ret == 0 || ret == -EAGAIN) @@ -2553,12 +2535,6 @@ static void netdev_cmd_connect_cb(struct l_genl_msg *msg, void *user_data) NULL, netdev->user_data); - /* the SAE SM can be freed */ - if (netdev->sae_sm) { - sae_sm_free(netdev->sae_sm); - netdev->sae_sm = NULL; - } - /* * We register the eapol state machine here, in case the PAE * socket receives EAPoL packets before the nl80211 socket @@ -2597,21 +2573,57 @@ static struct l_genl_msg *netdev_build_cmd_authenticate(struct netdev *netdev, return msg; } -static void netdev_sae_complete(uint16_t status, void *user_data) +static void netdev_auth_cb(struct l_genl_msg *msg, void *user_data) +{ + struct netdev *netdev = user_data; + + if (l_genl_msg_get_error(msg) < 0) { + l_error("Error sending CMD_AUTHENTICATE"); + + netdev_connect_failed(netdev, + NETDEV_RESULT_AUTHENTICATION_FAILED, + MMPDU_STATUS_CODE_UNSPECIFIED); + return; + } +} + +static void netdev_assoc_cb(struct l_genl_msg *msg, void *user_data) +{ + struct netdev *netdev = user_data; + + if (l_genl_msg_get_error(msg) < 0) { + l_error("Error sending CMD_ASSOCIATE"); + + netdev_connect_failed(netdev, NETDEV_RESULT_ASSOCIATION_FAILED, + MMPDU_STATUS_CODE_UNSPECIFIED); + } +} + +static void netdev_sae_tx_authenticate(const uint8_t *body, + size_t body_len, void *user_data) { struct netdev *netdev = user_data; struct l_genl_msg *msg; - struct iovec iov[3]; - int iov_elems = 0; - sae_sm_free(netdev->sae_sm); - netdev->sae_sm = NULL; + msg = netdev_build_cmd_authenticate(netdev, NL80211_AUTHTYPE_SAE, + netdev->handshake->aa); - if (status != 0) { - l_error("SAE exchange failed on %u result %u", - netdev->index, status); - goto auth_failed; + l_genl_msg_append_attr(msg, NL80211_ATTR_AUTH_DATA, body_len, body); + + if (!l_genl_family_send(nl80211, msg, netdev_auth_cb, netdev, NULL)) { + l_genl_msg_unref(msg); + netdev_connect_failed(netdev, + NETDEV_RESULT_AUTHENTICATION_FAILED, + MMPDU_STATUS_CODE_UNSPECIFIED); } +} + +static void netdev_sae_tx_associate(void *user_data) +{ + struct netdev *netdev = user_data; + struct l_genl_msg *msg; + struct iovec iov[2]; + int iov_elems = 0; msg = netdev_build_cmd_associate_common(netdev); @@ -2627,65 +2639,10 @@ static void netdev_sae_complete(uint16_t status, void *user_data) l_genl_msg_append_attrv(msg, NL80211_ATTR_IE, iov, iov_elems); - /* netdev_cmd_connect_cb can be reused */ - netdev->connect_cmd_id = l_genl_family_send(nl80211, msg, - netdev_cmd_connect_cb, - netdev, NULL); - - if (!netdev->connect_cmd_id) - goto auth_failed; - - /* - * Kick off EAPoL sm early in case the first EAPoL packet comes prior - * to the netdev_cmd_connect_cb - */ - netdev->sm = eapol_sm_new(netdev->handshake); - return; - -auth_failed: - netdev_connect_failed(netdev, NETDEV_RESULT_AUTHENTICATION_FAILED, - MMPDU_STATUS_CODE_UNSPECIFIED); -} - -static void netdev_tx_sae_frame_cb(struct l_genl_msg *msg, - void *user_data) -{ - int err = l_genl_msg_get_error(msg); - - if (err < 0) - l_debug("SAE: CMD_AUTHENTICATE failed: %s", strerror(err)); -} - -static int netdev_tx_sae_frame(const uint8_t *dest, const uint8_t *body, - size_t body_len, void *user_data) -{ - struct netdev *netdev = user_data; - struct l_genl_msg *msg; - - msg = netdev_build_cmd_authenticate(netdev, NL80211_AUTHTYPE_SAE, dest); - - l_genl_msg_append_attr(msg, NL80211_ATTR_AUTH_DATA, body_len, body); - - if (!l_genl_family_send(nl80211, msg, netdev_tx_sae_frame_cb, - netdev, NULL)) { + if (!l_genl_family_send(nl80211, msg, netdev_assoc_cb, netdev, NULL)) { l_genl_msg_unref(msg); - return -EINVAL; - } - - return 0; -} - -static void netdev_auth_cb(struct l_genl_msg *msg, void *user_data) -{ - struct netdev *netdev = user_data; - - if (l_genl_msg_get_error(msg) < 0) { - l_error("Error sending CMD_AUTHENTICATE"); - - netdev_connect_failed(netdev, - NETDEV_RESULT_AUTHENTICATION_FAILED, + netdev_connect_failed(netdev, NETDEV_RESULT_ASSOCIATION_FAILED, MMPDU_STATUS_CODE_UNSPECIFIED); - return; } } @@ -2707,18 +2664,6 @@ static void netdev_owe_tx_authenticate(void *user_data) } } -static void netdev_assoc_cb(struct l_genl_msg *msg, void *user_data) -{ - struct netdev *netdev = user_data; - - if (l_genl_msg_get_error(msg) < 0) { - l_error("Error sending CMD_ASSOCIATE"); - - netdev_connect_failed(netdev, NETDEV_RESULT_ASSOCIATION_FAILED, - MMPDU_STATUS_CODE_UNSPECIFIED); - } -} - static void netdev_owe_tx_associate(struct iovec *ie_iov, size_t iov_len, void *user_data) { @@ -2911,10 +2856,7 @@ static int netdev_connect_common(struct netdev *netdev, NL80211_EXT_FEATURE_CAN_REPLACE_PTK0)) handshake_state_set_no_rekey(hs, true); - if (netdev->sae_sm) - sae_start(netdev->sae_sm); - else - auth_proto_start(netdev->ap); + auth_proto_start(netdev->ap); return 0; } @@ -2937,8 +2879,9 @@ int netdev_connect(struct netdev *netdev, struct scan_bss *bss, switch (hs->akm_suite) { case IE_RSN_AKM_SUITE_SAE_SHA256: case IE_RSN_AKM_SUITE_FT_OVER_SAE_SHA256: - netdev->sae_sm = sae_sm_new(hs, netdev_tx_sae_frame, - netdev_sae_complete, netdev); + netdev->ap = sae_sm_new(hs, netdev_sae_tx_authenticate, + netdev_sae_tx_associate, + netdev); break; case IE_RSN_AKM_SUITE_OWE: netdev->ap = owe_sm_new(hs, netdev_owe_tx_authenticate, diff --git a/src/sae.c b/src/sae.c index 49089adf..0922b278 100644 --- a/src/sae.c +++ b/src/sae.c @@ -32,6 +32,7 @@ #include "src/crypto.h" #include "src/mpdu.h" #include "src/sae.h" +#include "src/auth-proto.h" #define SAE_RETRANSMIT_TIMEOUT 2 #define SAE_SYNC_MAX 3 @@ -44,6 +45,7 @@ enum sae_state { }; struct sae_sm { + struct auth_proto ap; struct handshake_state *handshake; struct l_ecc_point *pwe; enum sae_state state; @@ -71,8 +73,8 @@ struct sae_sm { /* remote peer */ uint8_t peer[6]; - sae_tx_packet_func_t tx; - sae_complete_func_t complete; + sae_tx_authenticate_func_t tx_auth; + sae_tx_associate_func_t tx_assoc; void *user_data; }; @@ -109,7 +111,7 @@ static struct l_ecc_scalar *sae_pwd_value(const struct l_ecc_curve *curve, if (!kdf_sha256(pwd_seed, 32, "SAE Hunting and Pecking", strlen("SAE Hunting and Pecking"), prime, len, pwd_value, len)) - return false; + return NULL; return l_ecc_scalar_new(curve, pwd_value, sizeof(pwd_value)); } @@ -154,11 +156,6 @@ static bool sae_cn(const uint8_t *kck, uint16_t send_confirm, return (ret == 32); } -static void sae_authentication_failed(struct sae_sm *sm, uint16_t reason) -{ - sm->complete(reason, sm->user_data); -} - static void sae_reject_authentication(struct sae_sm *sm, uint16_t reason) { uint8_t reject[6]; @@ -176,9 +173,7 @@ static void sae_reject_authentication(struct sae_sm *sm, uint16_t reason) ptr += 2; } - sm->tx(sm->peer, reject, ptr - reject, sm->user_data); - - sae_authentication_failed(sm, reason); + sm->tx_auth(reject, ptr - reject, sm->user_data); } static struct l_ecc_scalar *sae_new_residue(const struct l_ecc_curve *curve, @@ -410,10 +405,10 @@ static void sae_send_confirm(struct sae_sm *sm) sm->state = SAE_STATE_CONFIRMED; - sm->tx(sm->peer, body, 38, sm->user_data); + sm->tx_auth(body, 38, sm->user_data); } -static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, +static int sae_process_commit(struct sae_sm *sm, const uint8_t *from, const uint8_t *frame, size_t len) { uint8_t *ptr = (uint8_t *) frame; @@ -445,9 +440,9 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, ptr += 2; if (group != sm->group) { - l_error("unsupported group: %u", group); - reason = MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP; - goto reject; + sae_reject_authentication(sm, + MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP); + return 0; } sm->p_scalar = l_ecc_scalar_new(sm->curve, ptr, nbytes); @@ -472,7 +467,7 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, /* possible reflection attack, silently discard message */ l_warn("peer scalar or element matched own, discarding frame"); - return; + return 0; } sm->sc++; @@ -535,10 +530,11 @@ static void sae_process_commit(struct sae_sm *sm, const uint8_t *from, sae_send_confirm(sm); - return; + return 0; reject: sae_reject_authentication(sm, reason); + return -EBADMSG; } static bool sae_verify_confirm(struct sae_sm *sm, const uint8_t *frame) @@ -559,7 +555,7 @@ static bool sae_verify_confirm(struct sae_sm *sm, const uint8_t *frame) return true; } -static void sae_process_confirm(struct sae_sm *sm, const uint8_t *from, +static int sae_process_confirm(struct sae_sm *sm, const uint8_t *from, const uint8_t *frame, size_t len) { const uint8_t *ptr = frame; @@ -583,17 +579,18 @@ static void sae_process_confirm(struct sae_sm *sm, const uint8_t *from, handshake_state_set_pmkid(sm->handshake, sm->pmkid); handshake_state_set_pmk(sm->handshake, sm->pmk, 32); - sm->complete(0, sm->user_data); - sm->state = SAE_STATE_ACCEPTED; - return; + sm->tx_assoc(sm->user_data); + + return 0; reject: sae_reject_authentication(sm, MMPDU_REASON_CODE_UNSPECIFIED); + return -EBADMSG; } -static void sae_send_commit(struct sae_sm *sm, bool retry) +static bool sae_send_commit(struct sae_sm *sm, bool retry) { struct handshake_state *hs = sm->handshake; /* regular commit + possible 256 byte token + 6 bytes header */ @@ -601,19 +598,23 @@ static void sae_send_commit(struct sae_sm *sm, bool retry) size_t len; if (!sae_build_commit(sm, hs->spa, hs->aa, commit, &len, retry)) - return; + return false; sm->state = SAE_STATE_COMMITTED; - sm->tx(sm->peer, commit, len, sm->user_data); + sm->tx_auth(commit, len, sm->user_data); + + return true; } -void sae_timeout(struct sae_sm *sm) +static bool sae_timeout(struct auth_proto *ap) { + struct sae_sm *sm = l_container_of(ap, struct sae_sm, ap); + /* regardless of state, reject if sync exceeds max */ if (sm->sync > SAE_SYNC_MAX) { sae_reject_authentication(sm, MMPDU_REASON_CODE_UNSPECIFIED); - return; + return false; } sm->sync++; @@ -629,8 +630,10 @@ void sae_timeout(struct sae_sm *sm) default: /* should never happen */ l_error("SAE timeout in bad state %u", sm->state); - return; + return false; } + + return true; } /* @@ -664,7 +667,7 @@ static void sae_process_anti_clogging(struct sae_sm *sm, const uint8_t *ptr, /* * 802.11-2016 - 12.4.8.6.3 Protocol instance behavior - Nothing state */ -static bool sae_verify_nothing(struct sae_sm *sm, uint16_t transaction, +static int sae_verify_nothing(struct sae_sm *sm, uint16_t transaction, uint16_t status, const uint8_t *frame, size_t len) { @@ -674,22 +677,20 @@ static bool sae_verify_nothing(struct sae_sm *sm, uint16_t transaction, * yet supported. */ if (transaction != SAE_STATE_COMMITTED) - return false; + return -EBADMSG; /* frame shall be silently discarded and Del event sent */ - if (status != 0) { - sae_authentication_failed(sm, MMPDU_REASON_CODE_UNSPECIFIED); - return false; - } + if (status != 0) + return -EBADMSG; /* reject with unsupported group */ if (l_get_le16(frame) != sm->group) { sae_reject_authentication(sm, MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP); - return false; + return -EBADMSG; } - return true; + return 0; } static void sae_reset_state(struct sae_sm *sm) @@ -714,7 +715,7 @@ static void sae_reset_state(struct sae_sm *sm) /* * 802.11-2016 - 12.4.8.6.4 Protocol instance behavior - Committed state */ -static bool sae_verify_committed(struct sae_sm *sm, uint16_t transaction, +static int sae_verify_committed(struct sae_sm *sm, uint16_t transaction, uint16_t status, const uint8_t *frame, size_t len) { @@ -728,23 +729,20 @@ static bool sae_verify_committed(struct sae_sm *sm, uint16_t transaction, * the peer... */ if (transaction == SAE_STATE_CONFIRMED) { - if (sm->sync > SAE_SYNC_MAX) { - sae_authentication_failed(sm, - MMPDU_REASON_CODE_UNSPECIFIED); - return false; - } + if (sm->sync > SAE_SYNC_MAX) + return -EBADMSG; sm->sync++; sae_send_commit(sm, true); - return false; + return -EAGAIN; } switch (status) { case MMPDU_STATUS_CODE_ANTI_CLOGGING_TOKEN_REQ: sae_process_anti_clogging(sm, frame, len); - return false; + return -EAGAIN; case MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP: /* * TODO: hostapd in its current state does not include the @@ -766,7 +764,7 @@ static bool sae_verify_committed(struct sae_sm *sm, uint16_t transaction, if (len == 0) l_warn("AP did not include group number in response!"); else if (len >= 2 && (l_get_le16(frame) != sm->group)) - return false; + return -EBADMSG; sm->group_retry++; @@ -798,23 +796,17 @@ static bool sae_verify_committed(struct sae_sm *sm, uint16_t transaction, sae_send_commit(sm, false); - return false; + return -EAGAIN; case 0: - if (len < 2) { - sae_authentication_failed(sm, - MMPDU_REASON_CODE_UNSPECIFIED); - return false; - } + if (len < 2) + return -EBADMSG; if (l_get_le16(frame) == sm->group) - return true; + return 0; if (!l_ecc_curve_get_ike_group(l_get_le16(frame))) { - if (sm->sync > SAE_SYNC_MAX) { - sae_authentication_failed(sm, - MMPDU_REASON_CODE_UNSPECIFIED); - return false; - } + if (sm->sync > SAE_SYNC_MAX) + return -EBADMSG; sm->sync++; @@ -842,7 +834,7 @@ static bool sae_verify_committed(struct sae_sm *sm, uint16_t transaction, */ sae_send_commit(sm, true); - return false; + return 0; } /* @@ -872,50 +864,48 @@ static bool sae_verify_committed(struct sae_sm *sm, uint16_t transaction, * called. */ - return true; + return 0; default: /* * If the Status is some other nonzero value, the frame shall * be silently discarded... */ - return false; + return 0; } reject_unsupp_group: sae_reject_authentication(sm, MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP); - return false; + return MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP; } /* * 802.11-2016 - 12.4.8.6.5 Protocol instance behavior - Confirmed state */ -static bool sae_verify_confirmed(struct sae_sm *sm, uint16_t trans, +static int sae_verify_confirmed(struct sae_sm *sm, uint16_t trans, uint16_t status, const uint8_t *frame, size_t len) { if (trans == SAE_STATE_CONFIRMED) - return true; + return 0; /* * If the Status is nonzero, the frame shall be silently discarded... */ if (status != 0) - return false; + return 0; /* * If Sync is greater than dot11RSNASAESync, the protocol instance * shall send the parent process a Del event and transitions back to * Nothing state. */ - if (sm->sync > SAE_SYNC_MAX) { - sae_authentication_failed(sm, MMPDU_REASON_CODE_UNSPECIFIED); - return false; - } + if (sm->sync > SAE_SYNC_MAX) + return -EBADMSG; /* frame shall be silently discarded */ if (l_get_le16(frame) != sm->group) - return false; + return 0; /* * the protocol instance shall increment Sync, increment Sc, and @@ -927,7 +917,7 @@ static bool sae_verify_confirmed(struct sae_sm *sm, uint16_t trans, sae_send_commit(sm, true); sae_send_confirm(sm); - return false; + return 0; } /* @@ -945,10 +935,8 @@ static bool sae_verify_accepted(struct sae_sm *sm, uint16_t trans, return false; } - if (sm->sync > SAE_SYNC_MAX) { - sae_authentication_failed(sm, MMPDU_REASON_CODE_UNSPECIFIED); + if (sm->sync > SAE_SYNC_MAX) return false; - } sc = l_get_le16(frame); @@ -985,12 +973,12 @@ static bool sae_verify_accepted(struct sae_sm *sm, uint16_t trans, return false; } -static bool sae_verify_packet(struct sae_sm *sm, uint16_t trans, +static int sae_verify_packet(struct sae_sm *sm, uint16_t trans, uint16_t status, const uint8_t *frame, size_t len) { if (trans != SAE_STATE_COMMITTED && trans != SAE_STATE_CONFIRMED) - return false; + return -EBADMSG; switch (sm->state) { case SAE_STATE_NOTHING: @@ -1004,26 +992,36 @@ static bool sae_verify_packet(struct sae_sm *sm, uint16_t trans, } /* should never get here */ - return false; + return -1; } -void sae_rx_packet(struct sae_sm *sm, const uint8_t *from, const uint8_t *frame, - size_t len) +static int sae_rx_authenticate(struct auth_proto *ap, + const uint8_t *frame, size_t len) { - uint16_t transaction; - uint16_t status; - const uint8_t *ptr = frame; + struct sae_sm *sm = l_container_of(ap, struct sae_sm, ap); + const struct mmpdu_header *hdr = mpdu_validate(frame, len); + const struct mmpdu_authentication *auth; + int ret; + + if (!hdr) { + l_debug("Auth frame header did not validate"); + goto reject; + } + + auth = mmpdu_body(hdr); + + if (!auth) { + l_debug("Auth frame body did not validate"); + goto reject; + } + + len -= sizeof(struct mmpdu_header); if (len < 4) { l_error("bad packet length"); goto reject; } - transaction = l_get_le16(ptr); - ptr += 2; - status = l_get_le16(ptr); - ptr += 2; - /* * TODO: Hostapd seems to not include the group number when rejecting * with an unsupported group, which violates the spec. This means our @@ -1032,42 +1030,81 @@ void sae_rx_packet(struct sae_sm *sm, const uint8_t *from, const uint8_t *frame, * code, as well as add the check in the verify function to allow for * this missing group number. */ - if (len == 4 && status != - MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP) { - sae_authentication_failed(sm, status); - return; - } + if (len == 4 && auth->status != + MMPDU_STATUS_CODE_UNSUPP_FINITE_CYCLIC_GROUP) + goto reject; - if (!sae_verify_packet(sm, transaction, status, ptr, len - 4)) - return; + ret = sae_verify_packet(sm, auth->transaction_sequence, auth->status, + auth->ies, len - 6); + if (ret != 0) + return ret; - switch (transaction) { + switch (auth->transaction_sequence) { case SAE_STATE_COMMITTED: - sae_process_commit(sm, from, ptr, len - 4); - return; + return sae_process_commit(sm, hdr->address_2, auth->ies, + len - 2); case SAE_STATE_CONFIRMED: - sae_process_confirm(sm, from, ptr, len - 4); - return; + return sae_process_confirm(sm, hdr->address_2, auth->ies, + len - 2); default: - l_error("invalid transaction sequence %u", transaction); + l_error("invalid transaction sequence %u", + auth->transaction_sequence); } reject: sae_reject_authentication(sm, MMPDU_REASON_CODE_UNSPECIFIED); + + return -EBADMSG; } -void sae_start(struct sae_sm *sm) +static int sae_rx_associate(struct auth_proto *ap, const uint8_t *frame, + size_t len) { + const struct mmpdu_header *mpdu = NULL; + const struct mmpdu_association_response *body; + + mpdu = mpdu_validate(frame, len); + if (!mpdu) { + l_error("could not process frame"); + return -EBADMSG; + } + + body = mmpdu_body(mpdu); + + if (body->status_code != 0) + return (int) body->status_code; + + return 0; +} + +static bool sae_start(struct auth_proto *ap) +{ + struct sae_sm *sm = l_container_of(ap, struct sae_sm, ap); + if (sm->handshake->authenticator) memcpy(sm->peer, sm->handshake->spa, 6); else memcpy(sm->peer, sm->handshake->aa, 6); - sae_send_commit(sm, false); + return sae_send_commit(sm, false); } -struct sae_sm *sae_sm_new(struct handshake_state *hs, sae_tx_packet_func_t tx, - sae_complete_func_t complete, void *user_data) +static void sae_free(struct auth_proto *ap) +{ + struct sae_sm *sm = l_container_of(ap, struct sae_sm, ap); + + sae_reset_state(sm); + + /* zero out whole structure, including keys */ + explicit_bzero(sm, sizeof(struct sae_sm)); + + l_free(sm); +} + +struct auth_proto *sae_sm_new(struct handshake_state *hs, + sae_tx_authenticate_func_t tx_auth, + sae_tx_associate_func_t tx_assoc, + void *user_data) { struct sae_sm *sm; @@ -1076,8 +1113,8 @@ struct sae_sm *sae_sm_new(struct handshake_state *hs, sae_tx_packet_func_t tx, if (!sm) return NULL; - sm->tx = tx; - sm->complete = complete; + sm->tx_auth = tx_auth; + sm->tx_assoc = tx_assoc; sm->user_data = user_data; sm->handshake = hs; sm->state = SAE_STATE_NOTHING; @@ -1085,15 +1122,11 @@ struct sae_sm *sae_sm_new(struct handshake_state *hs, sae_tx_packet_func_t tx, sm->group = sm->ecc_groups[sm->group_retry]; sm->curve = l_ecc_curve_get_ike_group(sm->group); - return sm; -} - -void sae_sm_free(struct sae_sm *sm) -{ - sae_reset_state(sm); - - /* zero out whole structure, including keys */ - memset(sm, 0, sizeof(struct sae_sm)); - - l_free(sm); + sm->ap.start = sae_start; + sm->ap.free = sae_free; + sm->ap.rx_authenticate = sae_rx_authenticate; + sm->ap.rx_associate = sae_rx_associate; + sm->ap.auth_timeout = sae_timeout; + + return &sm->ap; } diff --git a/src/sae.h b/src/sae.h index f6c42f61..c56092ee 100644 --- a/src/sae.h +++ b/src/sae.h @@ -23,17 +23,12 @@ struct sae_sm; struct handshake_state; -typedef int (*sae_tx_packet_func_t)(const uint8_t *dest, const uint8_t *frame, - size_t len, void *user_data); +typedef void (*sae_tx_authenticate_func_t)(const uint8_t *data, size_t len, + void *user_data); +typedef void (*sae_tx_associate_func_t)(void *user_data); -typedef void (*sae_complete_func_t)(uint16_t status, void *user_data); +struct auth_proto *sae_sm_new(struct handshake_state *hs, + sae_tx_authenticate_func_t tx_auth, + sae_tx_associate_func_t tx_assoc, + void *user_data); -struct sae_sm *sae_sm_new(struct handshake_state *hs, sae_tx_packet_func_t tx, - sae_complete_func_t complete, void *user_data); -void sae_sm_free(struct sae_sm *sm); - -void sae_rx_packet(struct sae_sm *sm, const uint8_t *src, - const uint8_t *frame, size_t len); -void sae_timeout(struct sae_sm *sm); - -void sae_start(struct sae_sm *sm);