Support adding addresses with CIDR mask

Correctly parse and add submitted networks to sets to reflect the
behavior of the `nft` command line.

Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
Georg Pfuetzenreuter 2024-09-10 22:11:03 +02:00
parent 26a500ac96
commit bad275abe2
Signed by: Georg
GPG Key ID: 1ED2F138E7E6FF57
2 changed files with 87 additions and 31 deletions

43
nft.go
View File

@ -13,7 +13,6 @@ package main
import (
"bytes"
"errors"
"github.com/google/nftables"
"log"
"net"
@ -42,16 +41,42 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
var element []nftables.SetElement
if task != "get" {
address, _ := parseIPAddress(givenAddress)
if address == nil {
return nil, errors.New("invalid address")
address, network, _, err := parseIPAddressOrNetworkString(givenAddress)
if err != nil || address == nil {
return nil, err
}
element = []nftables.SetElement{
{
Key: []byte(address),
},
if network == nil {
element = []nftables.SetElement{
{
Key: []byte(address),
},
}
} else {
first, last, err := nftables.NetFirstAndLastIP(givenAddress)
if err != nil {
return nil, err
}
lastNext, err := incrementIPAddress(last)
if err != nil {
return nil, err
}
element = []nftables.SetElement{
{
Key: []byte(first),
},
{
Key: []byte(lastNext),
IntervalEnd: true,
},
}
}
log.Println(element)
}
var retmsg string
@ -83,7 +108,7 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
fErr := nft.Flush()
if fErr != nil {
log.Println("nftablesHandler: failed to save changes: %w", fErr)
log.Println("nftablesHandler: failed to save changes:", fErr)
return nil, fErr
}

View File

@ -18,6 +18,7 @@ import (
"log"
"net"
"net/http"
"strings"
)
type Response struct {
@ -70,31 +71,51 @@ func doCheckToken(token string, hash string) bool {
}
}
func parseIPAddress(straddress string) (net.IP, string) {
parsedaddress := net.ParseIP(straddress)
var address net.IP
family, err := getIPAddressFamily(straddress)
if err == nil {
if family == "ipv4" {
address = parsedaddress.To4()
} else if family == "ipv6" {
address = parsedaddress.To16()
} else {
log.Println("unknown family, this should not happen")
return nil, family
}
return address, family
func parseIPAddressOrNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
if strings.Contains(givenAddress, "/") {
return parseIPNetworkString(givenAddress)
} else {
address, family, err := parseIPAddressString(givenAddress)
return address, nil, family, err
}
log.Println("address parsing failed:", err)
return nil, ""
}
func getIPAddressFamily(ip string) (string, error) {
if net.ParseIP(ip) == nil {
return "", errors.New("Not an IP address")
func parseIPNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
ipObject, cidrObject, err := net.ParseCIDR(givenAddress)
if err != nil {
return nil, nil, "", err
}
for i := 0; i < len(ip); i++ {
switch ip[i] {
address, family, err := parseIPAddress(ipObject)
return address, cidrObject, family, err
}
func parseIPAddressString(givenAddress string) (net.IP, string, error) {
return parseIPAddress(net.ParseIP(givenAddress))
}
func parseIPAddress(ipObject net.IP) (net.IP, string, error) {
var address net.IP
family, err := getIPAddressFamily(ipObject)
if err == nil {
if family == "ipv4" {
address = ipObject.To4()
} else if family == "ipv6" {
address = ipObject.To16()
} else {
log.Println("unknown family, this should not happen")
return nil, family, errors.New("unknown family")
}
return address, family, nil
}
log.Println("address parsing failed:", err)
return nil, "", errors.New("invalid address")
}
func getIPAddressFamily(ipObject net.IP) (string, error) {
ipAddress := ipObject.String()
for i := 0; i < len(ipAddress); i++ {
switch ipAddress[i] {
case '.':
return "ipv4", nil
case ':':
@ -102,5 +123,15 @@ func getIPAddressFamily(ip string) (string, error) {
}
}
return "", errors.New("unknown error")
return "", errors.New("address family detection failed")
}
func incrementIPAddress(ip net.IP) (net.IP, error) {
for i := len(ip) - 1; i >= 0; i-- {
ip[i]++
if ip[i] != 0 {
break
}
}
return ip, nil
}