diff --git a/src/blacklist.c b/src/blacklist.c index 21f85a75..8ae474b3 100644 --- a/src/blacklist.c +++ b/src/blacklist.c @@ -51,10 +51,27 @@ struct blacklist_entry { uint8_t addr[6]; uint64_t added_time; uint64_t expire_time; + enum blacklist_reason reason; +}; + +struct blacklist_search { + const uint8_t *addr; + enum blacklist_reason reason; }; static struct l_queue *blacklist; +static uint64_t get_reason_timeout(enum blacklist_reason reason) +{ + switch (reason) { + case BLACKLIST_REASON_CONNECT_FAILED: + return blacklist_initial_timeout; + default: + l_warn("Unhandled blacklist reason: %u", reason); + return 0; + } +} + static bool check_if_expired(void *data, void *user_data) { struct blacklist_entry *entry = data; @@ -87,15 +104,31 @@ static bool match_addr(const void *a, const void *b) return false; } -void blacklist_add_bss(const uint8_t *addr) +static bool match_addr_and_reason(const void *a, const void *b) +{ + const struct blacklist_entry *entry = a; + const struct blacklist_search *search = b; + + if (entry->reason != search->reason) + return false; + + if (!memcmp(entry->addr, search->addr, 6)) + return true; + + return false; +} + +void blacklist_add_bss(const uint8_t *addr, enum blacklist_reason reason) { struct blacklist_entry *entry; - - if (!blacklist_initial_timeout) - return; + uint64_t timeout; blacklist_prune(); + timeout = get_reason_timeout(reason); + if (!timeout) + return; + entry = l_queue_find(blacklist, match_addr, addr); if (entry) { @@ -115,22 +148,26 @@ void blacklist_add_bss(const uint8_t *addr) entry = l_new(struct blacklist_entry, 1); entry->added_time = l_time_now(); - entry->expire_time = l_time_offset(entry->added_time, - blacklist_initial_timeout); + entry->expire_time = l_time_offset(entry->added_time, timeout); + entry->reason = reason; memcpy(entry->addr, addr, 6); l_queue_push_tail(blacklist, entry); } -bool blacklist_contains_bss(const uint8_t *addr) +bool blacklist_contains_bss(const uint8_t *addr, enum blacklist_reason reason) { bool ret; uint64_t time_now; struct blacklist_entry *entry; + struct blacklist_search search = { + .addr = addr, + .reason = reason + }; blacklist_prune(); - entry = l_queue_find(blacklist, match_addr, addr); + entry = l_queue_find(blacklist, match_addr_and_reason, &search); if (!entry) return false; @@ -142,13 +179,17 @@ bool blacklist_contains_bss(const uint8_t *addr) return ret; } -void blacklist_remove_bss(const uint8_t *addr) +void blacklist_remove_bss(const uint8_t *addr, enum blacklist_reason reason) { struct blacklist_entry *entry; + struct blacklist_search search = { + .addr = addr, + .reason = reason + }; blacklist_prune(); - entry = l_queue_remove_if(blacklist, match_addr, addr); + entry = l_queue_remove_if(blacklist, match_addr_and_reason, &search); if (!entry) return; diff --git a/src/blacklist.h b/src/blacklist.h index 56260e20..a87e5eca 100644 --- a/src/blacklist.h +++ b/src/blacklist.h @@ -20,6 +20,14 @@ * */ -void blacklist_add_bss(const uint8_t *addr); -bool blacklist_contains_bss(const uint8_t *addr); -void blacklist_remove_bss(const uint8_t *addr); +enum blacklist_reason { + /* + * When a BSS is blacklisted using this reason IWD will refuse to + * connect to it via autoconnect + */ + BLACKLIST_REASON_CONNECT_FAILED, +}; + +void blacklist_add_bss(const uint8_t *addr, enum blacklist_reason reason); +bool blacklist_contains_bss(const uint8_t *addr, enum blacklist_reason reason); +void blacklist_remove_bss(const uint8_t *addr, enum blacklist_reason reason); diff --git a/src/network.c b/src/network.c index 0a40a6c5..4602a110 100644 --- a/src/network.c +++ b/src/network.c @@ -1280,7 +1280,8 @@ struct scan_bss *network_bss_select(struct network *network, if (l_queue_find(network->blacklist, match_bss, bss)) continue; - if (blacklist_contains_bss(bss->addr)) + if (blacklist_contains_bss(bss->addr, + BLACKLIST_REASON_CONNECT_FAILED)) continue; /* OWE Transition BSS */ diff --git a/src/station.c b/src/station.c index 0b20e785..e2ed78f3 100644 --- a/src/station.c +++ b/src/station.c @@ -2880,7 +2880,8 @@ static bool station_roam_scan_notify(int err, struct l_queue *bss_list, if (network_can_connect_bss(network, bss) < 0) goto next; - if (blacklist_contains_bss(bss->addr)) + if (blacklist_contains_bss(bss->addr, + BLACKLIST_REASON_CONNECT_FAILED)) goto next; rank = bss->rank; @@ -3400,7 +3401,8 @@ static bool station_retry_with_reason(struct station *station, break; } - blacklist_add_bss(station->connected_bss->addr); + blacklist_add_bss(station->connected_bss->addr, + BLACKLIST_REASON_CONNECT_FAILED); /* * Network blacklist the BSS as well, since the timeout blacklist could @@ -3471,7 +3473,8 @@ static bool station_retry_with_status(struct station *station, * obtain that IE, but this should be done in the future. */ if (!IS_TEMPORARY_STATUS(status_code)) - blacklist_add_bss(station->connected_bss->addr); + blacklist_add_bss(station->connected_bss->addr, + BLACKLIST_REASON_CONNECT_FAILED); /* * Unconditionally network blacklist the BSS if we are retrying. This @@ -3566,7 +3569,8 @@ static void station_connect_cb(struct netdev *netdev, enum netdev_result result, switch (result) { case NETDEV_RESULT_OK: - blacklist_remove_bss(station->connected_bss->addr); + blacklist_remove_bss(station->connected_bss->addr, + BLACKLIST_REASON_CONNECT_FAILED); station_connect_ok(station); return; case NETDEV_RESULT_DISCONNECTED: