IsDestRegSet unmarshaling fix (#178)

Fixes https://github.com/google/nftables/issues/176 | Added test case
This commit is contained in:
turekt 2022-08-30 17:03:33 +00:00 committed by GitHub
parent 2eca001357
commit e4bff45b7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 97 additions and 0 deletions

View File

@ -76,6 +76,7 @@ func (e *Lookup) unmarshal(fam byte, data []byte) error {
e.SourceRegister = ad.Uint32()
case unix.NFTA_LOOKUP_DREG:
e.DestRegister = ad.Uint32()
e.IsDestRegSet = true
case unix.NFTA_LOOKUP_FLAGS:
e.Invert = (ad.Uint32() & unix.NFT_LOOKUP_F_INV) != 0
}

View File

@ -3095,6 +3095,102 @@ func TestFlushTable(t *testing.T) {
}
}
func TestGetLookupExprDestSet(t *testing.T) {
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
c.FlushRuleset()
defer c.FlushRuleset()
filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})
forward := c.AddChain(&nftables.Chain{
Name: "forward",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
})
set := &nftables.Set{
Table: filter,
Name: "kek",
IsMap: true,
KeyType: nftables.TypeInetService,
DataType: nftables.TypeVerdict,
}
if err := c.AddSet(set, nil); err != nil {
t.Errorf("c.AddSet(set) failed: %v", err)
}
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
c.AddRule(&nftables.Rule{
Table: filter,
Chain: forward,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetName: set.Name,
SetID: set.ID,
DestRegister: 0,
IsDestRegSet: true,
},
},
})
if err := c.Flush(); err != nil {
t.Errorf("c.Flush() failed: %v", err)
}
rules, err := c.GetRules(
&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
},
&nftables.Chain{
Name: "forward",
},
)
if err != nil {
t.Fatal(err)
}
if got, want := len(rules), 1; got != want {
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
}
if got, want := len(rules[0].Exprs), 4; got != want {
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
}
lookup, lookupOk := rules[0].Exprs[3].(*expr.Lookup)
if !lookupOk {
t.Fatalf("Exprs[3] is type %T, want *expr.Lookup", rules[0].Exprs[3])
}
if want := (&expr.Lookup{
SourceRegister: 1,
SetName: set.Name,
DestRegister: 0,
IsDestRegSet: true,
}); !reflect.DeepEqual(lookup, want) {
t.Errorf("lookup expr = %+v, wanted %+v", lookup, want)
}
}
func TestGetRuleLookupVerdictImmediate(t *testing.T) {
// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.