From b9c3feb198d00846ab64308473cfb416a4f162d3 Mon Sep 17 00:00:00 2001 From: James Prestwood Date: Fri, 22 Nov 2024 07:15:37 -0800 Subject: [PATCH] handshake: add ref counting to handshake_state This adds a ref count to the handshake state object (as well as ref/unref APIs). Currently IWD is careful to ensure that netdev holds the root reference to the handshake state. Other modules do track it themselves, but ensure that it doesn't get referenced after netdev frees it. Future work related to PMKSA will require that station holds a references to the handshake state, specifically for retry logic, after netdev is done with it so we need a way to delay the free until station is also done. --- src/adhoc.c | 4 ++-- src/ap.c | 2 +- src/handshake.c | 12 +++++++++++- src/handshake.h | 9 ++++++--- src/netdev.c | 5 +++-- src/p2p.c | 2 +- src/station.c | 8 ++++---- src/wsc.c | 2 +- 8 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/adhoc.c b/src/adhoc.c index e787dab1..930240ae 100644 --- a/src/adhoc.c +++ b/src/adhoc.c @@ -94,13 +94,13 @@ static void adhoc_sta_free(void *data) eapol_sm_free(sta->sm); if (sta->hs_sta) - handshake_state_free(sta->hs_sta); + handshake_state_unref(sta->hs_sta); if (sta->sm_a) eapol_sm_free(sta->sm_a); if (sta->hs_auth) - handshake_state_free(sta->hs_auth); + handshake_state_unref(sta->hs_auth); end: l_free(sta); diff --git a/src/ap.c b/src/ap.c index 562e00c8..d52b7e55 100644 --- a/src/ap.c +++ b/src/ap.c @@ -230,7 +230,7 @@ static void ap_stop_handshake(struct sta_state *sta) } if (sta->hs) { - handshake_state_free(sta->hs); + handshake_state_unref(sta->hs); sta->hs = NULL; } diff --git a/src/handshake.c b/src/handshake.c index fc1978df..7fb75dc4 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -103,7 +103,14 @@ void __handshake_set_install_ext_tk_func(handshake_install_ext_tk_func_t func) install_ext_tk = func; } -void handshake_state_free(struct handshake_state *s) +struct handshake_state *handshake_state_ref(struct handshake_state *s) +{ + __sync_fetch_and_add(&s->refcount, 1); + + return s; +} + +void handshake_state_unref(struct handshake_state *s) { __typeof__(s->free) destroy; @@ -117,6 +124,9 @@ void handshake_state_free(struct handshake_state *s) return; } + if (__sync_sub_and_fetch(&s->refcount, 1)) + return; + l_free(s->authenticator_ie); l_free(s->supplicant_ie); l_free(s->authenticator_rsnxe); diff --git a/src/handshake.h b/src/handshake.h index d1116472..6c0946d4 100644 --- a/src/handshake.h +++ b/src/handshake.h @@ -170,6 +170,8 @@ struct handshake_state { bool in_event; handshake_event_func_t event_func; + + int refcount; }; #define HSID(x) UNIQUE_ID(handshake_, x) @@ -186,7 +188,7 @@ struct handshake_state { ##__VA_ARGS__); \ \ if (!HSID(hs)->in_event) { \ - handshake_state_free(HSID(hs)); \ + handshake_state_unref(HSID(hs)); \ HSID(freed) = true; \ } else \ HSID(hs)->in_event = false; \ @@ -194,7 +196,8 @@ struct handshake_state { HSID(freed); \ }) -void handshake_state_free(struct handshake_state *s); +struct handshake_state *handshake_state_ref(struct handshake_state *s); +void handshake_state_unref(struct handshake_state *s); void handshake_state_set_supplicant_address(struct handshake_state *s, const uint8_t *spa); @@ -316,4 +319,4 @@ void handshake_util_build_gtk_kde(enum crypto_cipher cipher, const uint8_t *key, void handshake_util_build_igtk_kde(enum crypto_cipher cipher, const uint8_t *key, unsigned int key_index, uint8_t *to); -DEFINE_CLEANUP_FUNC(handshake_state_free); +DEFINE_CLEANUP_FUNC(handshake_state_unref); diff --git a/src/netdev.c b/src/netdev.c index e86ef1bd..4dccb78a 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -376,6 +376,7 @@ struct handshake_state *netdev_handshake_state_new(struct netdev *netdev) nhs->super.ifindex = netdev->index; nhs->super.free = netdev_handshake_state_free; + nhs->super.refcount = 1; nhs->netdev = netdev; /* @@ -828,7 +829,7 @@ static void netdev_connect_free(struct netdev *netdev) eapol_preauth_cancel(netdev->index); if (netdev->handshake) { - handshake_state_free(netdev->handshake); + handshake_state_unref(netdev->handshake); netdev->handshake = NULL; } @@ -4239,7 +4240,7 @@ int netdev_reassociate(struct netdev *netdev, const struct scan_bss *target_bss, eapol_sm_free(old_sm); if (old_hs) - handshake_state_free(old_hs); + handshake_state_unref(old_hs); return 0; } diff --git a/src/p2p.c b/src/p2p.c index 676ef146..7d89da21 100644 --- a/src/p2p.c +++ b/src/p2p.c @@ -1497,7 +1497,7 @@ static void p2p_handshake_event(struct handshake_state *hs, static void p2p_try_connect_group(struct p2p_device *dev) { struct scan_bss *bss = dev->conn_wsc_bss; - _auto_(handshake_state_free) struct handshake_state *hs = NULL; + _auto_(handshake_state_unref) struct handshake_state *hs = NULL; struct iovec ie_iov[16]; int ie_num = 0; int r; diff --git a/src/station.c b/src/station.c index 1238734f..c1c7ba9d 100644 --- a/src/station.c +++ b/src/station.c @@ -1394,7 +1394,7 @@ static struct handshake_state *station_handshake_setup(struct station *station, return hs; not_supported: - handshake_state_free(hs); + handshake_state_unref(hs); return NULL; } @@ -2484,7 +2484,7 @@ static void station_preauthenticate_cb(struct netdev *netdev, } if (station_transition_reassociate(station, bss, new_hs) < 0) { - handshake_state_free(new_hs); + handshake_state_unref(new_hs); station_roam_failed(station); } } @@ -2687,7 +2687,7 @@ static bool station_try_next_transition(struct station *station, } if (station_transition_reassociate(station, bss, new_hs) < 0) { - handshake_state_free(new_hs); + handshake_state_unref(new_hs); return false; } @@ -3734,7 +3734,7 @@ int __station_connect_network(struct station *station, struct network *network, station_netdev_event, station_connect_cb, station); if (r < 0) { - handshake_state_free(hs); + handshake_state_unref(hs); return r; } diff --git a/src/wsc.c b/src/wsc.c index f88f5deb..44b8d3de 100644 --- a/src/wsc.c +++ b/src/wsc.c @@ -393,7 +393,7 @@ static int wsc_enrollee_connect(struct wsc_enrollee *wsce, struct scan_bss *bss, return 0; error: - handshake_state_free(hs); + handshake_state_unref(hs); return r; }