diff --git a/src/netdev.c b/src/netdev.c index d9d64932..1e64b943 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -83,6 +83,8 @@ struct netdev { uint32_t disconnect_cmd_id; enum netdev_result result; struct l_timeout *neighbor_report_timeout; + struct l_timeout *sa_query_timeout; + uint16_t sa_query_id; uint8_t prev_bssid[ETH_ALEN]; int8_t rssi_levels[16]; uint8_t rssi_levels_num; @@ -446,6 +448,11 @@ static void netdev_connect_free(struct netdev *netdev) l_timeout_remove(netdev->neighbor_report_timeout); } + if (netdev->sa_query_timeout) { + l_timeout_remove(netdev->sa_query_timeout); + netdev->sa_query_timeout = NULL; + } + netdev->operational = false; netdev->connected = false; netdev->connect_cb = NULL; @@ -2758,6 +2765,143 @@ static void netdev_neighbor_report_frame_event(struct netdev *netdev, l_timeout_remove(netdev->neighbor_report_timeout); } +static void netdev_sa_query_resp_frame_event(struct netdev *netdev, + const struct mmpdu_header *hdr, + const void *body, size_t body_len, + void *user_data) +{ + if (body_len < 4) { + l_debug("SA Query frame too short"); + return; + } + + l_debug("SA Query src="MAC" dest="MAC" bssid="MAC" transaction=%u", + MAC_STR(hdr->address_2), MAC_STR(hdr->address_1), + MAC_STR(hdr->address_3), l_get_u16(body + 2)); + + if (!netdev->sa_query_timeout) { + l_debug("no SA Query request sent"); + return; + } + + /* check if this is from our connected BSS */ + if (memcmp(hdr->address_2, netdev->handshake->aa, 6)) { + l_debug("received SA Query from non-connected AP"); + return; + } + + if (memcmp(body + 2, &netdev->sa_query_id, 2)) { + l_debug("SA Query transaction ID's did not match"); + return; + } + + l_info("SA Query response from connected BSS received, " + "keeping the connection active"); + + l_timeout_remove(netdev->sa_query_timeout); + netdev->sa_query_timeout = NULL; +} + +static void netdev_sa_query_req_cb(struct l_genl_msg *msg, + void *user_data) +{ + struct netdev *netdev = user_data; + + if (l_genl_msg_get_error(msg) < 0) { + l_debug("error sending SA Query request"); + + l_timeout_remove(netdev->sa_query_timeout); + netdev->sa_query_timeout = NULL; + } +} + +static void netdev_sa_query_timeout(struct l_timeout *timeout, + void *user_data) +{ + struct netdev *netdev = user_data; + struct l_genl_msg *msg; + + l_info("SA Query timed out, connection is invalid. Disconnecting..."); + + l_timeout_remove(netdev->sa_query_timeout); + netdev->sa_query_timeout = NULL; + + msg = netdev_build_cmd_disconnect(netdev, + MMPDU_REASON_CODE_PREV_AUTH_NOT_VALID); + netdev->disconnect_cmd_id = l_genl_family_send(nl80211, msg, + netdev_connect_failed, netdev, NULL); +} + +static void netdev_unprot_disconnect_event(struct l_genl_msg *msg, + struct netdev *netdev) +{ + const struct mmpdu_header *hdr = NULL; + struct l_genl_attr attr; + uint16_t type; + uint16_t len; + const void *data; + uint8_t action_frame[4]; + uint8_t reason_code; + + if (!netdev->connected) + return; + + /* ignore excessive disassociate requests */ + if (netdev->sa_query_timeout) + return; + + if (!l_genl_attr_init(&attr, msg)) + return; + + while (l_genl_attr_next(&attr, &type, &len, &data)) { + switch (type) { + case NL80211_ATTR_FRAME: + hdr = mpdu_validate(data, len); + break; + } + } + + /* check that ATTR_FRAME was actually included */ + if (!hdr) + return; + + /* get reason code, first byte of frame */ + reason_code = l_get_u8(mmpdu_body(hdr)); + + l_info("disconnect event, src="MAC" dest="MAC" bssid="MAC" reason=%u", + MAC_STR(hdr->address_2), MAC_STR(hdr->address_1), + MAC_STR(hdr->address_3), reason_code); + + if (memcmp(hdr->address_2, netdev->handshake->aa, 6)) { + l_debug("received invalid disassociate frame"); + return; + } + + if (reason_code != MMPDU_REASON_CODE_CLASS2_FRAME_FROM_NONAUTH_STA && + reason_code != + MMPDU_REASON_CODE_CLASS3_FRAME_FROM_NONASSOC_STA) { + l_debug("invalid reason code %u", reason_code); + return; + } + + action_frame[0] = 0x08; /* Category: SA Query */ + action_frame[1] = 0x00; /* SA Query Action: Request */ + + /* Transaction ID */ + l_getrandom(action_frame + 2, 2); + + if (!netdev_send_action_frame(netdev, netdev->handshake->aa, + action_frame, sizeof(action_frame), + netdev_sa_query_req_cb)) { + l_error("error sending SA Query action frame"); + return; + } + + netdev->sa_query_id = l_get_u16(action_frame + 2); + netdev->sa_query_timeout = l_timeout_create(3, + netdev_sa_query_timeout, netdev, NULL); +} + static void netdev_mlme_notify(struct l_genl_msg *msg, void *user_data) { struct netdev *netdev = NULL; @@ -2813,6 +2957,10 @@ static void netdev_mlme_notify(struct l_genl_msg *msg, void *user_data) case NL80211_CMD_SET_REKEY_OFFLOAD: netdev_rekey_offload_event(msg, netdev); break; + case NL80211_CMD_UNPROT_DEAUTHENTICATE: + case NL80211_CMD_UNPROT_DISASSOCIATE: + netdev_unprot_disconnect_event(msg, netdev); + break; } } @@ -3305,6 +3453,7 @@ static void netdev_create_from_genl(struct l_genl_msg *msg) struct ifinfomsg *rtmmsg; size_t bufsize; const uint8_t action_neighbor_report_prefix[2] = { 0x05, 0x05 }; + const uint8_t action_sa_query_resp_prefix[2] = { 0x08, 0x01 }; if (!l_genl_attr_init(&attr, msg)) return; @@ -3417,6 +3566,10 @@ static void netdev_create_from_genl(struct l_genl_msg *msg) sizeof(action_neighbor_report_prefix), netdev_neighbor_report_frame_event, NULL); + netdev_frame_watch_add(netdev, 0x00d0, action_sa_query_resp_prefix, + sizeof(action_sa_query_resp_prefix), + netdev_sa_query_resp_frame_event, NULL); + /* Set RSSI threshold for CQM notifications */ netdev_cqm_rssi_update(netdev); }