Fix for ListChains policy bug (#144)

Fixes https://github.com/google/nftables/issues/130 | Added a test case for ListChains func
This commit is contained in:
turekt 2022-02-06 17:44:06 +00:00 committed by GitHub
parent a46119e592
commit 91d3b4571d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 1 deletions

View File

@ -229,7 +229,7 @@ func chainFromMsg(msg netlink.Message) (*Chain, error) {
case unix.NFTA_CHAIN_TYPE: case unix.NFTA_CHAIN_TYPE:
c.Type = ChainType(ad.String()) c.Type = ChainType(ad.String())
case unix.NFTA_CHAIN_POLICY: case unix.NFTA_CHAIN_POLICY:
policy := ChainPolicy(ad.Uint32()) policy := ChainPolicy(binaryutil.BigEndian.Uint32(ad.Bytes()))
c.Policy = &policy c.Policy = &policy
case unix.NFTA_CHAIN_HOOK: case unix.NFTA_CHAIN_HOOK:
ad.Do(func(b []byte) error { ad.Do(func(b []byte) error {

View File

@ -1170,6 +1170,97 @@ func TestAddRuleWithPosition(t *testing.T) {
} }
} }
func TestListChains(t *testing.T) {
polDrop := nftables.ChainPolicyDrop
polAcpt := nftables.ChainPolicyAccept
reply := [][]byte{
// chain input { type filter hook input priority filter; policy accept; }
[]byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x00\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x01\x0a\x00\x03\x00\x69\x6e\x70\x75\x74\x00\x00\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x01\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"),
// chain forward { type filter hook forward priority filter; policy drop; }
[]byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x01\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x02\x0c\x00\x03\x00\x66\x6f\x72\x77\x61\x72\x64\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"),
// chain output { type filter hook output priority filter; policy accept; }
[]byte("\x70\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x02\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x03\x0b\x00\x03\x00\x6f\x75\x74\x70\x75\x74\x00\x00\x14\x00\x04\x00\x08\x00\x01\x00\x00\x00\x00\x03\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x05\x00\x00\x00\x00\x01\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x08\x00\x0a\x00\x00\x00\x00\x01\x08\x00\x06\x00\x00\x00\x00\x00"),
// chain undef { counter packets 56235 bytes 175436495 return }
[]byte("\x40\x00\x00\x00\x03\x0a\x02\x00\x00\x00\x00\x03\xb8\x76\x02\x00\x01\x00\x00\xc3\x0b\x00\x01\x00\x66\x69\x6c\x74\x65\x72\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x04\x0a\x00\x03\x00\x75\x6e\x64\x65\x66\x00\x00\x00\x08\x00\x06\x00\x00\x00\x00\x01"),
[]byte("\x14\x00\x00\x00\x03\x00\x02\x00\x00\x00\x00\x04\xb8\x76\x02\x00\x00\x00\x00\x00"),
}
want := []*nftables.Chain{
{
Name: "input",
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAcpt,
},
{
Name: "forward",
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polDrop,
},
{
Name: "output",
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAcpt,
},
{
Name: "undef",
Hooknum: 0,
Priority: 0,
Policy: nil,
},
}
c := &nftables.Conn{
TestDial: func(req []netlink.Message) ([]netlink.Message, error) {
msgReply := make([]netlink.Message, len(reply))
for i, r := range reply {
nm := &netlink.Message{}
nm.UnmarshalBinary(r)
nm.Header.Sequence = req[0].Header.Sequence
nm.Header.PID = req[0].Header.PID
msgReply[i] = *nm
}
return msgReply, nil
},
}
chains, err := c.ListChains()
if err != nil {
t.Errorf("error returned from TestDial %v", err)
return
}
if len(chains) != len(want) {
t.Errorf("number of chains %d != number of want %d", len(chains), len(want))
return
}
validate := func(got interface{}, want interface{}, name string, index int) {
if got != want {
t.Errorf("chain %d: chain %s mismatch, got %v want %v", index, name, got, want)
}
}
for i, chain := range chains {
validate(chain.Name, want[i].Name, "name", i)
validate(chain.Hooknum, want[i].Hooknum, "hooknum", i)
validate(chain.Priority, want[i].Priority, "priority", i)
validate(chain.Type, want[i].Type, "type", i)
if want[i].Policy != nil && chain.Policy != nil {
validate(*chain.Policy, *want[i].Policy, "policy value", i)
} else {
validate(chain.Policy, want[i].Policy, "policy pointer", i)
}
}
}
func TestAddChain(t *testing.T) { func TestAddChain(t *testing.T) {
tests := []struct { tests := []struct {
name string name string