/* * This file is part of nftables-http-api. * Copyright (C) 2024 Georg Pfuetzenreuter * * The nftables-http-api program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. * You should have received a copy of the GNU General Public License along with this program. If not, see . */ package main import ( "bytes" "errors" "github.com/google/nftables" "log" "net" ) type nftError struct { Message string } func (nfterr nftError) Error() string { return nfterr.Message } func handleNft(task string, givenSet string, givenAddress string) (any, error) { nft, err := nftables.New() if err != nil { log.Println("handleNft():", err) return "", err } set, err := getNftSet(nft, givenSet) if err != nil { return nil, err } var element []nftables.SetElement if task != "get" { address, _ := parseIPAddress(givenAddress) if address == nil { return nil, errors.New("invalid address") } element = []nftables.SetElement{ { Key: []byte(address), }, } } var retmsg string switch task { case "get": nftResult, err := getNftSetElements(nft, set) if err != nil { return nil, err } return nftResult, nil case "add": contains, err := containsNftSetElement(nft, set, element) if err != nil { return nil, err } if contains { return "already", nil } else { err := nft.SetAddElements(set, element) if err != nil { log.Println("handleNft() add failure:", err) return nil, err } retmsg = "added" } } fErr := nft.Flush() if fErr != nil { log.Println("nftablesHandler: failed to save changes: %w", fErr) return nil, fErr } return retmsg, nil } func getNftTable(nft *nftables.Conn) (*nftables.Table, error) { targetTable := "filter" // TODO: make table configurable or smarter foundTables, err := nft.ListTables() if err != nil { log.Printf("getNftTable(): %s", err) return nil, err } exists := false var table *nftables.Table for _, foundTable := range foundTables { if foundTable.Name == targetTable { exists = true table = foundTable break } } if !exists { log.Printf("Table %s does not exist, cannot proceed", targetTable) return nil, nftError{Message: "Table does not exist"} } return table, nil } func getNftSet(nft *nftables.Conn, setName string) (*nftables.Set, error) { foundTable, err := getNftTable(nft) if err != nil { return nil, err } foundSet, err := nft.GetSetByName(foundTable, setName) if err != nil || foundSet == nil { log.Printf("Set lookup for %s failed: %s", setName, err) return nil, err } log.Printf("Found set %s", foundSet.Name) return foundSet, nil } func getNftSetElements(nft *nftables.Conn, set *nftables.Set) ([]string, error) { setElements, err := nft.GetSetElements(set) if err != nil { return nil, err } var returnElements []string for i, element := range setElements { ip := net.IP(element.Key) log.Printf("Element %d: %s", i, ip) returnElements = append(returnElements, ip.String()) } return returnElements, nil } func containsNftSetElement(nft *nftables.Conn, set *nftables.Set, element []nftables.SetElement) (bool, error) { existingElements, err := nft.GetSetElements(set) if err != nil { return false, err } for _, existingElement := range existingElements { if bytes.Equal(existingElement.Key, element[0].Key) { log.Printf("Existing element found %v", existingElement.Key) return true, nil } } return false, nil }