Add POST functionality to add set elements
Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
parent
73c788181e
commit
f2821a9293
88
nft.go
88
nft.go
@ -12,6 +12,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"github.com/google/nftables"
|
||||
"log"
|
||||
"net"
|
||||
@ -25,22 +27,67 @@ func (nfterr nftError) Error() string {
|
||||
return nfterr.Message
|
||||
}
|
||||
|
||||
func handleNft(task string, set string) (any, error) {
|
||||
func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
||||
nft, err := nftables.New()
|
||||
if err != nil {
|
||||
log.Println("handleNft():", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if task == "get" {
|
||||
nftResult, err := getNftSetElements(nft, set)
|
||||
if err == nil {
|
||||
return nftResult, nil
|
||||
}
|
||||
set, err := getNftSet(nft, givenSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return "", nil
|
||||
var element []nftables.SetElement
|
||||
|
||||
if task != "get" {
|
||||
address, _ := parseIPAddress(givenAddress)
|
||||
if address == nil {
|
||||
return nil, errors.New("invalid address")
|
||||
}
|
||||
|
||||
element = []nftables.SetElement{
|
||||
{
|
||||
Key: []byte(address),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var retmsg string
|
||||
|
||||
switch task {
|
||||
case "get":
|
||||
nftResult, err := getNftSetElements(nft, set)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nftResult, nil
|
||||
|
||||
case "add":
|
||||
contains, err := containsNftSetElement(nft, set, element)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if contains {
|
||||
return "already", nil
|
||||
} else {
|
||||
err := nft.SetAddElements(set, element)
|
||||
if err != nil {
|
||||
log.Println("handleNft() add failure:", err)
|
||||
return nil, err
|
||||
}
|
||||
retmsg = "added"
|
||||
}
|
||||
}
|
||||
|
||||
fErr := nft.Flush()
|
||||
if fErr != nil {
|
||||
log.Println("nftablesHandler: failed to save changes: %w", fErr)
|
||||
return nil, fErr
|
||||
}
|
||||
|
||||
return retmsg, nil
|
||||
}
|
||||
|
||||
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
|
||||
@ -77,7 +124,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
|
||||
}
|
||||
foundSet, err := nft.GetSetByName(foundTable, setName)
|
||||
if err != nil || foundSet == nil {
|
||||
log.Printf("Set lookup for %s failed, cannot proceed: %s", setName, err)
|
||||
log.Printf("Set lookup for %s failed: %s", setName, err)
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("Found set %s", foundSet.Name)
|
||||
@ -85,13 +132,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
|
||||
return foundSet, nil
|
||||
}
|
||||
|
||||
func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) {
|
||||
set, err := getNftSet(nft, setName)
|
||||
if err != nil {
|
||||
log.Printf("Could not retrieve set elements")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) {
|
||||
setElements, err := nft.GetSetElements(set)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -107,3 +148,20 @@ func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) {
|
||||
|
||||
return returnElements, nil
|
||||
}
|
||||
|
||||
func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) {
|
||||
|
||||
existingElements, err := nft.GetSetElements(set)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, existingElement := range existingElements {
|
||||
if bytes.Equal(existingElement.Key, element[0].Key) {
|
||||
log.Printf("Existing element found %v", existingElement.Key)
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"github.com/gorilla/mux"
|
||||
"gopkg.in/yaml.v3"
|
||||
@ -66,6 +67,10 @@ func (authMiddleWare *authMiddleWareMap) Middleware(next http.Handler) http.Hand
|
||||
})
|
||||
}
|
||||
|
||||
type addressPayload struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.Print("Booting ...")
|
||||
@ -86,7 +91,7 @@ func main() {
|
||||
log.Print("Listening on ", listen)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET")
|
||||
router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET", "POST")
|
||||
|
||||
authMiddleWare := authMiddleWareMap{make(map[string]string)}
|
||||
router.Use(authMiddleWare.Middleware)
|
||||
@ -101,13 +106,37 @@ func handleSetRoute(w http.ResponseWriter, r *http.Request) {
|
||||
set := params["set"]
|
||||
log.Printf("Processing authorized %s request from %s for set %s", method, r.RemoteAddr, set)
|
||||
|
||||
if method == "GET" {
|
||||
nftResult, err := handleNft("get", set)
|
||||
if err != nil {
|
||||
doReturn(w, http.StatusInternalServerError, "nftables failure")
|
||||
switch method {
|
||||
case "GET":
|
||||
nftResult, err := handleNft("get", set, "")
|
||||
if err != nil || nftResult == nil {
|
||||
doReturn(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if nftResult != nil {
|
||||
doReturnSet(w, http.StatusOK, "", nftResult.([]string))
|
||||
doReturnSet(w, http.StatusOK, "", nftResult.([]string))
|
||||
|
||||
case "POST":
|
||||
var payload addressPayload
|
||||
decErr := json.NewDecoder(r.Body).Decode(&payload)
|
||||
if decErr != nil {
|
||||
doReturn(w, http.StatusBadRequest, decErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
nftResult, err := handleNft("add", set, payload.Address)
|
||||
if err != nil || nftResult == nil {
|
||||
doReturn(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
switch nftResult {
|
||||
case "already":
|
||||
doReturn(w, http.StatusOK, "already exists")
|
||||
case "added":
|
||||
doReturn(w, http.StatusCreated, "ok")
|
||||
case nil:
|
||||
doReturn(w, http.StatusInternalServerError, "failure")
|
||||
default:
|
||||
doReturn(w, http.StatusInternalServerError, "unhandled result")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
2
utils.go
2
utils.go
@ -32,7 +32,7 @@ type ResponseSet struct {
|
||||
|
||||
func doReturn(w http.ResponseWriter, status int, text string) {
|
||||
var response any
|
||||
if status == http.StatusOK {
|
||||
if status == http.StatusOK || status == http.StatusCreated {
|
||||
response = Response{RResult: text}
|
||||
} else {
|
||||
response = Response{RError: text}
|
||||
|
Reference in New Issue
Block a user