diff --git a/src/netdev.c b/src/netdev.c index 1b50a540..d298977a 100644 --- a/src/netdev.c +++ b/src/netdev.c @@ -4033,11 +4033,14 @@ static void netdev_connect_common(struct netdev *netdev, netdev->ap = sae_sm_new(hs, netdev_sae_tx_authenticate, netdev_sae_tx_associate, netdev); - else + else { netdev->ap = sae_sm_new(hs, netdev_external_auth_sae_tx_authenticate, netdev_external_auth_sae_tx_associate, netdev); + sae_sm_force_default_group(netdev->ap); + sae_sm_force_hunt_and_peck(netdev->ap); + } if (sae_sm_is_h2e(netdev->ap)) { uint8_t own_rsnxe[20]; diff --git a/src/sae.c b/src/sae.c index 97c0af05..eb463484 100644 --- a/src/sae.c +++ b/src/sae.c @@ -1550,6 +1550,26 @@ struct auth_proto *sae_sm_new(struct handshake_state *hs, return &sm->ap; } +bool sae_sm_force_hunt_and_peck(struct auth_proto *ap) +{ + struct sae_sm *sm = l_container_of(ap, struct sae_sm, ap); + + sae_debug("Forcing SAE Hunting and Pecking"); + sm->sae_type = CRYPTO_SAE_LOOPING; + + return true; +} + +bool sae_sm_force_default_group(struct auth_proto *ap) +{ + struct sae_sm *sm = l_container_of(ap, struct sae_sm, ap); + + sae_debug("Forcing Default Group"); + sm->force_default_group = true; + + return true; +} + static int sae_init(void) { if (getenv("IWD_SAE_DEBUG")) diff --git a/src/sae.h b/src/sae.h index 668d084f..4a59999b 100644 --- a/src/sae.h +++ b/src/sae.h @@ -34,3 +34,6 @@ struct auth_proto *sae_sm_new(struct handshake_state *hs, sae_tx_authenticate_func_t tx_auth, sae_tx_associate_func_t tx_assoc, void *user_data); + +bool sae_sm_force_hunt_and_peck(struct auth_proto *ap); +bool sae_sm_force_default_group(struct auth_proto *ap);