diff --git a/expr/expr.go b/expr/expr.go index bf386d6..2ae57a9 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -336,6 +336,8 @@ const ( NF_NAT_RANGE_PERSISTENT = unix.NF_NAT_RANGE_PERSISTENT // NF_NAT_RANGE_PREFIX defines flag for a prefix masquerade NF_NAT_RANGE_PREFIX = unix.NF_NAT_RANGE_NETMAP + // NF_NAT_RANGE_PROTO_SPECIFIED defines flag for a specified range + NF_NAT_RANGE_PROTO_SPECIFIED = unix.NF_NAT_RANGE_PROTO_SPECIFIED ) func (e *Masq) marshal(fam byte) ([]byte, error) { diff --git a/expr/nat.go b/expr/nat.go index eded7da..3f28967 100644 --- a/expr/nat.go +++ b/expr/nat.go @@ -41,6 +41,7 @@ type NAT struct { FullyRandom bool Persistent bool Prefix bool + Specified bool } // |00048|N-|00001| |len |flags| type| @@ -97,6 +98,9 @@ func (e *NAT) marshalData(fam byte) ([]byte, error) { if e.Prefix { flags |= NF_NAT_RANGE_PREFIX } + if e.Specified { + flags |= NF_NAT_RANGE_PROTO_SPECIFIED + } if flags != 0 { attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_NAT_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)}) } @@ -130,6 +134,7 @@ func (e *NAT) unmarshal(fam byte, data []byte) error { e.Random = (flags & NF_NAT_RANGE_PROTO_RANDOM) != 0 e.FullyRandom = (flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY) != 0 e.Prefix = (flags & NF_NAT_RANGE_PREFIX) != 0 + e.Specified = (flags & NF_NAT_RANGE_PROTO_SPECIFIED) != 0 } } return ad.Err() diff --git a/expr/nat_test.go b/expr/nat_test.go new file mode 100644 index 0000000..d2d34aa --- /dev/null +++ b/expr/nat_test.go @@ -0,0 +1,66 @@ +package expr + +import ( + "encoding/binary" + "reflect" + "testing" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func TestNat(t *testing.T) { + t.Parallel() + tests := []struct { + name string + nat NAT + }{ + { + name: "Unmarshal DNAT specified case", + nat: NAT{ + Type: NATTypeDestNAT, + Family: unix.NFPROTO_IPV4, + RegAddrMin: 1, + RegProtoMin: 2, + Specified: true, + }, + }, + { + name: "Unmarshal SNAT persistent case", + nat: NAT{ + Type: NATTypeSourceNAT, + Family: unix.NFPROTO_IPV4, + RegAddrMin: 1, + RegProtoMin: 2, + Persistent: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nnat := NAT{} + data, err := tt.nat.marshal(0 /* don't care in this test */) + if err != nil { + t.Fatalf("marshal error: %+v", err) + + } + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + t.Fatalf("NewAttributeDecoder() error: %+v", err) + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + if ad.Type() == unix.NFTA_EXPR_DATA { + if err := nnat.unmarshal(0, ad.Bytes()); err != nil { + t.Errorf("unmarshal error: %+v", err) + break + } + } + } + if !reflect.DeepEqual(tt.nat, nnat) { + t.Fatalf("original %+v and recovered %+v Ct structs are different", tt.nat, nnat) + } + }) + } +}