diff --git a/expr/dynset.go b/expr/dynset.go index b6087ec..cfaaefd 100644 --- a/expr/dynset.go +++ b/expr/dynset.go @@ -38,7 +38,9 @@ func (e *Dynset) marshal() ([]byte, error) { // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c var opAttrs []netlink.Attribute opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) - opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_DATA, Data: binaryutil.BigEndian.PutUint32(e.SrcRegData)}) + if e.SrcRegData != 0 { + opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_DATA, Data: binaryutil.BigEndian.PutUint32(e.SrcRegData)}) + } opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_OP, Data: binaryutil.BigEndian.PutUint32(e.Operation)}) if e.Timeout != 0 { opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_TIMEOUT, Data: binaryutil.BigEndian.PutUint64(uint64(e.Timeout.Milliseconds()))}) @@ -76,6 +78,8 @@ func (e *Dynset) unmarshal(data []byte) error { e.SrcRegKey = ad.Uint32() case unix.NFTA_DYNSET_SREG_DATA: e.SrcRegData = ad.Uint32() + case unix.NFTA_DYNSET_OP: + e.Operation = ad.Uint32() case unix.NFTA_DYNSET_TIMEOUT: e.Timeout = time.Duration(ad.Uint64() * 1000) case unix.NFTA_DYNSET_FLAGS: diff --git a/nftables_test.go b/nftables_test.go index dcf6b27..5db5672 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -2530,19 +2530,14 @@ func TestDynset(t *testing.T) { &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Immediate{ - Register: 2, - Data: []byte{0x0, 0x0, 0x0, 0x2}, + Offset: uint32(12), + Len: uint32(4), }, &expr.Dynset{ - SrcRegKey: 1, - SrcRegData: 2, - SetName: set.Name, - SetID: set.ID, - Operation: uint32(unix.NFT_DYNSET_OP_UPDATE), + SrcRegKey: 1, + SetName: set.Name, + SetID: set.ID, + Operation: uint32(unix.NFT_DYNSET_OP_UPDATE), }, }, }) @@ -2567,20 +2562,18 @@ func TestDynset(t *testing.T) { if got, want := len(rules), 1; got != want { t.Fatalf("unexpected number of rules: got %d, want %d", got, want) } - if got, want := len(rules[0].Exprs), 3; got != want { + if got, want := len(rules[0].Exprs), 2; got != want { t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) } - dynset, dynsetOk := rules[0].Exprs[2].(*expr.Dynset) + dynset, dynsetOk := rules[0].Exprs[1].(*expr.Dynset) if !dynsetOk { - t.Fatalf("Exprs[3] is type %T, want *expr.Dynset", rules[0].Exprs[2]) + t.Fatalf("Exprs[0] is type %T, want *expr.Dynset", rules[0].Exprs[1]) } if want := (&expr.Dynset{ - SrcRegKey: 1, - SrcRegData: 2, - SetName: set.Name, - SetID: set.ID, - Operation: uint32(unix.NFT_DYNSET_OP_UPDATE), + SrcRegKey: 1, + SetName: set.Name, + Operation: uint32(unix.NFT_DYNSET_OP_UPDATE), }); !reflect.DeepEqual(dynset, want) { t.Errorf("dynset expr = %+v, wanted %+v", dynset, want) } diff --git a/rule.go b/rule.go index a86ce97..6fda09b 100644 --- a/rule.go +++ b/rule.go @@ -240,6 +240,8 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) { e = &expr.Redir{} case "nat": e = &expr.NAT{} + case "dynset": + e = &expr.Dynset{} } if e == nil { // TODO: introduce an opaque expression type so that users know