Support adding addresses with CIDR mask

Correctly parse and add submitted networks to sets to reflect the
behavior of the `nft` command line.

Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
Georg Pfuetzenreuter 2024-09-10 22:11:03 +02:00
parent 26a500ac96
commit bad275abe2
Signed by: Georg
GPG Key ID: 1ED2F138E7E6FF57
2 changed files with 87 additions and 31 deletions

43
nft.go
View File

@ -13,7 +13,6 @@ package main
import ( import (
"bytes" "bytes"
"errors"
"github.com/google/nftables" "github.com/google/nftables"
"log" "log"
"net" "net"
@ -42,16 +41,42 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
var element []nftables.SetElement var element []nftables.SetElement
if task != "get" { if task != "get" {
address, _ := parseIPAddress(givenAddress) address, network, _, err := parseIPAddressOrNetworkString(givenAddress)
if address == nil { if err != nil || address == nil {
return nil, errors.New("invalid address") return nil, err
} }
element = []nftables.SetElement{ if network == nil {
{ element = []nftables.SetElement{
Key: []byte(address), {
}, Key: []byte(address),
},
}
} else {
first, last, err := nftables.NetFirstAndLastIP(givenAddress)
if err != nil {
return nil, err
}
lastNext, err := incrementIPAddress(last)
if err != nil {
return nil, err
}
element = []nftables.SetElement{
{
Key: []byte(first),
},
{
Key: []byte(lastNext),
IntervalEnd: true,
},
}
} }
log.Println(element)
} }
var retmsg string var retmsg string
@ -83,7 +108,7 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
fErr := nft.Flush() fErr := nft.Flush()
if fErr != nil { if fErr != nil {
log.Println("nftablesHandler: failed to save changes: %w", fErr) log.Println("nftablesHandler: failed to save changes:", fErr)
return nil, fErr return nil, fErr
} }

View File

@ -18,6 +18,7 @@ import (
"log" "log"
"net" "net"
"net/http" "net/http"
"strings"
) )
type Response struct { type Response struct {
@ -70,31 +71,51 @@ func doCheckToken(token string, hash string) bool {
} }
} }
func parseIPAddress(straddress string) (net.IP, string) { func parseIPAddressOrNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
parsedaddress := net.ParseIP(straddress) if strings.Contains(givenAddress, "/") {
var address net.IP return parseIPNetworkString(givenAddress)
family, err := getIPAddressFamily(straddress) } else {
if err == nil { address, family, err := parseIPAddressString(givenAddress)
if family == "ipv4" { return address, nil, family, err
address = parsedaddress.To4()
} else if family == "ipv6" {
address = parsedaddress.To16()
} else {
log.Println("unknown family, this should not happen")
return nil, family
}
return address, family
} }
log.Println("address parsing failed:", err)
return nil, ""
} }
func getIPAddressFamily(ip string) (string, error) { func parseIPNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
if net.ParseIP(ip) == nil { ipObject, cidrObject, err := net.ParseCIDR(givenAddress)
return "", errors.New("Not an IP address") if err != nil {
return nil, nil, "", err
} }
for i := 0; i < len(ip); i++ { address, family, err := parseIPAddress(ipObject)
switch ip[i] { return address, cidrObject, family, err
}
func parseIPAddressString(givenAddress string) (net.IP, string, error) {
return parseIPAddress(net.ParseIP(givenAddress))
}
func parseIPAddress(ipObject net.IP) (net.IP, string, error) {
var address net.IP
family, err := getIPAddressFamily(ipObject)
if err == nil {
if family == "ipv4" {
address = ipObject.To4()
} else if family == "ipv6" {
address = ipObject.To16()
} else {
log.Println("unknown family, this should not happen")
return nil, family, errors.New("unknown family")
}
return address, family, nil
}
log.Println("address parsing failed:", err)
return nil, "", errors.New("invalid address")
}
func getIPAddressFamily(ipObject net.IP) (string, error) {
ipAddress := ipObject.String()
for i := 0; i < len(ipAddress); i++ {
switch ipAddress[i] {
case '.': case '.':
return "ipv4", nil return "ipv4", nil
case ':': case ':':
@ -102,5 +123,15 @@ func getIPAddressFamily(ip string) (string, error) {
} }
} }
return "", errors.New("unknown error") return "", errors.New("address family detection failed")
}
func incrementIPAddress(ip net.IP) (net.IP, error) {
for i := len(ip) - 1; i >= 0; i-- {
ip[i]++
if ip[i] != 0 {
break
}
}
return ip, nil
} }