diff --git a/src/dpp.c b/src/dpp.c index 06ae2929..a95f93e2 100644 --- a/src/dpp.c +++ b/src/dpp.c @@ -853,13 +853,42 @@ static bool dpp_scan_results(int err, struct l_queue *bss_list, { struct dpp_sm *dpp = userdata; struct station *station = station_find(netdev_get_ifindex(dpp->netdev)); + struct scan_bss *bss; + char ssid[33]; + struct network *network; if (err < 0) - return false; + goto reset; - station_set_scan_results(station, bss_list, freqs, true); + if (!bss_list || l_queue_length(bss_list) == 0) + goto reset; + + /* + * The station watch _should_ detect this and reset, which cancels the + * scan. But just in case... + */ + if (L_WARN_ON(station_get_connected_network(station))) + goto reset; + + /* Purely for grabbing the SSID */ + bss = l_queue_peek_head(bss_list); + + memcpy(ssid, bss->ssid, bss->ssid_len); + ssid[bss->ssid_len] = '\0'; + + station_set_scan_results(station, bss_list, freqs, false); + + network = station_network_find(station, ssid, SECURITY_PSK); + + dpp_reset(dpp); + + bss = network_bss_select(network, true); + network_autoconnect(network, bss); return true; + +reset: + return false; } static void dpp_scan_destroy(void *userdata) @@ -898,6 +927,7 @@ static void dpp_handle_config_response_frame(const struct mmpdu_header *frame, struct network *network = NULL; struct scan_bss *bss = NULL; char ssid[33]; + size_t ssid_len; if (dpp->state != DPP_STATE_CONFIGURING) return; @@ -1027,6 +1057,7 @@ static void dpp_handle_config_response_frame(const struct mmpdu_header *frame, */ if (station) { memcpy(ssid, config->ssid, config->ssid_len); + ssid_len = config->ssid_len; ssid[config->ssid_len] = '\0'; network = station_network_find(station, ssid, SECURITY_PSK); @@ -1045,7 +1076,14 @@ static void dpp_handle_config_response_frame(const struct mmpdu_header *frame, __station_connect_network(station, network, bss, STATION_STATE_CONNECTING); else if (station) { - dpp->connect_scan_id = scan_active(dpp->wdev_id, NULL, 0, + struct scan_parameters params = {0}; + + params.ssid = (void *) ssid; + params.ssid_len = ssid_len; + + l_debug("Scanning for %s", ssid); + + dpp->connect_scan_id = scan_active_full(dpp->wdev_id, ¶ms, dpp_scan_triggered, dpp_scan_results, dpp, dpp_scan_destroy);