Add POST functionality to add set elements

Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
Georg Pfuetzenreuter 2024-08-30 20:12:17 +02:00
parent 73c788181e
commit f2821a9293
Signed by: Georg
GPG Key ID: 1ED2F138E7E6FF57
3 changed files with 110 additions and 23 deletions

88
nft.go
View File

@ -12,6 +12,8 @@
package main
import (
"bytes"
"errors"
"github.com/google/nftables"
"log"
"net"
@ -25,22 +27,67 @@ func (nfterr nftError) Error() string {
return nfterr.Message
}
func handleNft(task string, set string) (any, error) {
func handleNft(task string, givenSet string, givenAddress string) (any, error) {
nft, err := nftables.New()
if err != nil {
log.Println("handleNft():", err)
return "", err
}
if task == "get" {
nftResult, err := getNftSetElements(nft, set)
if err == nil {
return nftResult, nil
}
set, err := getNftSet(nft, givenSet)
if err != nil {
return nil, err
}
return "", nil
var element []nftables.SetElement
if task != "get" {
address, _ := parseIPAddress(givenAddress)
if address == nil {
return nil, errors.New("invalid address")
}
element = []nftables.SetElement{
{
Key: []byte(address),
},
}
}
var retmsg string
switch task {
case "get":
nftResult, err := getNftSetElements(nft, set)
if err != nil {
return nil, err
}
return nftResult, nil
case "add":
contains, err := containsNftSetElement(nft, set, element)
if err != nil {
return nil, err
}
if contains {
return "already", nil
} else {
err := nft.SetAddElements(set, element)
if err != nil {
log.Println("handleNft() add failure:", err)
return nil, err
}
retmsg = "added"
}
}
fErr := nft.Flush()
if fErr != nil {
log.Println("nftablesHandler: failed to save changes: %w", fErr)
return nil, fErr
}
return retmsg, nil
}
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
@ -77,7 +124,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
}
foundSet, err := nft.GetSetByName(foundTable, setName)
if err != nil || foundSet == nil {
log.Printf("Set lookup for %s failed, cannot proceed: %s", setName, err)
log.Printf("Set lookup for %s failed: %s", setName, err)
return nil, err
}
log.Printf("Found set %s", foundSet.Name)
@ -85,13 +132,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
return foundSet, nil
}
func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) {
set, err := getNftSet(nft, setName)
if err != nil {
log.Printf("Could not retrieve set elements")
return nil, err
}
func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) {
setElements, err := nft.GetSetElements(set)
if err != nil {
return nil, err
@ -107,3 +148,20 @@ func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) {
return returnElements, nil
}
func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) {
existingElements, err := nft.GetSetElements(set)
if err != nil {
return false, err
}
for _, existingElement := range existingElements {
if bytes.Equal(existingElement.Key, element[0].Key) {
log.Printf("Existing element found %v", existingElement.Key)
return true, nil
}
}
return false, nil
}

View File

@ -12,6 +12,7 @@
package main
import (
"encoding/json"
"flag"
"github.com/gorilla/mux"
"gopkg.in/yaml.v3"
@ -66,6 +67,10 @@ func (authMiddleWare *authMiddleWareMap) Middleware(next http.Handler) http.Hand
})
}
type addressPayload struct {
Address string
}
func main() {
flag.Parse()
log.Print("Booting ...")
@ -86,7 +91,7 @@ func main() {
log.Print("Listening on ", listen)
router := mux.NewRouter()
router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET")
router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET", "POST")
authMiddleWare := authMiddleWareMap{make(map[string]string)}
router.Use(authMiddleWare.Middleware)
@ -101,13 +106,37 @@ func handleSetRoute(w http.ResponseWriter, r *http.Request) {
set := params["set"]
log.Printf("Processing authorized %s request from %s for set %s", method, r.RemoteAddr, set)
if method == "GET" {
nftResult, err := handleNft("get", set)
if err != nil {
doReturn(w, http.StatusInternalServerError, "nftables failure")
switch method {
case "GET":
nftResult, err := handleNft("get", set, "")
if err != nil || nftResult == nil {
doReturn(w, http.StatusInternalServerError, err.Error())
return
}
if nftResult != nil {
doReturnSet(w, http.StatusOK, "", nftResult.([]string))
case "POST":
var payload addressPayload
decErr := json.NewDecoder(r.Body).Decode(&payload)
if decErr != nil {
doReturn(w, http.StatusBadRequest, decErr.Error())
return
}
nftResult, err := handleNft("add", set, payload.Address)
if err != nil || nftResult == nil {
doReturn(w, http.StatusInternalServerError, err.Error())
return
}
switch nftResult {
case "already":
doReturn(w, http.StatusOK, "already exists")
case "added":
doReturn(w, http.StatusCreated, "ok")
case nil:
doReturn(w, http.StatusInternalServerError, "failure")
default:
doReturn(w, http.StatusInternalServerError, "unhandled result")
}
}
}

View File

@ -32,7 +32,7 @@ type ResponseSet struct {
func doReturn(w http.ResponseWriter, status int, text string) {
var response any
if status == http.StatusOK {
if status == http.StatusOK || status == http.StatusCreated {
response = Response{RResult: text}
} else {
response = Response{RError: text}