This repository has been archived on 2024-09-28. You can view files and clone it, but cannot push or open issues or pull requests.
Georg Pfuetzenreuter bad275abe2
Support adding addresses with CIDR mask
Correctly parse and add submitted networks to sets to reflect the
behavior of the `nft` command line.

Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
2024-09-10 22:11:03 +02:00

193 lines
4.2 KiB
Go

/*
* This file is part of nftables-http-api.
* Copyright (C) 2024 Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
*
* The nftables-http-api program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
* You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package main
import (
"bytes"
"github.com/google/nftables"
"log"
"net"
)
type nftError struct {
Message string
}
func (nfterr nftError) Error() string {
return nfterr.Message
}
func handleNft(task string, givenSet string, givenAddress string) (any, error) {
nft, err := nftables.New()
if err != nil {
log.Println("handleNft():", err)
return "", err
}
set, err := getNftSet(nft, givenSet)
if err != nil {
return nil, err
}
var element []nftables.SetElement
if task != "get" {
address, network, _, err := parseIPAddressOrNetworkString(givenAddress)
if err != nil || address == nil {
return nil, err
}
if network == nil {
element = []nftables.SetElement{
{
Key: []byte(address),
},
}
} else {
first, last, err := nftables.NetFirstAndLastIP(givenAddress)
if err != nil {
return nil, err
}
lastNext, err := incrementIPAddress(last)
if err != nil {
return nil, err
}
element = []nftables.SetElement{
{
Key: []byte(first),
},
{
Key: []byte(lastNext),
IntervalEnd: true,
},
}
}
log.Println(element)
}
var retmsg string
switch task {
case "get":
nftResult, err := getNftSetElements(nft, set)
if err != nil {
return nil, err
}
return nftResult, nil
case "add":
contains, err := containsNftSetElement(nft, set, element)
if err != nil {
return nil, err
}
if contains {
return "already", nil
} else {
err := nft.SetAddElements(set, element)
if err != nil {
log.Println("handleNft() add failure:", err)
return nil, err
}
retmsg = "added"
}
}
fErr := nft.Flush()
if fErr != nil {
log.Println("nftablesHandler: failed to save changes:", fErr)
return nil, fErr
}
return retmsg, nil
}
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
targetTable := "filter" // TODO: make table configurable or smarter
foundTables, err := nft.ListTables()
if err != nil {
log.Printf("getNftTable(): %s", err)
return nil, err
}
exists := false
var table *nftables.Table
for _, foundTable := range foundTables {
if foundTable.Name == targetTable {
exists = true
table = foundTable
break
}
}
if !exists {
log.Printf("Table %s does not exist, cannot proceed", targetTable)
return nil, nftError{Message: "Table does not exist"}
}
return table, nil
}
func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
foundTable, err := getNftTable(nft)
if err != nil {
return nil, err
}
foundSet, err := nft.GetSetByName(foundTable, setName)
if err != nil || foundSet == nil {
log.Printf("Set lookup for %s failed: %s", setName, err)
return nil, err
}
log.Printf("Found set %s", foundSet.Name)
return foundSet, nil
}
func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) {
setElements, err := nft.GetSetElements(set)
if err != nil {
return nil, err
}
var returnElements []string
for i, element := range setElements {
ip := net.IP(element.Key)
log.Printf("Element %d: %s", i, ip)
returnElements = append(returnElements, ip.String())
}
return returnElements, nil
}
func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) {
existingElements, err := nft.GetSetElements(set)
if err != nil {
return false, err
}
for _, existingElement := range existingElements {
if bytes.Equal(existingElement.Key, element[0].Key) {
log.Printf("Existing element found %v", existingElement.Key)
return true, nil
}
}
return false, nil
}