This repository has been archived on 2024-09-28. You can view files and clone it, but cannot push or open issues or pull requests.
nftables-http-api-go/nft.go
Georg Pfuetzenreuter bad275abe2
Support adding addresses with CIDR mask
Correctly parse and add submitted networks to sets to reflect the
behavior of the `nft` command line.

Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
2024-09-10 22:11:03 +02:00

193 lines
4.2 KiB
Go

/*
* This file is part of nftables-http-api.
* Copyright (C) 2024 Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
*
* The nftables-http-api program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
* You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"bytes"
"github.com/google/nftables"
"log"
"net"
)
type nftError struct {
Message string
}
func (nfterr nftError) Error() string {
return nfterr.Message
}
func handleNft(task string, givenSet string, givenAddress string) (any, error) {
nft, err := nftables.New()
if err != nil {
log.Println("handleNft():", err)
return "", err
}
set, err := getNftSet(nft, givenSet)
if err != nil {
return nil, err
}
var element []nftables.SetElement
if task != "get" {
address, network, _, err := parseIPAddressOrNetworkString(givenAddress)
if err != nil || address == nil {
return nil, err
}
if network == nil {
element = []nftables.SetElement{
{
Key: []byte(address),
},
}
} else {
first, last, err := nftables.NetFirstAndLastIP(givenAddress)
if err != nil {
return nil, err
}
lastNext, err := incrementIPAddress(last)
if err != nil {
return nil, err
}
element = []nftables.SetElement{
{
Key: []byte(first),
},
{
Key: []byte(lastNext),
IntervalEnd: true,
},
}
}
log.Println(element)
}
var retmsg string
switch task {
case "get":
nftResult, err := getNftSetElements(nft, set)
if err != nil {
return nil, err
}
return nftResult, nil
case "add":
contains, err := containsNftSetElement(nft, set, element)
if err != nil {
return nil, err
}
if contains {
return "already", nil
} else {
err := nft.SetAddElements(set, element)
if err != nil {
log.Println("handleNft() add failure:", err)
return nil, err
}
retmsg = "added"
}
}
fErr := nft.Flush()
if fErr != nil {
log.Println("nftablesHandler: failed to save changes:", fErr)
return nil, fErr
}
return retmsg, nil
}
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
targetTable := "filter" // TODO: make table configurable or smarter
foundTables, err := nft.ListTables()
if err != nil {
log.Printf("getNftTable(): %s", err)
return nil, err
}
exists := false
var table *nftables.Table
for _, foundTable := range foundTables {
if foundTable.Name == targetTable {
exists = true
table = foundTable
break
}
}
if !exists {
log.Printf("Table %s does not exist, cannot proceed", targetTable)
return nil, nftError{Message: "Table does not exist"}
}
return table, nil
}
func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
foundTable, err := getNftTable(nft)
if err != nil {
return nil, err
}
foundSet, err := nft.GetSetByName(foundTable, setName)
if err != nil || foundSet == nil {
log.Printf("Set lookup for %s failed: %s", setName, err)
return nil, err
}
log.Printf("Found set %s", foundSet.Name)
return foundSet, nil
}
func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) {
setElements, err := nft.GetSetElements(set)
if err != nil {
return nil, err
}
var returnElements []string
for i, element := range setElements {
ip := net.IP(element.Key)
log.Printf("Element %d: %s", i, ip)
returnElements = append(returnElements, ip.String())
}
return returnElements, nil
}
func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) {
existingElements, err := nft.GetSetElements(set)
if err != nil {
return false, err
}
for _, existingElement := range existingElements {
if bytes.Equal(existingElement.Key, element[0].Key) {
log.Printf("Existing element found %v", existingElement.Key)
return true, nil
}
}
return false, nil
}