300 lines
7.3 KiB
Go
300 lines
7.3 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/nftables"
|
|
"github.com/vishvananda/netns"
|
|
)
|
|
|
|
func main() {
|
|
args := os.Args[1:]
|
|
if len(args) != 1 {
|
|
log.Fatalf("need to specify the file to read the \"nft list ruleset\" dump")
|
|
}
|
|
|
|
filename := args[0]
|
|
|
|
runtime.LockOSThread()
|
|
defer runtime.UnlockOSThread()
|
|
|
|
// Create a new network namespace
|
|
ns, err := netns.New()
|
|
if err != nil {
|
|
log.Fatalf("netns.New() failed: %v", err)
|
|
}
|
|
n, err := nftables.New(nftables.WithNetNSFd(int(ns)))
|
|
if err != nil {
|
|
log.Fatalf("nftables.New() failed: %v", err)
|
|
}
|
|
|
|
scriptOutput, err := applyNFTRuleset(filename)
|
|
if err != nil {
|
|
log.Fatalf("Failed to apply nftables script: %v\noutput:%s", err, scriptOutput)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
// Helper function to print to the file
|
|
pf := func(format string, a ...interface{}) {
|
|
_, err := fmt.Fprintf(&buf, format, a...)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
pf("// Code generated by nft2go. DO NOT EDIT.\n")
|
|
pf("package main\n\n")
|
|
pf("import (\n")
|
|
pf("\t\"fmt\"\n")
|
|
pf("\t\"log\"\n")
|
|
pf("\t\"github.com/google/nftables\"\n")
|
|
pf("\t\"github.com/google/nftables/expr\"\n")
|
|
pf(")\n\n")
|
|
pf("func main() {\n")
|
|
pf("\tn, err:= nftables.New()\n")
|
|
pf("\tif err!= nil {\n")
|
|
pf("\t\tlog.Fatal(err)\n")
|
|
pf("\t}\n\n")
|
|
pf("\n")
|
|
pf("\tvar expressions []expr.Any\n")
|
|
pf("\tvar chain *nftables.Chain\n")
|
|
pf("\tvar table *nftables.Table\n")
|
|
|
|
tables, err := n.ListTables()
|
|
if err != nil {
|
|
log.Fatalf("ListTables failed: %v", err)
|
|
}
|
|
|
|
chains, err := n.ListChains()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for _, table := range tables {
|
|
log.Printf("processing table: %s", table.Name)
|
|
|
|
pf("\ttable = n.AddTable(&nftables.Table{Family: %s,Name: \"%s\"})\n", TableFamilyString(table.Family), table.Name)
|
|
for _, chain := range chains {
|
|
if chain.Table.Name != table.Name {
|
|
continue
|
|
}
|
|
|
|
sets, err := n.GetSets(table)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
for _, set := range sets {
|
|
// TODO datatype and the other options
|
|
pf("\tn.AddSet(&nftables.Set{\n")
|
|
pf("\t\tTable: table,\n")
|
|
pf("\t\tName: \"%s\",\n", set.Name)
|
|
pf("\t}, nil)\n")
|
|
}
|
|
|
|
pf("\tchain = n.AddChain(&nftables.Chain{Name: \"%s\", Table: table, Type: %s, Hooknum: %s, Priority: %s})\n",
|
|
chain.Name, ChainTypeString(chain.Type), ChainHookRef(chain.Hooknum), ChainPrioRef(chain.Priority))
|
|
|
|
rules, err := n.GetRules(table, chain)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
for _, rule := range rules {
|
|
pf("\texpressions = []expr.Any{\n")
|
|
for _, exp := range rule.Exprs {
|
|
pf("\t\t%#v,\n", exp)
|
|
}
|
|
pf("\t\t}\n")
|
|
pf("\tn.AddRule(&nftables.Rule{\n")
|
|
pf("\t\tTable: table,\n")
|
|
pf("\t\tChain: chain,\n")
|
|
pf("\t\tExprs: expressions,\n")
|
|
pf("\t})\n")
|
|
}
|
|
}
|
|
}
|
|
|
|
pf("\n\tif err:= n.Flush(); err!= nil {\n")
|
|
pf("\t\tlog.Fatal(err)\n")
|
|
pf("\t}\n\n")
|
|
pf("\tfmt.Println(\"nft ruleset applied.\")\n")
|
|
pf("}\n")
|
|
|
|
// Program nftables using your Go code
|
|
if err := flushNFTRuleset(); err != nil {
|
|
log.Fatalf("Failed to flush nftables ruleset: %v", err)
|
|
}
|
|
|
|
// Create the output file
|
|
// Create a temporary directory
|
|
tempDir, err := ioutil.TempDir("", "nftables_gen")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer os.RemoveAll(tempDir) // Clean up the temporary directory
|
|
|
|
// Create the temporary Go file
|
|
tempGoFile := filepath.Join(tempDir, "nftables_recreate.go")
|
|
f, err := os.Create(tempGoFile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
mw := io.MultiWriter(f, os.Stdout)
|
|
buf.WriteTo(mw)
|
|
|
|
// Format the generated code
|
|
log.Printf("formating file: %s", tempGoFile)
|
|
cmd := exec.Command("gofmt", "-w", "-s", tempGoFile)
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
log.Fatalf("gofmt error: %v\nOutput: %s", err, output)
|
|
}
|
|
|
|
// Run the generated code
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
log.Printf("executing file: %s", tempGoFile)
|
|
cmd = exec.CommandContext(ctx, "go", "run", tempGoFile)
|
|
output, err = cmd.CombinedOutput()
|
|
if err != nil {
|
|
log.Fatalf("Execution error: %v\nOutput: %s", err, output)
|
|
}
|
|
|
|
// Retrieve nftables state using nft
|
|
log.Printf("obtain current ruleset: %s", tempGoFile)
|
|
actualOutput, err := listNFTRuleset()
|
|
if err != nil {
|
|
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
|
|
}
|
|
|
|
expectedOutput, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
log.Fatalf("Failed to list nftables ruleset: %v\noutput:%s", err, actualOutput)
|
|
}
|
|
|
|
if !compareMultilineStringsIgnoreIndentation(string(expectedOutput), actualOutput) {
|
|
log.Printf("Expected output:\n%s", string(expectedOutput))
|
|
log.Printf("Actual output:\n%s", actualOutput)
|
|
|
|
log.Fatalf("nftables ruleset mismatch:\n%s", cmp.Diff(string(expectedOutput), actualOutput))
|
|
}
|
|
|
|
if err := flushNFTRuleset(); err != nil {
|
|
log.Fatalf("Failed to flush nftables ruleset: %v", err)
|
|
}
|
|
}
|
|
|
|
func applyNFTRuleset(scriptPath string) (string, error) {
|
|
cmd := exec.Command("nft", "--debug=netlink", "-f", scriptPath)
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return string(out), err
|
|
}
|
|
return strings.TrimSpace(string(out)), nil
|
|
}
|
|
|
|
func listNFTRuleset() (string, error) {
|
|
cmd := exec.Command("nft", "list", "ruleset")
|
|
out, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
return string(out), err
|
|
}
|
|
return strings.TrimSpace(string(out)), nil
|
|
}
|
|
|
|
func flushNFTRuleset() error {
|
|
cmd := exec.Command("nft", "flush", "ruleset")
|
|
return cmd.Run()
|
|
}
|
|
|
|
func ChainHookRef(hookNum *nftables.ChainHook) string {
|
|
i := uint32(0)
|
|
if hookNum != nil {
|
|
i = uint32(*hookNum)
|
|
}
|
|
switch i {
|
|
case 0:
|
|
return "nftables.ChainHookPrerouting"
|
|
case 1:
|
|
return "nftables.ChainHookInput"
|
|
case 2:
|
|
return "nftables.ChainHookForward"
|
|
case 3:
|
|
return "nftables.ChainHookOutput"
|
|
case 4:
|
|
return "nftables.ChainHookPostrouting"
|
|
case 5:
|
|
return "nftables.ChainHookIngress"
|
|
case 6:
|
|
return "nftables.ChainHookEgress"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func ChainPrioRef(priority *nftables.ChainPriority) string {
|
|
i := int32(0)
|
|
if priority != nil {
|
|
i = int32(*priority)
|
|
}
|
|
return fmt.Sprintf("nftables.ChainPriorityRef(%d)", i)
|
|
}
|
|
|
|
func ChainTypeString(chaintype nftables.ChainType) string {
|
|
switch chaintype {
|
|
case nftables.ChainTypeFilter:
|
|
return "nftables.ChainTypeFilter"
|
|
case nftables.ChainTypeRoute:
|
|
return "nftables.ChainTypeRoute"
|
|
case nftables.ChainTypeNAT:
|
|
return "nftables.ChainTypeNAT"
|
|
default:
|
|
return "nftables.ChainTypeFilter"
|
|
}
|
|
}
|
|
|
|
func TableFamilyString(family nftables.TableFamily) string {
|
|
switch family {
|
|
case nftables.TableFamilyUnspecified:
|
|
return "nftables.TableFamilyUnspecified"
|
|
case nftables.TableFamilyINet:
|
|
return "nftables.TableFamilyINet"
|
|
case nftables.TableFamilyIPv4:
|
|
return "nftables.TableFamilyIPv4"
|
|
case nftables.TableFamilyIPv6:
|
|
return "nftables.TableFamilyIPv6"
|
|
case nftables.TableFamilyARP:
|
|
return "nftables.TableFamilyARP"
|
|
case nftables.TableFamilyNetdev:
|
|
return "nftables.TableFamilyNetdev"
|
|
case nftables.TableFamilyBridge:
|
|
return "nftables.TableFamilyBridge"
|
|
default:
|
|
return "nftables.TableFamilyIPv4"
|
|
}
|
|
}
|
|
|
|
func compareMultilineStringsIgnoreIndentation(str1, str2 string) bool {
|
|
// Remove all indentation from both strings
|
|
re := regexp.MustCompile(`(?m)^\s+`)
|
|
str1 = re.ReplaceAllString(str1, "")
|
|
str2 = re.ReplaceAllString(str2, "")
|
|
|
|
return str1 == str2
|
|
}
|