diff --git a/expr/lookup.go b/expr/lookup.go index 6b7440c..e6593ac 100644 --- a/expr/lookup.go +++ b/expr/lookup.go @@ -76,6 +76,7 @@ func (e *Lookup) unmarshal(fam byte, data []byte) error { e.SourceRegister = ad.Uint32() case unix.NFTA_LOOKUP_DREG: e.DestRegister = ad.Uint32() + e.IsDestRegSet = true case unix.NFTA_LOOKUP_FLAGS: e.Invert = (ad.Uint32() & unix.NFT_LOOKUP_F_INV) != 0 } diff --git a/nftables_test.go b/nftables_test.go index 2dc8a2a..9bf60b9 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3095,6 +3095,102 @@ func TestFlushTable(t *testing.T) { } } +func TestGetLookupExprDestSet(t *testing.T) { + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + forward := c.AddChain(&nftables.Chain{ + Name: "forward", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + + set := &nftables.Set{ + Table: filter, + Name: "kek", + IsMap: true, + KeyType: nftables.TypeInetService, + DataType: nftables.TypeVerdict, + } + if err := c.AddSet(set, nil); err != nil { + t.Errorf("c.AddSet(set) failed: %v", err) + } + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: forward, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: set.Name, + SetID: set.ID, + DestRegister: 0, + IsDestRegSet: true, + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + rules, err := c.GetRules( + &nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }, + &nftables.Chain{ + Name: "forward", + }, + ) + if err != nil { + t.Fatal(err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + if got, want := len(rules[0].Exprs), 4; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + lookup, lookupOk := rules[0].Exprs[3].(*expr.Lookup) + if !lookupOk { + t.Fatalf("Exprs[3] is type %T, want *expr.Lookup", rules[0].Exprs[3]) + } + if want := (&expr.Lookup{ + SourceRegister: 1, + SetName: set.Name, + DestRegister: 0, + IsDestRegSet: true, + }); !reflect.DeepEqual(lookup, want) { + t.Errorf("lookup expr = %+v, wanted %+v", lookup, want) + } +} + func TestGetRuleLookupVerdictImmediate(t *testing.T) { // Create a new network namespace to test these operations, // and tear down the namespace at test completion.