Support adding addresses with CIDR mask
Correctly parse and add submitted networks to sets to reflect the behavior of the `nft` command line. Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
parent
26a500ac96
commit
bad275abe2
43
nft.go
43
nft.go
@ -13,7 +13,6 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/google/nftables"
|
||||
"log"
|
||||
"net"
|
||||
@ -42,16 +41,42 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
||||
var element []nftables.SetElement
|
||||
|
||||
if task != "get" {
|
||||
address, _ := parseIPAddress(givenAddress)
|
||||
if address == nil {
|
||||
return nil, errors.New("invalid address")
|
||||
address, network, _, err := parseIPAddressOrNetworkString(givenAddress)
|
||||
if err != nil || address == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
element = []nftables.SetElement{
|
||||
{
|
||||
Key: []byte(address),
|
||||
},
|
||||
if network == nil {
|
||||
element = []nftables.SetElement{
|
||||
{
|
||||
Key: []byte(address),
|
||||
},
|
||||
}
|
||||
|
||||
} else {
|
||||
first, last, err := nftables.NetFirstAndLastIP(givenAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastNext, err := incrementIPAddress(last)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
element = []nftables.SetElement{
|
||||
{
|
||||
Key: []byte(first),
|
||||
},
|
||||
{
|
||||
Key: []byte(lastNext),
|
||||
IntervalEnd: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
log.Println(element)
|
||||
|
||||
}
|
||||
|
||||
var retmsg string
|
||||
@ -83,7 +108,7 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
||||
|
||||
fErr := nft.Flush()
|
||||
if fErr != nil {
|
||||
log.Println("nftablesHandler: failed to save changes: %w", fErr)
|
||||
log.Println("nftablesHandler: failed to save changes:", fErr)
|
||||
return nil, fErr
|
||||
}
|
||||
|
||||
|
75
utils.go
75
utils.go
@ -18,6 +18,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
@ -70,31 +71,51 @@ func doCheckToken(token string, hash string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func parseIPAddress(straddress string) (net.IP, string) {
|
||||
parsedaddress := net.ParseIP(straddress)
|
||||
var address net.IP
|
||||
family, err := getIPAddressFamily(straddress)
|
||||
if err == nil {
|
||||
if family == "ipv4" {
|
||||
address = parsedaddress.To4()
|
||||
} else if family == "ipv6" {
|
||||
address = parsedaddress.To16()
|
||||
} else {
|
||||
log.Println("unknown family, this should not happen")
|
||||
return nil, family
|
||||
}
|
||||
return address, family
|
||||
func parseIPAddressOrNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
|
||||
if strings.Contains(givenAddress, "/") {
|
||||
return parseIPNetworkString(givenAddress)
|
||||
} else {
|
||||
address, family, err := parseIPAddressString(givenAddress)
|
||||
return address, nil, family, err
|
||||
}
|
||||
log.Println("address parsing failed:", err)
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
func getIPAddressFamily(ip string) (string, error) {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return "", errors.New("Not an IP address")
|
||||
func parseIPNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
|
||||
ipObject, cidrObject, err := net.ParseCIDR(givenAddress)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
for i := 0; i < len(ip); i++ {
|
||||
switch ip[i] {
|
||||
address, family, err := parseIPAddress(ipObject)
|
||||
return address, cidrObject, family, err
|
||||
}
|
||||
|
||||
func parseIPAddressString(givenAddress string) (net.IP, string, error) {
|
||||
return parseIPAddress(net.ParseIP(givenAddress))
|
||||
}
|
||||
|
||||
func parseIPAddress(ipObject net.IP) (net.IP, string, error) {
|
||||
var address net.IP
|
||||
family, err := getIPAddressFamily(ipObject)
|
||||
if err == nil {
|
||||
if family == "ipv4" {
|
||||
address = ipObject.To4()
|
||||
} else if family == "ipv6" {
|
||||
address = ipObject.To16()
|
||||
} else {
|
||||
log.Println("unknown family, this should not happen")
|
||||
return nil, family, errors.New("unknown family")
|
||||
}
|
||||
return address, family, nil
|
||||
}
|
||||
log.Println("address parsing failed:", err)
|
||||
return nil, "", errors.New("invalid address")
|
||||
}
|
||||
|
||||
func getIPAddressFamily(ipObject net.IP) (string, error) {
|
||||
ipAddress := ipObject.String()
|
||||
|
||||
for i := 0; i < len(ipAddress); i++ {
|
||||
switch ipAddress[i] {
|
||||
case '.':
|
||||
return "ipv4", nil
|
||||
case ':':
|
||||
@ -102,5 +123,15 @@ func getIPAddressFamily(ip string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("unknown error")
|
||||
return "", errors.New("address family detection failed")
|
||||
}
|
||||
|
||||
func incrementIPAddress(ip net.IP) (net.IP, error) {
|
||||
for i := len(ip) - 1; i >= 0; i-- {
|
||||
ip[i]++
|
||||
if ip[i] != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user