Add POST functionality to add set elements
Signed-off-by: Georg Pfuetzenreuter <mail@georg-pfuetzenreuter.net>
This commit is contained in:
parent
73c788181e
commit
f2821a9293
88
nft.go
88
nft.go
@ -12,6 +12,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -25,22 +27,67 @@ func (nfterr nftError) Error() string {
|
|||||||
return nfterr.Message
|
return nfterr.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleNft(task string, set string) (any, error) {
|
func handleNft(task string, givenSet string, givenAddress string) (any, error) {
|
||||||
nft, err := nftables.New()
|
nft, err := nftables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("handleNft():", err)
|
log.Println("handleNft():", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if task == "get" {
|
set, err := getNftSet(nft, givenSet)
|
||||||
nftResult, err := getNftSetElements(nft, set)
|
if err != nil {
|
||||||
if err == nil {
|
|
||||||
return nftResult, nil
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", nil
|
var element []nftables.SetElement
|
||||||
|
|
||||||
|
if task != "get" {
|
||||||
|
address, _ := parseIPAddress(givenAddress)
|
||||||
|
if address == nil {
|
||||||
|
return nil, errors.New("invalid address")
|
||||||
|
}
|
||||||
|
|
||||||
|
element = []nftables.SetElement{
|
||||||
|
{
|
||||||
|
Key: []byte(address),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var retmsg string
|
||||||
|
|
||||||
|
switch task {
|
||||||
|
case "get":
|
||||||
|
nftResult, err := getNftSetElements(nft, set)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return nftResult, nil
|
||||||
|
|
||||||
|
case "add":
|
||||||
|
contains, err := containsNftSetElement(nft, set, element)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if contains {
|
||||||
|
return "already", nil
|
||||||
|
} else {
|
||||||
|
err := nft.SetAddElements(set, element)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("handleNft() add failure:", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
retmsg = "added"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fErr := nft.Flush()
|
||||||
|
if fErr != nil {
|
||||||
|
log.Println("nftablesHandler: failed to save changes: %w", fErr)
|
||||||
|
return nil, fErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return retmsg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
|
func getNftTable(nft *nftables.Conn) (*nftables.Table, error) {
|
||||||
@ -77,7 +124,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
|
|||||||
}
|
}
|
||||||
foundSet, err := nft.GetSetByName(foundTable, setName)
|
foundSet, err := nft.GetSetByName(foundTable, setName)
|
||||||
if err != nil || foundSet == nil {
|
if err != nil || foundSet == nil {
|
||||||
log.Printf("Set lookup for %s failed, cannot proceed: %s", setName, err)
|
log.Printf("Set lookup for %s failed: %s", setName, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Printf("Found set %s", foundSet.Name)
|
log.Printf("Found set %s", foundSet.Name)
|
||||||
@ -85,13 +132,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) {
|
|||||||
return foundSet, nil
|
return foundSet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) {
|
func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) {
|
||||||
set, err := getNftSet(nft, setName)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Could not retrieve set elements")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
setElements, err := nft.GetSetElements(set)
|
setElements, err := nft.GetSetElements(set)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -107,3 +148,20 @@ func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) {
|
|||||||
|
|
||||||
return returnElements, nil
|
return returnElements, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) {
|
||||||
|
|
||||||
|
existingElements, err := nft.GetSetElements(set)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, existingElement := range existingElements {
|
||||||
|
if bytes.Equal(existingElement.Key, element[0].Key) {
|
||||||
|
log.Printf("Existing element found %v", existingElement.Key)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@ -66,6 +67,10 @@ func (authMiddleWare *authMiddleWareMap) Middleware(next http.Handler) http.Hand
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type addressPayload struct {
|
||||||
|
Address string
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
log.Print("Booting ...")
|
log.Print("Booting ...")
|
||||||
@ -86,7 +91,7 @@ func main() {
|
|||||||
log.Print("Listening on ", listen)
|
log.Print("Listening on ", listen)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET")
|
router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET", "POST")
|
||||||
|
|
||||||
authMiddleWare := authMiddleWareMap{make(map[string]string)}
|
authMiddleWare := authMiddleWareMap{make(map[string]string)}
|
||||||
router.Use(authMiddleWare.Middleware)
|
router.Use(authMiddleWare.Middleware)
|
||||||
@ -101,13 +106,37 @@ func handleSetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
set := params["set"]
|
set := params["set"]
|
||||||
log.Printf("Processing authorized %s request from %s for set %s", method, r.RemoteAddr, set)
|
log.Printf("Processing authorized %s request from %s for set %s", method, r.RemoteAddr, set)
|
||||||
|
|
||||||
if method == "GET" {
|
switch method {
|
||||||
nftResult, err := handleNft("get", set)
|
case "GET":
|
||||||
if err != nil {
|
nftResult, err := handleNft("get", set, "")
|
||||||
doReturn(w, http.StatusInternalServerError, "nftables failure")
|
if err != nil || nftResult == nil {
|
||||||
|
doReturn(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if nftResult != nil {
|
doReturnSet(w, http.StatusOK, "", nftResult.([]string))
|
||||||
doReturnSet(w, http.StatusOK, "", nftResult.([]string))
|
|
||||||
|
case "POST":
|
||||||
|
var payload addressPayload
|
||||||
|
decErr := json.NewDecoder(r.Body).Decode(&payload)
|
||||||
|
if decErr != nil {
|
||||||
|
doReturn(w, http.StatusBadRequest, decErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nftResult, err := handleNft("add", set, payload.Address)
|
||||||
|
if err != nil || nftResult == nil {
|
||||||
|
doReturn(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch nftResult {
|
||||||
|
case "already":
|
||||||
|
doReturn(w, http.StatusOK, "already exists")
|
||||||
|
case "added":
|
||||||
|
doReturn(w, http.StatusCreated, "ok")
|
||||||
|
case nil:
|
||||||
|
doReturn(w, http.StatusInternalServerError, "failure")
|
||||||
|
default:
|
||||||
|
doReturn(w, http.StatusInternalServerError, "unhandled result")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
2
utils.go
2
utils.go
@ -32,7 +32,7 @@ type ResponseSet struct {
|
|||||||
|
|
||||||
func doReturn(w http.ResponseWriter, status int, text string) {
|
func doReturn(w http.ResponseWriter, status int, text string) {
|
||||||
var response any
|
var response any
|
||||||
if status == http.StatusOK {
|
if status == http.StatusOK || status == http.StatusCreated {
|
||||||
response = Response{RResult: text}
|
response = Response{RResult: text}
|
||||||
} else {
|
} else {
|
||||||
response = Response{RError: text}
|
response = Response{RError: text}
|
||||||
|
Reference in New Issue
Block a user