diff --git a/src/device.c b/src/device.c index e412168f..7899d9a4 100644 --- a/src/device.c +++ b/src/device.c @@ -649,7 +649,7 @@ static struct handshake_state *device_handshake_setup(struct device *device, struct handshake_state *hs; bool add_mde = false; - hs = handshake_state_new(netdev_get_ifindex(device->netdev)); + hs = netdev_handshake_state_new(device->netdev); if (security == SECURITY_PSK || security == SECURITY_8021X) { const struct l_settings *settings = iwd_get_config(); diff --git a/src/handshake.c b/src/handshake.c index b475bb8c..2dd58a26 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -70,25 +70,19 @@ void __handshake_set_install_igtk_func(handshake_install_igtk_func_t func) install_igtk = func; } -struct handshake_state *handshake_state_new(uint32_t ifindex) -{ - struct handshake_state *s; - - s = l_new(struct handshake_state, 1); - - s->ifindex = ifindex; - - return s; -} - void handshake_state_free(struct handshake_state *s) { + typeof(s->free) destroy = s->free; + l_free(s->ap_ie); l_free(s->own_ie); l_free(s->mde); l_free(s->fte); - l_free(s); + memset(s, 0, sizeof(*s)); + + if (destroy) + destroy(s); } void handshake_state_set_supplicant_address(struct handshake_state *s, @@ -342,7 +336,7 @@ void handshake_state_install_ptk(struct handshake_state *s) uint32_t cipher = ie_rsn_cipher_suite_to_cipher( s->pairwise_cipher); - install_tk(s->ifindex, s->aa, ptk->tk, cipher, s->user_data); + install_tk(s, ptk->tk, cipher); } } @@ -355,8 +349,8 @@ void handshake_state_install_gtk(struct handshake_state *s, uint32_t cipher = ie_rsn_cipher_suite_to_cipher(s->group_cipher); - install_gtk(s->ifindex, gtk_key_index, gtk, gtk_len, - rsc, rsc_len, cipher, s->user_data); + install_gtk(s, gtk_key_index, gtk, gtk_len, + rsc, rsc_len, cipher); } } @@ -370,8 +364,8 @@ void handshake_state_install_igtk(struct handshake_state *s, ie_rsn_cipher_suite_to_cipher( s->group_management_cipher); - install_igtk(s->ifindex, igtk_key_index, igtk, igtk_len, - ipn, 6, cipher, s->user_data); + install_igtk(s, igtk_key_index, igtk, igtk_len, + ipn, 6, cipher); } } diff --git a/src/handshake.h b/src/handshake.h index bd10a7a6..94bdec20 100644 --- a/src/handshake.h +++ b/src/handshake.h @@ -25,6 +25,8 @@ #include #include +struct handshake_state; + /* 802.11-2016 Table 12-6 in section 12.7.2 */ enum handshake_kde { HANDSHAKE_KDE_GTK = 0x000fac01, @@ -41,19 +43,18 @@ enum handshake_kde { }; typedef bool (*handshake_get_nonce_func_t)(uint8_t nonce[]); -typedef void (*handshake_install_tk_func_t)(uint32_t ifindex, const uint8_t *aa, - const uint8_t *tk, uint32_t cipher, - void *user_data); -typedef void (*handshake_install_gtk_func_t)(uint32_t ifindex, +typedef void (*handshake_install_tk_func_t)(struct handshake_state *hs, + const uint8_t *tk, uint32_t cipher); +typedef void (*handshake_install_gtk_func_t)(struct handshake_state *hs, uint8_t key_index, const uint8_t *gtk, uint8_t gtk_len, const uint8_t *rsc, uint8_t rsc_len, - uint32_t cipher, void *user_data); -typedef void (*handshake_install_igtk_func_t)(uint32_t ifindex, + uint32_t cipher); +typedef void (*handshake_install_igtk_func_t)(struct handshake_state *hs, uint8_t key_index, const uint8_t *igtk, uint8_t igtk_len, const uint8_t *ipn, uint8_t ipn_len, - uint32_t cipher, void *user_data); + uint32_t cipher); void __handshake_set_get_nonce_func(handshake_get_nonce_func_t func); void __handshake_set_install_tk_func(handshake_install_tk_func_t func); @@ -93,9 +94,10 @@ struct handshake_state { size_t r0khid_len; uint8_t r1khid[6]; void *user_data; + + void (*free)(struct handshake_state *s); }; -struct handshake_state *handshake_state_new(uint32_t ifindex); void handshake_state_free(struct handshake_state *s); void handshake_state_set_supplicant_address(struct handshake_state *s, diff --git a/src/netdev.c b/src/netdev.c index 38f16e97..0fa4a1d0 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -59,6 +59,15 @@ #define ENOTSUPP 524 #endif +struct netdev_handshake_state { + struct handshake_state super; + uint32_t pairwise_new_key_cmd_id; + uint32_t group_new_key_cmd_id; + uint32_t group_management_new_key_cmd_id; + uint32_t set_station_cmd_id; + struct netdev *netdev; +}; + struct netdev { uint32_t index; char name[IFNAMSIZ]; @@ -76,10 +85,6 @@ struct netdev { void *user_data; struct eapol_sm *sm; struct handshake_state *handshake; - uint32_t pairwise_new_key_cmd_id; - uint32_t group_new_key_cmd_id; - uint32_t group_management_new_key_cmd_id; - uint32_t set_station_cmd_id; uint32_t connect_cmd_id; uint32_t disconnect_cmd_id; enum netdev_result result; @@ -140,6 +145,49 @@ static void do_debug(const char *str, void *user_data) l_info("%s%s", prefix, str); } +static void netdev_handshake_state_free(struct handshake_state *hs) +{ + struct netdev_handshake_state *nhs = + container_of(hs, struct netdev_handshake_state, super); + + if (nhs->pairwise_new_key_cmd_id) { + l_genl_family_cancel(nl80211, nhs->pairwise_new_key_cmd_id); + nhs->pairwise_new_key_cmd_id = 0; + } + + if (nhs->group_new_key_cmd_id) { + l_genl_family_cancel(nl80211, nhs->group_new_key_cmd_id); + nhs->group_new_key_cmd_id = 0; + } + + if (nhs->group_management_new_key_cmd_id) { + l_genl_family_cancel(nl80211, + nhs->group_management_new_key_cmd_id); + nhs->group_management_new_key_cmd_id = 0; + } + + if (nhs->set_station_cmd_id) { + l_genl_family_cancel(nl80211, nhs->set_station_cmd_id); + nhs->set_station_cmd_id = 0; + } + + l_free(nhs); +} + +struct handshake_state *netdev_handshake_state_new(struct netdev *netdev) +{ + struct netdev_handshake_state *nhs; + + nhs = l_new(struct netdev_handshake_state, 1); + + nhs->super.ifindex = netdev->index; + nhs->super.free = netdev_handshake_state_free; + + nhs->netdev = netdev; + + return &nhs->super; +} + struct cb_data { netdev_command_func_t callback; void *user_data; @@ -472,27 +520,6 @@ static void netdev_connect_free(struct netdev *netdev) netdev_rssi_polling_update(netdev); - if (netdev->pairwise_new_key_cmd_id) { - l_genl_family_cancel(nl80211, netdev->pairwise_new_key_cmd_id); - netdev->pairwise_new_key_cmd_id = 0; - } - - if (netdev->group_new_key_cmd_id) { - l_genl_family_cancel(nl80211, netdev->group_new_key_cmd_id); - netdev->group_new_key_cmd_id = 0; - } - - if (netdev->group_management_new_key_cmd_id) { - l_genl_family_cancel(nl80211, - netdev->group_management_new_key_cmd_id); - netdev->group_management_new_key_cmd_id = 0; - } - - if (netdev->set_station_cmd_id) { - l_genl_family_cancel(nl80211, netdev->set_station_cmd_id); - netdev->set_station_cmd_id = 0; - } - if (netdev->connect_cmd_id) { l_genl_family_cancel(nl80211, netdev->connect_cmd_id); netdev->connect_cmd_id = 0; @@ -888,9 +915,10 @@ static void netdev_connect_ok(struct netdev *netdev) netdev_rssi_polling_update(netdev); } -static void netdev_setting_keys_failed(struct netdev *netdev, +static void netdev_setting_keys_failed(struct netdev_handshake_state *nhs, uint16_t reason_code) { + struct netdev *netdev = nhs->netdev; struct l_genl_msg *msg; /* @@ -902,17 +930,17 @@ static void netdev_setting_keys_failed(struct netdev *netdev, * * Cancel all pending commands, then de-authenticate */ - l_genl_family_cancel(nl80211, netdev->pairwise_new_key_cmd_id); - netdev->pairwise_new_key_cmd_id = 0; + l_genl_family_cancel(nl80211, nhs->pairwise_new_key_cmd_id); + nhs->pairwise_new_key_cmd_id = 0; - l_genl_family_cancel(nl80211, netdev->group_new_key_cmd_id); - netdev->group_new_key_cmd_id = 0; + l_genl_family_cancel(nl80211, nhs->group_new_key_cmd_id); + nhs->group_new_key_cmd_id = 0; - l_genl_family_cancel(nl80211, netdev->group_management_new_key_cmd_id); - netdev->group_management_new_key_cmd_id = 0; + l_genl_family_cancel(nl80211, nhs->group_management_new_key_cmd_id); + nhs->group_management_new_key_cmd_id = 0; - l_genl_family_cancel(nl80211, netdev->set_station_cmd_id); - netdev->set_station_cmd_id = 0; + l_genl_family_cancel(nl80211, nhs->set_station_cmd_id); + nhs->set_station_cmd_id = 0; netdev->result = NETDEV_RESULT_KEY_SETTING_FAILED; msg = netdev_build_cmd_disconnect(netdev, @@ -924,10 +952,11 @@ static void netdev_setting_keys_failed(struct netdev *netdev, static void netdev_set_station_cb(struct l_genl_msg *msg, void *user_data) { - struct netdev *netdev = user_data; + struct netdev_handshake_state *nhs = user_data; + struct netdev *netdev = nhs->netdev; int err; - netdev->set_station_cmd_id = 0; + nhs->set_station_cmd_id = 0; if (!netdev->connected) return; @@ -938,7 +967,7 @@ static void netdev_set_station_cb(struct l_genl_msg *msg, void *user_data) if (err < 0) { l_error("Set Station failed for ifindex %d", netdev->index); - netdev_setting_keys_failed(netdev, + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_UNSPECIFIED); return; } @@ -968,29 +997,30 @@ static struct l_genl_msg *netdev_build_cmd_set_station(struct netdev *netdev, static void netdev_new_group_key_cb(struct l_genl_msg *msg, void *data) { - struct netdev *netdev = data; + struct netdev_handshake_state *nhs = data; + struct netdev *netdev = nhs->netdev; - netdev->group_new_key_cmd_id = 0; + nhs->group_new_key_cmd_id = 0; if (l_genl_msg_get_error(msg) >= 0) return; l_error("New Key for Group Key failed for ifindex: %d", netdev->index); - netdev_setting_keys_failed(netdev, MMPDU_REASON_CODE_UNSPECIFIED); + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_UNSPECIFIED); } static void netdev_new_group_management_key_cb(struct l_genl_msg *msg, void *data) { - struct netdev *netdev = data; + struct netdev_handshake_state *nhs = data; + struct netdev *netdev = nhs->netdev; - netdev->group_management_new_key_cmd_id = 0; + nhs->group_management_new_key_cmd_id = 0; if (l_genl_msg_get_error(msg) < 0) { l_error("New Key for Group Mgmt failed for ifindex: %d", netdev->index); - netdev_setting_keys_failed(netdev, - MMPDU_REASON_CODE_UNSPECIFIED); + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_UNSPECIFIED); } } @@ -1068,28 +1098,28 @@ static bool netdev_copy_tk(uint8_t *tk_buf, const uint8_t *tk, return true; } -static void netdev_set_gtk(uint32_t ifindex, uint8_t key_index, +static void netdev_set_gtk(struct handshake_state *hs, uint8_t key_index, const uint8_t *gtk, uint8_t gtk_len, const uint8_t *rsc, uint8_t rsc_len, - uint32_t cipher, void *user_data) + uint32_t cipher) { + struct netdev_handshake_state *nhs = + container_of(hs, struct netdev_handshake_state, super); + struct netdev *netdev = nhs->netdev; uint8_t gtk_buf[32]; - struct netdev *netdev; struct l_genl_msg *msg; - netdev = netdev_find(ifindex); - l_debug("%d", netdev->index); if (crypto_cipher_key_len(cipher) != gtk_len) { l_error("Unexpected key length: %d", gtk_len); - netdev_setting_keys_failed(netdev, + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_INVALID_GROUP_CIPHER); return; } if (!netdev_copy_tk(gtk_buf, gtk, cipher, false)) { - netdev_setting_keys_failed(netdev, + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_INVALID_GROUP_CIPHER); return; } @@ -1097,33 +1127,33 @@ static void netdev_set_gtk(uint32_t ifindex, uint8_t key_index, msg = netdev_build_cmd_new_key_group(netdev, cipher, key_index, gtk_buf, gtk_len, rsc, rsc_len); - netdev->group_new_key_cmd_id = + nhs->group_new_key_cmd_id = l_genl_family_send(nl80211, msg, netdev_new_group_key_cb, - netdev, NULL); + nhs, NULL); - if (netdev->group_new_key_cmd_id > 0) + if (nhs->group_new_key_cmd_id > 0) return; l_genl_msg_unref(msg); - netdev_setting_keys_failed(netdev, MMPDU_REASON_CODE_UNSPECIFIED); + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_UNSPECIFIED); } -static void netdev_set_igtk(uint32_t ifindex, uint8_t key_index, +static void netdev_set_igtk(struct handshake_state *hs, uint8_t key_index, const uint8_t *igtk, uint8_t igtk_len, const uint8_t *ipn, uint8_t ipn_len, - uint32_t cipher, void *user_data) + uint32_t cipher) { + struct netdev_handshake_state *nhs = + container_of(hs, struct netdev_handshake_state, super); uint8_t igtk_buf[16]; - struct netdev *netdev; + struct netdev *netdev = nhs->netdev; struct l_genl_msg *msg; - netdev = netdev_find(ifindex); - l_debug("%d", netdev->index); if (crypto_cipher_key_len(cipher) != igtk_len) { l_error("Unexpected key length: %d", igtk_len); - netdev_setting_keys_failed(netdev, + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_INVALID_GROUP_CIPHER); return; } @@ -1134,7 +1164,7 @@ static void netdev_set_igtk(uint32_t ifindex, uint8_t key_index, break; default: l_error("Unexpected cipher: %x", cipher); - netdev_setting_keys_failed(netdev, + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_INVALID_GROUP_CIPHER); return; } @@ -1142,23 +1172,24 @@ static void netdev_set_igtk(uint32_t ifindex, uint8_t key_index, msg = netdev_build_cmd_new_key_group(netdev, cipher, key_index, igtk_buf, igtk_len, ipn, ipn_len); - netdev->group_management_new_key_cmd_id = + nhs->group_management_new_key_cmd_id = l_genl_family_send(nl80211, msg, netdev_new_group_management_key_cb, - netdev, NULL); + nhs, NULL); - if (netdev->group_management_new_key_cmd_id > 0) + if (nhs->group_management_new_key_cmd_id > 0) return; l_genl_msg_unref(msg); - netdev_setting_keys_failed(netdev, MMPDU_REASON_CODE_UNSPECIFIED); + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_UNSPECIFIED); } static void netdev_new_pairwise_key_cb(struct l_genl_msg *msg, void *data) { - struct netdev *netdev = data; + struct netdev_handshake_state *nhs = data; + struct netdev *netdev = nhs->netdev; - netdev->pairwise_new_key_cmd_id = 0; + nhs->pairwise_new_key_cmd_id = 0; if (l_genl_msg_get_error(msg) < 0) { l_error("New Key for Pairwise Key failed for ifindex: %d", @@ -1173,15 +1204,15 @@ static void netdev_new_pairwise_key_cb(struct l_genl_msg *msg, void *data) */ msg = netdev_build_cmd_set_station(netdev, netdev->handshake->aa); - netdev->set_station_cmd_id = + nhs->set_station_cmd_id = l_genl_family_send(nl80211, msg, netdev_set_station_cb, - netdev, NULL); - if (netdev->set_station_cmd_id > 0) + nhs, NULL); + if (nhs->set_station_cmd_id > 0) return; l_genl_msg_unref(msg); error: - netdev_setting_keys_failed(netdev, MMPDU_REASON_CODE_UNSPECIFIED); + netdev_setting_keys_failed(nhs, MMPDU_REASON_CODE_UNSPECIFIED); } static struct l_genl_msg *netdev_build_cmd_new_key_pairwise( @@ -1205,19 +1236,16 @@ static struct l_genl_msg *netdev_build_cmd_new_key_pairwise( return msg; } -static void netdev_set_tk(uint32_t ifindex, const uint8_t *aa, - const uint8_t *tk, uint32_t cipher, - void *user_data) +static void netdev_set_tk(struct handshake_state *hs, + const uint8_t *tk, uint32_t cipher) { + struct netdev_handshake_state *nhs = + container_of(hs, struct netdev_handshake_state, super); uint8_t tk_buf[32]; - struct netdev *netdev; + struct netdev *netdev = nhs->netdev; struct l_genl_msg *msg; enum mmpdu_reason_code rc; - netdev = netdev_find(ifindex); - if (!netdev) - return; - l_debug("%d", netdev->index); if (netdev->event_filter) @@ -1229,17 +1257,17 @@ static void netdev_set_tk(uint32_t ifindex, const uint8_t *aa, goto invalid_key; rc = MMPDU_REASON_CODE_UNSPECIFIED; - msg = netdev_build_cmd_new_key_pairwise(netdev, cipher, aa, tk_buf, + msg = netdev_build_cmd_new_key_pairwise(netdev, cipher, hs->aa, tk_buf, crypto_cipher_key_len(cipher)); - netdev->pairwise_new_key_cmd_id = + nhs->pairwise_new_key_cmd_id = l_genl_family_send(nl80211, msg, netdev_new_pairwise_key_cb, - netdev, NULL); - if (netdev->pairwise_new_key_cmd_id > 0) + nhs, NULL); + if (nhs->pairwise_new_key_cmd_id > 0) return; l_genl_msg_unref(msg); invalid_key: - netdev_setting_keys_failed(netdev, rc); + netdev_setting_keys_failed(nhs, rc); } static void netdev_handshake_failed(uint32_t ifindex, @@ -2313,6 +2341,7 @@ int netdev_reassociate(struct netdev *netdev, struct scan_bss *target_bss, netdev_connect_cb_t cb, void *user_data) { struct l_genl_msg *cmd_connect; + struct netdev_handshake_state; struct handshake_state *old_hs; struct eapol_sm *sm = NULL, *old_sm; bool is_rsn = hs->own_ie != NULL; @@ -2340,22 +2369,6 @@ int netdev_reassociate(struct netdev *netdev, struct scan_bss *target_bss, netdev_rssi_polling_update(netdev); - /* - * Cancel commands that could be running because of EAPoL activity - * like re-keying, this way the callbacks for those commands don't - * have to check if failures resulted from the transition. - */ - if (netdev->group_new_key_cmd_id) { - l_genl_family_cancel(nl80211, netdev->group_new_key_cmd_id); - netdev->group_new_key_cmd_id = 0; - } - - if (netdev->group_management_new_key_cmd_id) { - l_genl_family_cancel(nl80211, - netdev->group_management_new_key_cmd_id); - netdev->group_management_new_key_cmd_id = 0; - } - if (old_sm) eapol_sm_free(old_sm); @@ -2486,6 +2499,7 @@ int netdev_fast_transition(struct netdev *netdev, struct scan_bss *target_bss, netdev_connect_cb_t cb) { struct l_genl_msg *cmd_authenticate; + struct netdev_handshake_state *nhs; uint8_t orig_snonce[32]; int err; @@ -2547,15 +2561,18 @@ int netdev_fast_transition(struct netdev *netdev, struct scan_bss *target_bss, * like re-keying, this way the callbacks for those commands don't * have to check if failures resulted from the transition. */ - if (netdev->group_new_key_cmd_id) { - l_genl_family_cancel(nl80211, netdev->group_new_key_cmd_id); - netdev->group_new_key_cmd_id = 0; + nhs = container_of(netdev->handshake, + struct netdev_handshake_state, super); + + if (nhs->group_new_key_cmd_id) { + l_genl_family_cancel(nl80211, nhs->group_new_key_cmd_id); + nhs->group_new_key_cmd_id = 0; } - if (netdev->group_management_new_key_cmd_id) { + if (nhs->group_management_new_key_cmd_id) { l_genl_family_cancel(nl80211, - netdev->group_management_new_key_cmd_id); - netdev->group_management_new_key_cmd_id = 0; + nhs->group_management_new_key_cmd_id); + nhs->group_management_new_key_cmd_id = 0; } netdev_rssi_polling_update(netdev); diff --git a/src/netdev.h b/src/netdev.h index a8399d2f..1f1ee45c 100644 --- a/src/netdev.h +++ b/src/netdev.h @@ -104,6 +104,8 @@ int netdev_set_4addr(struct netdev *netdev, bool use_4addr, bool netdev_get_4addr(struct netdev *netdev); const char *netdev_get_name(struct netdev *netdev); bool netdev_get_is_up(struct netdev *netdev); + +struct handshake_state *netdev_handshake_state_new(struct netdev *netdev); struct handshake_state *netdev_get_handshake(struct netdev *netdev); int netdev_connect(struct netdev *netdev, struct scan_bss *bss, diff --git a/src/wsc.c b/src/wsc.c index ea45d98d..99524a9f 100644 --- a/src/wsc.c +++ b/src/wsc.c @@ -408,11 +408,10 @@ static void wsc_connect(struct wsc *wsc) struct handshake_state *hs; struct l_settings *settings = l_settings_new(); struct scan_bss *bss = wsc->target; - uint32_t ifindex = netdev_get_ifindex(device_get_netdev(wsc->device)); wsc->target = NULL; - hs = handshake_state_new(ifindex); + hs = netdev_handshake_state_new(device_get_netdev(wsc->device)); l_settings_set_string(settings, "Security", "EAP-Identity", "WFA-SimpleConfig-Enrollee-1-0");