package main import ( "bytes" "context" "fmt" "io" "io/ioutil" "log" "os" "os/exec" "path/filepath" "regexp" "runtime" "strings" "time" "github.com/google/go-cmp/cmp" "github.com/google/nftables" "github.com/vishvananda/netns" ) func main() { args := os.Args[1:] if len(args) != 1 { log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump") } filename := args[0] runtime.LockOSThread() defer runtime.UnlockOSThread() // Create a new network namespace ns, err := netns.New() if err != nil { log.Fatalf("netns.New() failed: %v", err) } n, err := nftables.New(nftables.WithNetNSFd(int(ns))) if err != nil { log.Fatalf("nftables.New() failed: %v", err) } scriptOutput, err := applyNFTRuleset(filename) if err != nil { log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput) } var buf bytes.Buffer // Helper function to print to the file pf := func(format string, a ...interface{}) { _, err := fmt.Fprintf(&buf, format, a...) if err != nil { log.Fatal(err) } } pf("// Code generated by nft2go. DO NOT EDIT.\n") pf("package main\n\n") pf("import (\n") pf("\t\"fmt\"\n") pf("\t\"log\"\n") pf("\t\"github.com/google/nftables\"\n") pf("\t\"github.com/google/nftables/expr\"\n") pf(")\n\n") pf("func main() {\n") pf("\tn, err:= nftables.New()\n") pf("\tif err!= nil {\n") pf("\t\tlog.Fatal(err)\n") pf("\t}\n\n") pf("\n") pf("\tvar expressions []expr.Any\n") pf("\tvar chain *nftables.Chain\n") pf("\tvar table *nftables.Table\n") tables, err := n.ListTables() if err != nil { log.Fatalf("ListTables failed: %v", err) } chains, err := n.ListChains() if err != nil { log.Fatal(err) } for _, table := range tables { log.Printf("processing table: %s", table.Name) pf("\ttable = n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name) for _, chain := range chains { if chain.Table.Name != table.Name { continue } sets, err := n.GetSets(table) if err != nil { log.Fatal(err) } for _, set := range sets { // TODO datatype and the other options pf("\tn.AddSet(&nftables.Set{\n") pf("\t\tTable: table,\n") pf("\t\tName: \"%s\",\n", set.Name) pf("\t}, nil)\n") } pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n", chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority)) rules, err := n.GetRules(table, chain) if err != nil { log.Fatal(err) } for _, rule := range rules { pf("\texpressions = []expr.Any{\n") for _, exp := range rule.Exprs { pf("\t\t%#v,\n", exp) } pf("\t\t}\n") pf("\tn.AddRule(&nftables.Rule{\n") pf("\t\tTable: table,\n") pf("\t\tChain: chain,\n") pf("\t\tExprs: expressions,\n") pf("\t})\n") } } } pf("\n\tif err:= n.Flush(); err!= nil {\n") pf("\t\tlog.Fatal(err)\n") pf("\t}\n\n") pf("\tfmt.Println(\"nft ruleset applied.\")\n") pf("}\n") // Program nftables using your Go code if err := flushNFTRuleset(); err != nil { log.Fatalf("Failed to flush nftables ruleset: %v", err) } // Create the output file // Create a temporary directory tempDir, err := ioutil.TempDir("", "nftables_gen") if err != nil { log.Fatal(err) } defer os.RemoveAll(tempDir) // Clean up the temporary directory // Create the temporary Go file tempGoFile := filepath.Join(tempDir, "nftables_recreate.go") f, err := os.Create(tempGoFile) if err != nil { log.Fatal(err) } defer f.Close() mw := io.MultiWriter(f, os.Stdout) buf.WriteTo(mw) // Format the generated code log.Printf("formating file: %s", tempGoFile) cmd := exec.Command("gofmt", "-w", "-s", tempGoFile) output, err := cmd.CombinedOutput() if err != nil { log.Fatalf("gofmt error: %v\nOutput: %s", err, output) } // Run the generated code ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() log.Printf("executing file: %s", tempGoFile) cmd = exec.CommandContext(ctx, "go", "run", tempGoFile) output, err = cmd.CombinedOutput() if err != nil { log.Fatalf("Execution error: %v\nOutput: %s", err, output) } // Retrieve nftables state using nft log.Printf("obtain current ruleset: %s", tempGoFile) actualOutput, err := listNFTRuleset() if err != nil { log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput) } expectedOutput, err := os.ReadFile(filename) if err != nil { log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput) } if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) { log.Printf("Expected output:\n%s", string(expectedOutput)) log.Printf("Actual output:\n%s", actualOutput) log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput)) } if err := flushNFTRuleset(); err != nil { log.Fatalf("Failed to flush nftables ruleset: %v", err) } } func applyNFTRuleset(scriptPath string) (string, error) { cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath) out, err := cmd.CombinedOutput() if err != nil { return string(out), err } return strings.TrimSpace(string(out)), nil } func listNFTRuleset() (string, error) { cmd := exec.Command("nft", "list", "ruleset") out, err := cmd.CombinedOutput() if err != nil { return string(out), err } return strings.TrimSpace(string(out)), nil } func flushNFTRuleset() error { cmd := exec.Command("nft", "flush", "ruleset") return cmd.Run() } func ChainHookRef(hookNum *nftables.ChainHook) string { i := uint32(0) if hookNum != nil { i = uint32(*hookNum) } switch i { case 0: return "nftables.ChainHookPrerouting" case 1: return "nftables.ChainHookInput" case 2: return "nftables.ChainHookForward" case 3: return "nftables.ChainHookOutput" case 4: return "nftables.ChainHookPostrouting" case 5: return "nftables.ChainHookIngress" case 6: return "nftables.ChainHookEgress" } return "" } func ChainPrioRef(priority *nftables.ChainPriority) string { i := int32(0) if priority != nil { i = int32(*priority) } return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i) } func ChainTypeString(chaintype nftables.ChainType) string { switch chaintype { case nftables.ChainTypeFilter: return "nftables.ChainTypeFilter" case nftables.ChainTypeRoute: return "nftables.ChainTypeRoute" case nftables.ChainTypeNAT: return "nftables.ChainTypeNAT" default: return "nftables.ChainTypeFilter" } } func TableFamilyString(family nftables.TableFamily) string { switch family { case nftables.TableFamilyUnspecified: return "nftables.TableFamilyUnspecified" case nftables.TableFamilyINet: return "nftables.TableFamilyINet" case nftables.TableFamilyIPv4: return "nftables.TableFamilyIPv4" case nftables.TableFamilyIPv6: return "nftables.TableFamilyIPv6" case nftables.TableFamilyARP: return "nftables.TableFamilyARP" case nftables.TableFamilyNetdev: return "nftables.TableFamilyNetdev" case nftables.TableFamilyBridge: return "nftables.TableFamilyBridge" default: return "nftables.TableFamilyIPv4" } } func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool { // Remove all indentation from both strings re := regexp.MustCompile(`(?m)^\s+`) str1 = re.ReplaceAllString(str1, "") str2 = re.ReplaceAllString(str2, "") return str1 == str2 }