Support adding addresses with CIDR mask
Correctly parse and add submitted networks to sets to reflect the behavior of the `nft` command line. Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
parent
26a500ac96
commit
bad275abe2
43
nft.go
43
nft.go
@ -13,7 +13,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -42,16 +41,42 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
|||||||
var element []nftables.SetElement
|
var element []nftables.SetElement
|
||||||
|
|
||||||
if task != "get" {
|
if task != "get" {
|
||||||
address, _ := parseIPAddress(givenAddress)
|
address, network, _, err := parseIPAddressOrNetworkString(givenAddress)
|
||||||
if address == nil {
|
if err != nil || address == nil {
|
||||||
return nil, errors.New("invalid address")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
element = []nftables.SetElement{
|
if network == nil {
|
||||||
{
|
element = []nftables.SetElement{
|
||||||
Key: []byte(address),
|
{
|
||||||
},
|
Key: []byte(address),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
first, last, err := nftables.NetFirstAndLastIP(givenAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lastNext, err := incrementIPAddress(last)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
element = []nftables.SetElement{
|
||||||
|
{
|
||||||
|
Key: []byte(first),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: []byte(lastNext),
|
||||||
|
IntervalEnd: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Println(element)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var retmsg string
|
var retmsg string
|
||||||
@ -83,7 +108,7 @@ func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
|||||||
|
|
||||||
fErr := nft.Flush()
|
fErr := nft.Flush()
|
||||||
if fErr != nil {
|
if fErr != nil {
|
||||||
log.Println("nftablesHandler: failed to save changes: %w", fErr)
|
log.Println("nftablesHandler: failed to save changes:", fErr)
|
||||||
return nil, fErr
|
return nil, fErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
75
utils.go
75
utils.go
@ -18,6 +18,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Response struct {
|
type Response struct {
|
||||||
@ -70,31 +71,51 @@ func doCheckToken(token string, hash string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseIPAddress(straddress string) (net.IP, string) {
|
func parseIPAddressOrNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
|
||||||
parsedaddress := net.ParseIP(straddress)
|
if strings.Contains(givenAddress, "/") {
|
||||||
var address net.IP
|
return parseIPNetworkString(givenAddress)
|
||||||
family, err := getIPAddressFamily(straddress)
|
} else {
|
||||||
if err == nil {
|
address, family, err := parseIPAddressString(givenAddress)
|
||||||
if family == "ipv4" {
|
return address, nil, family, err
|
||||||
address = parsedaddress.To4()
|
|
||||||
} else if family == "ipv6" {
|
|
||||||
address = parsedaddress.To16()
|
|
||||||
} else {
|
|
||||||
log.Println("unknown family, this should not happen")
|
|
||||||
return nil, family
|
|
||||||
}
|
|
||||||
return address, family
|
|
||||||
}
|
}
|
||||||
log.Println("address parsing failed:", err)
|
|
||||||
return nil, ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIPAddressFamily(ip string) (string, error) {
|
func parseIPNetworkString(givenAddress string) (net.IP, *net.IPNet, string, error) {
|
||||||
if net.ParseIP(ip) == nil {
|
ipObject, cidrObject, err := net.ParseCIDR(givenAddress)
|
||||||
return "", errors.New("Not an IP address")
|
if err != nil {
|
||||||
|
return nil, nil, "", err
|
||||||
}
|
}
|
||||||
for i := 0; i < len(ip); i++ {
|
address, family, err := parseIPAddress(ipObject)
|
||||||
switch ip[i] {
|
return address, cidrObject, family, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIPAddressString(givenAddress string) (net.IP, string, error) {
|
||||||
|
return parseIPAddress(net.ParseIP(givenAddress))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIPAddress(ipObject net.IP) (net.IP, string, error) {
|
||||||
|
var address net.IP
|
||||||
|
family, err := getIPAddressFamily(ipObject)
|
||||||
|
if err == nil {
|
||||||
|
if family == "ipv4" {
|
||||||
|
address = ipObject.To4()
|
||||||
|
} else if family == "ipv6" {
|
||||||
|
address = ipObject.To16()
|
||||||
|
} else {
|
||||||
|
log.Println("unknown family, this should not happen")
|
||||||
|
return nil, family, errors.New("unknown family")
|
||||||
|
}
|
||||||
|
return address, family, nil
|
||||||
|
}
|
||||||
|
log.Println("address parsing failed:", err)
|
||||||
|
return nil, "", errors.New("invalid address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPAddressFamily(ipObject net.IP) (string, error) {
|
||||||
|
ipAddress := ipObject.String()
|
||||||
|
|
||||||
|
for i := 0; i < len(ipAddress); i++ {
|
||||||
|
switch ipAddress[i] {
|
||||||
case '.':
|
case '.':
|
||||||
return "ipv4", nil
|
return "ipv4", nil
|
||||||
case ':':
|
case ':':
|
||||||
@ -102,5 +123,15 @@ func getIPAddressFamily(ip string) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", errors.New("unknown error")
|
return "", errors.New("address family detection failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func incrementIPAddress(ip net.IP) (net.IP, error) {
|
||||||
|
for i := len(ip) - 1; i >= 0; i-- {
|
||||||
|
ip[i]++
|
||||||
|
if ip[i] != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ip, nil
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user