From 9c6643b773a8b1bd888c70ee84e0ca0f749c9833 Mon Sep 17 00:00:00 2001 From: Andrew Zaborowski Date: Mon, 12 Dec 2016 18:34:19 +0100 Subject: [PATCH] netdev: Always require handshake_state with netdev_connect --- src/device.c | 12 +++++------- src/netdev.c | 26 ++++++++++++-------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/device.c b/src/device.c index b5b609ac..e15c9ba4 100644 --- a/src/device.c +++ b/src/device.c @@ -677,10 +677,12 @@ void device_connect_network(struct device *device, struct network *network, enum security security = network_get_security(network); struct wiphy *wiphy = device->wiphy; struct l_dbus *dbus = dbus_get_bus(); - struct handshake_state *hs = NULL; + struct handshake_state *hs; bool add_mde = false; uint8_t *mde; + hs = handshake_state_new(netdev_get_ifindex(device->netdev)); + if (security == SECURITY_PSK || security == SECURITY_8021X) { struct ie_rsn_info bss_info; uint8_t rsne_buf[256]; @@ -718,8 +720,6 @@ void device_connect_network(struct device *device, struct network *network, } else if (info.group_management_cipher != 0) info.mfpc = true; - hs = handshake_state_new(netdev_get_ifindex(device->netdev)); - ssid = network_get_ssid(network); handshake_state_set_ssid(hs, (void *) ssid, strlen(ssid)); @@ -758,8 +758,7 @@ void device_connect_network(struct device *device, struct network *network, mde[1] = 3; memcpy(mde + 2, bss->mde, 3); - if (hs) - handshake_state_set_mde(hs, mde); + handshake_state_set_mde(hs, mde); } else mde = NULL; @@ -768,8 +767,7 @@ void device_connect_network(struct device *device, struct network *network, if (netdev_connect(device->netdev, bss, hs, mde, device_netdev_event, device_connect_cb, device) < 0) { - if (hs) - handshake_state_free(hs); + handshake_state_free(hs); l_free(mde); diff --git a/src/netdev.c b/src/netdev.c index 85e65983..7c25b284 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -1067,6 +1067,7 @@ static void netdev_connect_event(struct l_genl_msg *msg, const uint16_t *status_code = NULL; const uint8_t *ies = NULL; size_t ies_len; + bool is_rsn = netdev->handshake->own_ie != NULL; l_debug(""); @@ -1139,10 +1140,10 @@ static void netdev_connect_event(struct l_genl_msg *msg, * in a non-RSN (12.4.2 vs. 12.4.3). */ - if (netdev->mde && !netdev->handshake && fte) + if (netdev->mde && !is_rsn && fte) goto error; - if (netdev->mde && netdev->handshake) { + if (netdev->mde && is_rsn) { struct ie_ft_info ft_info; uint8_t zeros[32]; @@ -1275,6 +1276,7 @@ static struct l_genl_msg *netdev_build_cmd_connect(struct netdev *netdev, struct l_genl_msg *msg; struct iovec iov[2]; int iov_elems = 0; + bool is_rsn = hs->own_ie != NULL; msg = l_genl_msg_new_sized(NL80211_CMD_CONNECT, 512); l_genl_msg_append_attr(msg, NL80211_ATTR_IFINDEX, 4, &netdev->index); @@ -1288,7 +1290,7 @@ static struct l_genl_msg *netdev_build_cmd_connect(struct netdev *netdev, if (bss->capability & IE_BSS_CAP_PRIVACY) l_genl_msg_append_attr(msg, NL80211_ATTR_PRIVACY, 0, NULL); - if (hs) { + if (is_rsn) { uint32_t nl_cipher; uint32_t nl_akm; uint32_t wpa_version; @@ -1330,11 +1332,9 @@ static struct l_genl_msg *netdev_build_cmd_connect(struct netdev *netdev, l_genl_msg_append_attr(msg, NL80211_ATTR_CONTROL_PORT, 0, NULL); - if (hs->own_ie) { - iov[iov_elems].iov_base = (void *) hs->own_ie; - iov[iov_elems].iov_len = hs->own_ie[1] + 2; - iov_elems += 1; - } + iov[iov_elems].iov_base = (void *) hs->own_ie; + iov[iov_elems].iov_len = hs->own_ie[1] + 2; + iov_elems += 1; } if (mde) { @@ -1377,10 +1377,8 @@ static int netdev_connect_common(struct netdev *netdev, if (mde) netdev->mde = l_memdup(mde, mde[1] + 2); - if (hs) { - handshake_state_set_authenticator_address(hs, bss->addr); - handshake_state_set_supplicant_address(hs, netdev->addr); - } + handshake_state_set_authenticator_address(hs, bss->addr); + handshake_state_set_supplicant_address(hs, netdev->addr); return 0; @@ -1393,7 +1391,7 @@ int netdev_connect(struct netdev *netdev, struct scan_bss *bss, { struct l_genl_msg *cmd_connect; struct eapol_sm *sm = NULL; - bool is_rsn = hs != NULL; + bool is_rsn = hs->own_ie != NULL; if (netdev->connected) return -EISCONN; @@ -1427,7 +1425,7 @@ int netdev_connect_wsc(struct netdev *netdev, struct scan_bss *bss, if (netdev->connected) return -EISCONN; - cmd_connect = netdev_build_cmd_connect(netdev, bss, NULL, NULL); + cmd_connect = netdev_build_cmd_connect(netdev, bss, hs, NULL); if (!cmd_connect) return -EINVAL;