diff --git a/chain.go b/chain.go index 74caca5..bcc35de 100644 --- a/chain.go +++ b/chain.go @@ -229,7 +229,7 @@ func chainFromMsg(msg netlink.Message) (*Chain, error) { case unix.NFTA_CHAIN_TYPE: c.Type = ChainType(ad.String()) case unix.NFTA_CHAIN_POLICY: - policy := ChainPolicy(ad.Uint32()) + policy := ChainPolicy(binaryutil.BigEndian.Uint32(ad.Bytes())) c.Policy = &policy case unix.NFTA_CHAIN_HOOK: ad.Do(func(b []byte) error { diff --git a/nftables_test.go b/nftables_test.go index 5986022..51450f5 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1170,6 +1170,97 @@ func TestAddRuleWithPosition(t *testing.T) { } } +func TestListChains(t *testing.T) { + polDrop := nftables.ChainPolicyDrop + polAcpt := nftables.ChainPolicyAccept + reply := [][]byte{ + // chain input { type filter hook input priority filter; policy accept; } + []byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x00\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x01\x0a\x00\x03\x00\x69\x6e\x70\x75\x74\x00\x00\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x01\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"), + // chain forward { type filter hook forward priority filter; policy drop; } + []byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x01\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x02\x0c\x00\x03\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"), + // chain output { type filter hook output priority filter; policy accept; } + []byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x02\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x03\x0b\x00\x03\x00\x6f\x75\x74\x70\x75\x74\x00\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x03\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x01\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"), + // chain undef { counter packets 56235 bytes 175436495 return } + []byte("\x40\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x03\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x04\x0a\x00\x03\x00\x75\x6e\x64\x65\x66\x00\x00\x00\x08\x00\x06\x00\x00\x00\x00\x01"), + []byte("\x14\x00\x00\x00\x03\x00\x02\x00\x00\x00\x00\x04\xb8\x76\x02\x00\x00\x00\x00\x00"), + } + + want := []*nftables.Chain{ + { + Name: "input", + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + Policy: &polAcpt, + }, + { + Name: "forward", + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + Policy: &polDrop, + }, + { + Name: "output", + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + Policy: &polAcpt, + }, + { + Name: "undef", + Hooknum: 0, + Priority: 0, + Policy: nil, + }, + } + + c := &nftables.Conn{ + TestDial: func(req []netlink.Message) ([]netlink.Message, error) { + msgReply := make([]netlink.Message, len(reply)) + for i, r := range reply { + nm := &netlink.Message{} + nm.UnmarshalBinary(r) + nm.Header.Sequence = req[0].Header.Sequence + nm.Header.PID = req[0].Header.PID + msgReply[i] = *nm + } + return msgReply, nil + }, + } + + chains, err := c.ListChains() + if err != nil { + t.Errorf("error returned from TestDial %v", err) + return + } + + if len(chains) != len(want) { + t.Errorf("number of chains %d != number of want %d", len(chains), len(want)) + return + } + + validate := func(got interface{}, want interface{}, name string, index int) { + if got != want { + t.Errorf("chain %d: chain %s mismatch, got %v want %v", index, name, got, want) + } + } + + for i, chain := range chains { + validate(chain.Name, want[i].Name, "name", i) + validate(chain.Hooknum, want[i].Hooknum, "hooknum", i) + validate(chain.Priority, want[i].Priority, "priority", i) + validate(chain.Type, want[i].Type, "type", i) + + if want[i].Policy != nil && chain.Policy != nil { + validate(*chain.Policy, *want[i].Policy, "policy value", i) + } else { + validate(chain.Policy, want[i].Policy, "policy pointer", i) + } + } + +} + func TestAddChain(t *testing.T) { tests := []struct { name string