From bad275abe2f0c07555e19a36c606aa91ad9f3af4 Mon Sep 17 00:00:00 2001 From: Georg Pfuetzenreuter Date: Tue, 10 Sep 2024 22:11:03 +0200 Subject: [PATCH] Support adding addresses with CIDR mask Correctly parse and add submitted networks to sets to reflect the behavior of the `nft` command line. Signed-off-by: Georg Pfuetzenreuter --- nft.go | 43 +++++++++++++++++++++++++------- utils.go | 75 +++++++++++++++++++++++++++++++++++++++----------------- 2 files changed, 87 insertions(+), 31 deletions(-) diff --git a/nft.go b/nft.go index 583dc02..8ac1437 100644 --- a/nft.go +++ b/nft.go @@ -13,7 +13,6 @@ package main import ( "bytes" - "errors" "github.com/google/nftables" "log" "net" @@ -42,16 +41,42 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) { var element []nftables.SetElement if task != "get" { - address, _ := parseIPAddress(givenAddress) - if address == nil { - return nil, errors.New("invalid address") + address, network, _, err := parseIPAddressOrNetworkString(givenAddress) + if err != nil || address == nil { + return nil, err } - element = []nftables.SetElement{ - { - Key: []byte(address), - }, + if network == nil { + element = []nftables.SetElement{ + { + Key: []byte(address), + }, + } + + } else { + first, last, err := nftables.NetFirstAndLastIP(givenAddress) + if err != nil { + return nil, err + } + + lastNext, err := incrementIPAddress(last) + if err != nil { + return nil, err + } + + element = []nftables.SetElement{ + { + Key: []byte(first), + }, + { + Key: []byte(lastNext), + IntervalEnd: true, + }, + } } + + log.Println(element) + } var retmsg string @@ -83,7 +108,7 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) { fErr := nft.Flush() if fErr != nil { - log.Println("nftablesHandler: failed to save changes: %w", fErr) + log.Println("nftablesHandler: failed to save changes:", fErr) return nil, fErr } diff --git a/utils.go b/utils.go index e37bdd9..8702bf3 100644 --- a/utils.go +++ b/utils.go @@ -18,6 +18,7 @@ import ( "log" "net" "net/http" + "strings" ) type Response struct { @@ -70,31 +71,51 @@ func doCheckToken(token string, hash string) bool { } } -func parseIPAddress(straddress string) (net.IP, string) { - parsedaddress := net.ParseIP(straddress) - var address net.IP - family, err := getIPAddressFamily(straddress) - if err == nil { - if family == "ipv4" { - address = parsedaddress.To4() - } else if family == "ipv6" { - address = parsedaddress.To16() - } else { - log.Println("unknown family, this should not happen") - return nil, family - } - return address, family +func parseIPAddressOrNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) { + if strings.Contains(givenAddress, "/") { + return parseIPNetworkString(givenAddress) + } else { + address, family, err := parseIPAddressString(givenAddress) + return address, nil, family, err } - log.Println("address parsing failed:", err) - return nil, "" } -func getIPAddressFamily(ip string) (string, error) { - if net.ParseIP(ip) == nil { - return "", errors.New("Not an IP address") +func parseIPNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) { + ipObject, cidrObject, err := net.ParseCIDR(givenAddress) + if err != nil { + return nil, nil, "", err } - for i := 0; i < len(ip); i++ { - switch ip[i] { + address, family, err := parseIPAddress(ipObject) + return address, cidrObject, family, err +} + +func parseIPAddressString(givenAddress string) (net.IP, string, error) { + return parseIPAddress(net.ParseIP(givenAddress)) +} + +func parseIPAddress(ipObject net.IP) (net.IP, string, error) { + var address net.IP + family, err := getIPAddressFamily(ipObject) + if err == nil { + if family == "ipv4" { + address = ipObject.To4() + } else if family == "ipv6" { + address = ipObject.To16() + } else { + log.Println("unknown family, this should not happen") + return nil, family, errors.New("unknown family") + } + return address, family, nil + } + log.Println("address parsing failed:", err) + return nil, "", errors.New("invalid address") +} + +func getIPAddressFamily(ipObject net.IP) (string, error) { + ipAddress := ipObject.String() + + for i := 0; i < len(ipAddress); i++ { + switch ipAddress[i] { case '.': return "ipv4", nil case ':': @@ -102,5 +123,15 @@ func getIPAddressFamily(ip string) (string, error) { } } - return "", errors.New("unknown error") + return "", errors.New("address family detection failed") +} + +func incrementIPAddress(ip net.IP) (net.IP, error) { + for i := len(ip) - 1; i >= 0; i-- { + ip[i]++ + if ip[i] != 0 { + break + } + } + return ip, nil }