nftables/internal/nf2go/main.go

300 lines
7.3 KiB
Go

package main
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/nftables"
"github.com/vishvananda/netns"
)
func main() {
args := os.Args[1:]
if len(args) != 1 {
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
}
filename := args[0]
runtime.LockOSThread()
defer runtime.UnlockOSThread()
// Create a new network namespace
ns, err := netns.New()
if err != nil {
log.Fatalf("netns.New() failed: %v", err)
}
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
if err != nil {
log.Fatalf("nftables.New() failed: %v", err)
}
scriptOutput, err := applyNFTRuleset(filename)
if err != nil {
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
}
var buf bytes.Buffer
// Helper function to print to the file
pf := func(format string, a ...interface{}) {
_, err := fmt.Fprintf(&buf, format, a...)
if err != nil {
log.Fatal(err)
}
}
pf("// Code generated by nft2go. DO NOT EDIT.\n")
pf("package main\n\n")
pf("import (\n")
pf("\t\"fmt\"\n")
pf("\t\"log\"\n")
pf("\t\"github.com/google/nftables\"\n")
pf("\t\"github.com/google/nftables/expr\"\n")
pf(")\n\n")
pf("func main() {\n")
pf("\tn, err:= nftables.New()\n")
pf("\tif err!= nil {\n")
pf("\t\tlog.Fatal(err)\n")
pf("\t}\n\n")
pf("\n")
pf("\tvar expressions []expr.Any\n")
pf("\tvar chain *nftables.Chain\n")
pf("\tvar table *nftables.Table\n")
tables, err := n.ListTables()
if err != nil {
log.Fatalf("ListTables failed: %v", err)
}
chains, err := n.ListChains()
if err != nil {
log.Fatal(err)
}
for _, table := range tables {
log.Printf("processing table: %s", table.Name)
pf("\ttable = n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}
sets, err := n.GetSets(table)
if err != nil {
log.Fatal(err)
}
for _, set := range sets {
// TODO datatype and the other options
pf("\tn.AddSet(&nftables.Set{\n")
pf("\t\tTable: table,\n")
pf("\t\tName: \"%s\",\n", set.Name)
pf("\t}, nil)\n")
}
pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))
rules, err := n.GetRules(table, chain)
if err != nil {
log.Fatal(err)
}
for _, rule := range rules {
pf("\texpressions = []expr.Any{\n")
for _, exp := range rule.Exprs {
pf("\t\t%#v,\n", exp)
}
pf("\t\t}\n")
pf("\tn.AddRule(&nftables.Rule{\n")
pf("\t\tTable: table,\n")
pf("\t\tChain: chain,\n")
pf("\t\tExprs: expressions,\n")
pf("\t})\n")
}
}
}
pf("\n\tif err:= n.Flush(); err!= nil {\n")
pf("\t\tlog.Fatal(err)\n")
pf("\t}\n\n")
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
pf("}\n")
// Program nftables using your Go code
if err := flushNFTRuleset(); err != nil {
log.Fatalf("Failed to flush nftables ruleset: %v", err)
}
// Create the output file
// Create a temporary directory
tempDir, err := ioutil.TempDir("", "nftables_gen")
if err != nil {
log.Fatal(err)
}
defer os.RemoveAll(tempDir) // Clean up the temporary directory
// Create the temporary Go file
tempGoFile := filepath.Join(tempDir, "nftables_recreate.go")
f, err := os.Create(tempGoFile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
mw := io.MultiWriter(f, os.Stdout)
buf.WriteTo(mw)
// Format the generated code
log.Printf("formating file: %s", tempGoFile)
cmd := exec.Command("gofmt", "-w", "-s", tempGoFile)
output, err := cmd.CombinedOutput()
if err != nil {
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
}
// Run the generated code
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
log.Printf("executing file: %s", tempGoFile)
cmd = exec.CommandContext(ctx, "go", "run", tempGoFile)
output, err = cmd.CombinedOutput()
if err != nil {
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
}
// Retrieve nftables state using nft
log.Printf("obtain current ruleset: %s", tempGoFile)
actualOutput, err := listNFTRuleset()
if err != nil {
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}
expectedOutput, err := os.ReadFile(filename)
if err != nil {
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
}
if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) {
log.Printf("Expected output:\n%s", string(expectedOutput))
log.Printf("Actual output:\n%s", actualOutput)
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput))
}
if err := flushNFTRuleset(); err != nil {
log.Fatalf("Failed to flush nftables ruleset: %v", err)
}
}
func applyNFTRuleset(scriptPath string) (string, error) {
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}
func listNFTRuleset() (string, error) {
cmd := exec.Command("nft", "list", "ruleset")
out, err := cmd.CombinedOutput()
if err != nil {
return string(out), err
}
return strings.TrimSpace(string(out)), nil
}
func flushNFTRuleset() error {
cmd := exec.Command("nft", "flush", "ruleset")
return cmd.Run()
}
func ChainHookRef(hookNum *nftables.ChainHook) string {
i := uint32(0)
if hookNum != nil {
i = uint32(*hookNum)
}
switch i {
case 0:
return "nftables.ChainHookPrerouting"
case 1:
return "nftables.ChainHookInput"
case 2:
return "nftables.ChainHookForward"
case 3:
return "nftables.ChainHookOutput"
case 4:
return "nftables.ChainHookPostrouting"
case 5:
return "nftables.ChainHookIngress"
case 6:
return "nftables.ChainHookEgress"
}
return ""
}
func ChainPrioRef(priority *nftables.ChainPriority) string {
i := int32(0)
if priority != nil {
i = int32(*priority)
}
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
}
func ChainTypeString(chaintype nftables.ChainType) string {
switch chaintype {
case nftables.ChainTypeFilter:
return "nftables.ChainTypeFilter"
case nftables.ChainTypeRoute:
return "nftables.ChainTypeRoute"
case nftables.ChainTypeNAT:
return "nftables.ChainTypeNAT"
default:
return "nftables.ChainTypeFilter"
}
}
func TableFamilyString(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyUnspecified:
return "nftables.TableFamilyUnspecified"
case nftables.TableFamilyINet:
return "nftables.TableFamilyINet"
case nftables.TableFamilyIPv4:
return "nftables.TableFamilyIPv4"
case nftables.TableFamilyIPv6:
return "nftables.TableFamilyIPv6"
case nftables.TableFamilyARP:
return "nftables.TableFamilyARP"
case nftables.TableFamilyNetdev:
return "nftables.TableFamilyNetdev"
case nftables.TableFamilyBridge:
return "nftables.TableFamilyBridge"
default:
return "nftables.TableFamilyIPv4"
}
}
func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool {
// Remove all indentation from both strings
re := regexp.MustCompile(`(?m)^\s+`)
str1 = re.ReplaceAllString(str1, "")
str2 = re.ReplaceAllString(str2, "")
return str1 == str2
}