diff --git a/src/dpp.c b/src/dpp.c index cedf5bfe..6379f1fd 100644 --- a/src/dpp.c +++ b/src/dpp.c @@ -61,6 +61,9 @@ static struct l_genl_family *nl80211; static uint8_t broadcast[] = { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }; static struct l_queue *dpp_list; static uint32_t mlme_watch; +static uint32_t unicast_watch; + +static uint8_t dpp_prefix[] = { 0x04, 0x09, 0x50, 0x6f, 0x9a, 0x1a, 0x01 }; enum dpp_state { DPP_STATE_NOTHING, @@ -127,6 +130,8 @@ struct dpp_sm { uint8_t frame_retry; struct l_dbus_message *pending; + + bool mcast_support : 1; }; static void dpp_free_auth_data(struct dpp_sm *dpp) @@ -1403,6 +1408,9 @@ static void authenticate_request(struct dpp_sm *dpp, const uint8_t *from, const void *ad1 = body + 8; uint32_t freq; + if (util_is_broadcast_address(from)) + return; + if (dpp->state != DPP_STATE_PRESENCE) return; @@ -1577,11 +1585,10 @@ auth_request_failed: dpp_free_auth_data(dpp); } -static void dpp_handle_frame(const struct mmpdu_header *frame, - const void *body, size_t body_len, - int rssi, void *user_data) +static void dpp_handle_frame(struct dpp_sm *dpp, + const struct mmpdu_header *frame, + const void *body, size_t body_len) { - struct dpp_sm *dpp = user_data; const uint8_t *ptr; /* @@ -1664,11 +1671,103 @@ static void dpp_mlme_notify(struct l_genl_msg *msg, void *user_data) dpp_send_frame(dpp, &iov, 1, dpp->current_freq); } +static void dpp_unicast_notify(struct l_genl_msg *msg, void *user_data) +{ + struct dpp_sm *dpp; + const uint64_t *wdev_id = NULL; + struct l_genl_attr attr; + uint16_t type, len, frame_len; + const void *data; + const struct mmpdu_header *mpdu = NULL; + const uint8_t *body; + size_t body_len; + + if (l_genl_msg_get_command(msg) != NL80211_CMD_FRAME) + return; + + if (!l_genl_attr_init(&attr, msg)) + return; + + while (l_genl_attr_next(&attr, &type, &len, &data)) { + switch (type) { + case NL80211_ATTR_WDEV: + if (len != 8) + break; + + wdev_id = data; + break; + + case NL80211_ATTR_FRAME: + mpdu = mpdu_validate(data, len); + if (!mpdu) { + l_warn("Frame didn't validate as MMPDU"); + return; + } + + frame_len = len; + break; + } + } + + if (!wdev_id) { + l_warn("Bad wdev attribute"); + return; + } + + dpp = l_queue_find(dpp_list, match_wdev, wdev_id); + if (!dpp) + return; + + if (!mpdu) { + l_warn("Missing frame data"); + return; + } + + body = mmpdu_body(mpdu); + body_len = (const uint8_t *) mpdu + frame_len - body; + + if (body_len < sizeof(dpp_prefix) || + memcmp(body, dpp_prefix, sizeof(dpp_prefix)) != 0) + return; + + dpp_handle_frame(dpp, mpdu, body, body_len); +} + +static void dpp_frame_watch_cb(struct l_genl_msg *msg, void *user_data) +{ + if (l_genl_msg_get_error(msg) < 0) + l_error("Could not register frame watch type %04x: %i", + L_PTR_TO_UINT(user_data), l_genl_msg_get_error(msg)); +} + +/* + * Special case the frame watch which includes the presence frames since they + * require multicast support. This is only supported by ath9k, so adding + * general support to frame-xchg isn't desireable. + */ +static void dpp_frame_watch(struct dpp_sm *dpp, uint16_t frame_type, + const uint8_t *prefix, size_t prefix_len) +{ + struct l_genl_msg *msg; + + msg = l_genl_msg_new_sized(NL80211_CMD_REGISTER_FRAME, 32 + prefix_len); + + l_genl_msg_append_attr(msg, NL80211_ATTR_WDEV, 8, &dpp->wdev_id); + l_genl_msg_append_attr(msg, NL80211_ATTR_FRAME_TYPE, 2, &frame_type); + l_genl_msg_append_attr(msg, NL80211_ATTR_FRAME_MATCH, + prefix_len, prefix); + if (dpp->mcast_support) + l_genl_msg_append_attr(msg, NL80211_ATTR_RECEIVE_MULTICAST, + 0, NULL); + + l_genl_family_send(nl80211, msg, dpp_frame_watch_cb, + L_UINT_TO_PTR(frame_type), NULL); +} + static void dpp_create(struct netdev *netdev) { struct l_dbus *dbus = dbus_get_bus(); struct dpp_sm *dpp = l_new(struct dpp_sm, 1); - uint8_t dpp_prefix[] = { 0x04, 0x09, 0x50, 0x6f, 0x9a, 0x1a, 0x01 }; uint8_t dpp_conf_response_prefix[] = { 0x04, 0x0b }; uint8_t dpp_conf_request_prefix[] = { 0x04, 0x0a }; @@ -1678,6 +1777,9 @@ static void dpp_create(struct netdev *netdev) dpp->curve = l_ecc_curve_from_ike_group(19); dpp->key_len = l_ecc_curve_get_scalar_bytes(dpp->curve); dpp->nonce_len = dpp_nonce_len_from_key_len(dpp->key_len); + dpp->mcast_support = wiphy_has_ext_feature( + wiphy_find_by_wdev(dpp->wdev_id), + NL80211_EXT_FEATURE_MULTICAST_REGISTRATIONS); l_ecdh_generate_key_pair(dpp->curve, &dpp->boot_private, &dpp->boot_public); @@ -1690,9 +1792,8 @@ static void dpp_create(struct netdev *netdev) l_dbus_object_add_interface(dbus, netdev_get_path(netdev), IWD_DPP_INTERFACE, dpp); - frame_watch_add(netdev_get_wdev_id(netdev), 0, 0x00d0, dpp_prefix, - sizeof(dpp_prefix), dpp_handle_frame, - dpp, NULL); + dpp_frame_watch(dpp, 0x00d0, dpp_prefix, sizeof(dpp_prefix)); + frame_watch_add(netdev_get_wdev_id(netdev), 0, 0x00d0, dpp_conf_response_prefix, sizeof(dpp_conf_response_prefix), @@ -1935,6 +2036,11 @@ static int dpp_init(void) mlme_watch = l_genl_family_register(nl80211, "mlme", dpp_mlme_notify, NULL, NULL); + unicast_watch = l_genl_add_unicast_watch(iwd_get_genl(), + NL80211_GENL_NAME, + dpp_unicast_notify, + NULL, NULL); + dpp_list = l_queue_new(); return 0; @@ -1948,6 +2054,8 @@ static void dpp_exit(void) netdev_watch_remove(netdev_watch); + l_genl_remove_unicast_watch(iwd_get_genl(), unicast_watch); + l_genl_family_unregister(nl80211, mlme_watch); mlme_watch = 0;