diff --git a/chain.go b/chain.go index 39b192c..9928d63 100644 --- a/chain.go +++ b/chain.go @@ -30,38 +30,48 @@ import ( type ChainHook uint32 // Possible ChainHook values. -const ( - ChainHookPrerouting ChainHook = unix.NF_INET_PRE_ROUTING - ChainHookInput ChainHook = unix.NF_INET_LOCAL_IN - ChainHookForward ChainHook = unix.NF_INET_FORWARD - ChainHookOutput ChainHook = unix.NF_INET_LOCAL_OUT - ChainHookPostrouting ChainHook = unix.NF_INET_POST_ROUTING - ChainHookIngress ChainHook = unix.NF_NETDEV_INGRESS +var ( + ChainHookPrerouting *ChainHook = ChainHookRef(unix.NF_INET_PRE_ROUTING) + ChainHookInput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_IN) + ChainHookForward *ChainHook = ChainHookRef(unix.NF_INET_FORWARD) + ChainHookOutput *ChainHook = ChainHookRef(unix.NF_INET_LOCAL_OUT) + ChainHookPostrouting *ChainHook = ChainHookRef(unix.NF_INET_POST_ROUTING) + ChainHookIngress *ChainHook = ChainHookRef(unix.NF_NETDEV_INGRESS) ) +// ChainHookRef returns a pointer to a ChainHookRef value. +func ChainHookRef(h ChainHook) *ChainHook { + return &h +} + // ChainPriority orders the chain relative to Netfilter internal operations. See // also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_priority type ChainPriority int32 // Possible ChainPriority values. -const ( // from /usr/include/linux/netfilter_ipv4.h - ChainPriorityFirst ChainPriority = math.MinInt32 - ChainPriorityConntrackDefrag ChainPriority = -400 - ChainPriorityRaw ChainPriority = -300 - ChainPrioritySELinuxFirst ChainPriority = -225 - ChainPriorityConntrack ChainPriority = -200 - ChainPriorityMangle ChainPriority = -150 - ChainPriorityNATDest ChainPriority = -100 - ChainPriorityFilter ChainPriority = 0 - ChainPrioritySecurity ChainPriority = 50 - ChainPriorityNATSource ChainPriority = 100 - ChainPrioritySELinuxLast ChainPriority = 225 - ChainPriorityConntrackHelper ChainPriority = 300 - ChainPriorityConntrackConfirm ChainPriority = math.MaxInt32 - ChainPriorityLast ChainPriority = math.MaxInt32 +var ( // from /usr/include/linux/netfilter_ipv4.h + ChainPriorityFirst *ChainPriority = ChainPriorityRef(math.MinInt32) + ChainPriorityConntrackDefrag *ChainPriority = ChainPriorityRef(-400) + ChainPriorityRaw *ChainPriority = ChainPriorityRef(-300) + ChainPrioritySELinuxFirst *ChainPriority = ChainPriorityRef(-225) + ChainPriorityConntrack *ChainPriority = ChainPriorityRef(-200) + ChainPriorityMangle *ChainPriority = ChainPriorityRef(-150) + ChainPriorityNATDest *ChainPriority = ChainPriorityRef(-100) + ChainPriorityFilter *ChainPriority = ChainPriorityRef(0) + ChainPrioritySecurity *ChainPriority = ChainPriorityRef(50) + ChainPriorityNATSource *ChainPriority = ChainPriorityRef(100) + ChainPrioritySELinuxLast *ChainPriority = ChainPriorityRef(225) + ChainPriorityConntrackHelper *ChainPriority = ChainPriorityRef(300) + ChainPriorityConntrackConfirm *ChainPriority = ChainPriorityRef(math.MaxInt32) + ChainPriorityLast *ChainPriority = ChainPriorityRef(math.MaxInt32) ) +// ChainPriorityRef returns a pointer to a ChainPriority value. +func ChainPriorityRef(p ChainPriority) *ChainPriority { + return &p +} + // ChainType defines what this chain will be used for. See also // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_types type ChainType string @@ -87,8 +97,8 @@ const ( type Chain struct { Name string Table *Table - Hooknum ChainHook - Priority ChainPriority + Hooknum *ChainHook + Priority *ChainPriority Type ChainType Policy *ChainPolicy } @@ -103,10 +113,10 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - if c.Type != "" { + if c.Hooknum != nil && c.Priority != nil { hookAttr := []netlink.Attribute{ - {Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(c.Hooknum))}, - {Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(c.Priority))}, + {Type: unix.NFTA_HOOK_HOOKNUM, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Hooknum))}, + {Type: unix.NFTA_HOOK_PRIORITY, Data: binaryutil.BigEndian.PutUint32(uint32(*c.Priority))}, } data = append(data, cc.marshalAttr([]netlink.Attribute{ {Type: unix.NLA_F_NESTED | unix.NFTA_CHAIN_HOOK, Data: cc.marshalAttr(hookAttr)}, @@ -249,10 +259,10 @@ func chainFromMsg(msg netlink.Message) (*Chain, error) { return &c, nil } -func hookFromMsg(b []byte) (ChainHook, ChainPriority, error) { +func hookFromMsg(b []byte) (*ChainHook, *ChainPriority, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil { - return 0, 0, err + return nil, nil, err } ad.ByteOrder = binary.BigEndian @@ -269,5 +279,5 @@ func hookFromMsg(b []byte) (ChainHook, ChainPriority, error) { } } - return hooknum, prio, nil + return &hooknum, &prio, nil } diff --git a/nftables_test.go b/nftables_test.go index 9bf60b9..8324763 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1143,7 +1143,7 @@ func TestTProxy(t *testing.T) { Name: "divert", Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, - Priority: -150, + Priority: nftables.ChainPriorityRef(-150), }, Exprs: []expr.Any{ // [ payload load 1b @ network header + 9 => reg 1 ] @@ -1384,7 +1384,7 @@ func TestAddRuleWithPosition(t *testing.T) { Name: "ipv4chain-1", Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, - Priority: 0, + Priority: nftables.ChainPriorityRef(0), }, Exprs: []expr.Any{ @@ -1523,8 +1523,8 @@ func TestListChains(t *testing.T) { }, { Name: "undef", - Hooknum: 0, - Priority: 0, + Hooknum: nil, + Priority: nil, Policy: nil, }, } @@ -1564,8 +1564,16 @@ func TestListChains(t *testing.T) { for i, chain := range chains { validate(chain.Name, want[i].Name, "name", i) - validate(chain.Hooknum, want[i].Hooknum, "hooknum", i) - validate(chain.Priority, want[i].Priority, "priority", i) + if want[i].Hooknum != nil && chain.Hooknum != nil { + validate(*chain.Hooknum, *want[i].Hooknum, "hooknum value", i) + } else { + validate(chain.Hooknum, want[i].Hooknum, "hooknum pointer", i) + } + if want[i].Priority != nil && chain.Priority != nil { + validate(*chain.Priority, *want[i].Priority, "priority value", i) + } else { + validate(chain.Priority, want[i].Priority, "priority pointer", i) + } validate(chain.Type, want[i].Type, "type", i) if want[i].Policy != nil && chain.Policy != nil { @@ -1588,7 +1596,7 @@ func TestAddChain(t *testing.T) { chain: &nftables.Chain{ Name: "base-chain", Hooknum: nftables.ChainHookPrerouting, - Priority: 0, + Priority: nftables.ChainPriorityRef(0), Type: nftables.ChainTypeFilter, }, want: [][]byte{ @@ -1671,7 +1679,7 @@ func TestDelChain(t *testing.T) { chain: &nftables.Chain{ Name: "base-chain", Hooknum: nftables.ChainHookPrerouting, - Priority: 0, + Priority: nftables.ChainPriorityRef(0), Type: nftables.ChainTypeFilter, }, want: [][]byte{ @@ -3425,7 +3433,7 @@ func TestDynsetWithOneExpression(t *testing.T) { Name: "forward", Hooknum: nftables.ChainHookForward, Table: table, - Priority: 0, + Priority: nftables.ChainPriorityRef(0), Type: nftables.ChainTypeFilter, } set := &nftables.Set{ @@ -3527,7 +3535,7 @@ func TestDynsetWithMultipleExpressions(t *testing.T) { Name: "forward", Hooknum: nftables.ChainHookForward, Table: table, - Priority: 0, + Priority: nftables.ChainPriorityRef(0), Type: nftables.ChainTypeFilter, } set := &nftables.Set{