168 lines
3.8 KiB
Go
168 lines
3.8 KiB
Go
/*
|
|
* This file is part of nftables-http-api.
|
|
* Copyright (C) 2024 Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
|
|
*
|
|
* The nftables-http-api program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
|
|
|
|
* This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
|
|
|
|
* You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"github.com/google/nftables"
|
|
"log"
|
|
"net"
|
|
)
|
|
|
|
type nftError struct {
|
|
Message string
|
|
}
|
|
|
|
func (nfterr nftError) Error() string {
|
|
return nfterr.Message
|
|
}
|
|
|
|
func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
|
nft, err := nftables.New()
|
|
if err != nil {
|
|
log.Println("handleNft():", err)
|
|
return "", err
|
|
}
|
|
|
|
set, err := getNftSet(nft, givenSet)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var element []nftables.SetElement
|
|
|
|
if task != "get" {
|
|
address, _ := parseIPAddress(givenAddress)
|
|
if address == nil {
|
|
return nil, errors.New("invalid address")
|
|
}
|
|
|
|
element = []nftables.SetElement{
|
|
{
|
|
Key: []byte(address),
|
|
},
|
|
}
|
|
}
|
|
|
|
var retmsg string
|
|
|
|
switch task {
|
|
case "get":
|
|
nftResult, err := getNftSetElements(nft, set)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return nftResult, nil
|
|
|
|
case "add":
|
|
contains, err := containsNftSetElement(nft, set, element)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if contains {
|
|
return "already", nil
|
|
} else {
|
|
err := nft.SetAddElements(set, element)
|
|
if err != nil {
|
|
log.Println("handleNft() add failure:", err)
|
|
return nil, err
|
|
}
|
|
retmsg = "added"
|
|
}
|
|
}
|
|
|
|
fErr := nft.Flush()
|
|
if fErr != nil {
|
|
log.Println("nftablesHandler: failed to save changes: %w", fErr)
|
|
return nil, fErr
|
|
}
|
|
|
|
return retmsg, nil
|
|
}
|
|
|
|
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
|
|
targetTable := "filter" // TODO: make table configurable or smarter
|
|
|
|
foundTables, err := nft.ListTables()
|
|
if err != nil {
|
|
log.Printf("getNftTable(): %s", err)
|
|
return nil, err
|
|
}
|
|
|
|
exists := false
|
|
var table *nftables.Table
|
|
for _, foundTable := range foundTables {
|
|
if foundTable.Name == targetTable {
|
|
exists = true
|
|
table = foundTable
|
|
break
|
|
}
|
|
}
|
|
|
|
if !exists {
|
|
log.Printf("Table %s does not exist, cannot proceed", targetTable)
|
|
return nil, nftError{Message: "Table does not exist"}
|
|
}
|
|
|
|
return table, nil
|
|
}
|
|
|
|
func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
|
|
foundTable, err := getNftTable(nft)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
foundSet, err := nft.GetSetByName(foundTable, setName)
|
|
if err != nil || foundSet == nil {
|
|
log.Printf("Set lookup for %s failed: %s", setName, err)
|
|
return nil, err
|
|
}
|
|
log.Printf("Found set %s", foundSet.Name)
|
|
|
|
return foundSet, nil
|
|
}
|
|
|
|
func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) {
|
|
setElements, err := nft.GetSetElements(set)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var returnElements []string
|
|
|
|
for i, element := range setElements {
|
|
ip := net.IP(element.Key)
|
|
log.Printf("Element %d: %s", i, ip)
|
|
returnElements = append(returnElements, ip.String())
|
|
}
|
|
|
|
return returnElements, nil
|
|
}
|
|
|
|
func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) {
|
|
|
|
existingElements, err := nft.GetSetElements(set)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
for _, existingElement := range existingElements {
|
|
if bytes.Equal(existingElement.Key, element[0].Key) {
|
|
log.Printf("Existing element found %v", existingElement.Key)
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|