diff --git a/nft.go b/nft.go index e228985..583dc02 100644 --- a/nft.go +++ b/nft.go @@ -12,6 +12,8 @@ package main import ( + "bytes" + "errors" "github.com/google/nftables" "log" "net" @@ -25,22 +27,67 @@ func (nfterr nftError) Error() string { return nfterr.Message } -func handleNft(task string, set string) (any, error) { +func handleNft(task string, givenSet string, givenAddress string) (any, error) { nft, err := nftables.New() if err != nil { log.Println("handleNft():", err) return "", err } - if task == "get" { - nftResult, err := getNftSetElements(nft, set) - if err == nil { - return nftResult, nil - } + set, err := getNftSet(nft, givenSet) + if err != nil { return nil, err } - return "", nil + var element []nftables.SetElement + + if task != "get" { + address, _ := parseIPAddress(givenAddress) + if address == nil { + return nil, errors.New("invalid address") + } + + element = []nftables.SetElement{ + { + Key: []byte(address), + }, + } + } + + var retmsg string + + switch task { + case "get": + nftResult, err := getNftSetElements(nft, set) + if err != nil { + return nil, err + } + return nftResult, nil + + case "add": + contains, err := containsNftSetElement(nft, set, element) + if err != nil { + return nil, err + } + if contains { + return "already", nil + } else { + err := nft.SetAddElements(set, element) + if err != nil { + log.Println("handleNft() add failure:", err) + return nil, err + } + retmsg = "added" + } + } + + fErr := nft.Flush() + if fErr != nil { + log.Println("nftablesHandler: failed to save changes: %w", fErr) + return nil, fErr + } + + return retmsg, nil } func getNftTable(nft *nftables.Conn) (*nftables.Table, error) { @@ -77,7 +124,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) { } foundSet, err := nft.GetSetByName(foundTable, setName) if err != nil || foundSet == nil { - log.Printf("Set lookup for %s failed, cannot proceed: %s", setName, err) + log.Printf("Set lookup for %s failed: %s", setName, err) return nil, err } log.Printf("Found set %s", foundSet.Name) @@ -85,13 +132,7 @@ func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) { return foundSet, nil } -func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) { - set, err := getNftSet(nft, setName) - if err != nil { - log.Printf("Could not retrieve set elements") - return nil, err - } - +func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) { setElements, err := nft.GetSetElements(set) if err != nil { return nil, err @@ -107,3 +148,20 @@ func getNftSetElements(nft *nftables.Conn, setName string) ([]string, error) { return returnElements, nil } + +func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) { + + existingElements, err := nft.GetSetElements(set) + if err != nil { + return false, err + } + + for _, existingElement := range existingElements { + if bytes.Equal(existingElement.Key, element[0].Key) { + log.Printf("Existing element found %v", existingElement.Key) + return true, nil + } + } + + return false, nil +} diff --git a/nftables-http-api.go b/nftables-http-api.go index 0d8ca37..cc50618 100644 --- a/nftables-http-api.go +++ b/nftables-http-api.go @@ -12,6 +12,7 @@ package main import ( + "encoding/json" "flag" "github.com/gorilla/mux" "gopkg.in/yaml.v3" @@ -66,6 +67,10 @@ func (authMiddleWare *authMiddleWareMap) Middleware(next http.Handler) http.Hand }) } +type addressPayload struct { + Address string +} + func main() { flag.Parse() log.Print("Booting ...") @@ -86,7 +91,7 @@ func main() { log.Print("Listening on ", listen) router := mux.NewRouter() - router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET") + router.HandleFunc("/set/{set}", handleSetRoute).Methods("GET", "POST") authMiddleWare := authMiddleWareMap{make(map[string]string)} router.Use(authMiddleWare.Middleware) @@ -101,13 +106,37 @@ func handleSetRoute(w http.ResponseWriter, r *http.Request) { set := params["set"] log.Printf("Processing authorized %s request from %s for set %s", method, r.RemoteAddr, set) - if method == "GET" { - nftResult, err := handleNft("get", set) - if err != nil { - doReturn(w, http.StatusInternalServerError, "nftables failure") + switch method { + case "GET": + nftResult, err := handleNft("get", set, "") + if err != nil || nftResult == nil { + doReturn(w, http.StatusInternalServerError, err.Error()) + return } - if nftResult != nil { - doReturnSet(w, http.StatusOK, "", nftResult.([]string)) + doReturnSet(w, http.StatusOK, "", nftResult.([]string)) + + case "POST": + var payload addressPayload + decErr := json.NewDecoder(r.Body).Decode(&payload) + if decErr != nil { + doReturn(w, http.StatusBadRequest, decErr.Error()) + return + } + + nftResult, err := handleNft("add", set, payload.Address) + if err != nil || nftResult == nil { + doReturn(w, http.StatusInternalServerError, err.Error()) + return + } + switch nftResult { + case "already": + doReturn(w, http.StatusOK, "already exists") + case "added": + doReturn(w, http.StatusCreated, "ok") + case nil: + doReturn(w, http.StatusInternalServerError, "failure") + default: + doReturn(w, http.StatusInternalServerError, "unhandled result") } } } diff --git a/utils.go b/utils.go index 4ff1a19..e37bdd9 100644 --- a/utils.go +++ b/utils.go @@ -32,7 +32,7 @@ type ResponseSet struct { func doReturn(w http.ResponseWriter, status int, text string) { var response any - if status == http.StatusOK { + if status == http.StatusOK || status == http.StatusCreated { response = Response{RResult: text} } else { response = Response{RError: text}