diff --git a/expr/exthdr.go b/expr/exthdr.go index e47f268..9a3ed8c 100644 --- a/expr/exthdr.go +++ b/expr/exthdr.go @@ -15,7 +15,7 @@ package expr import ( - "fmt" + "encoding/binary" "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" @@ -40,16 +40,24 @@ type Exthdr struct { } func (e *Exthdr) marshal() ([]byte, error) { - data, err := netlink.MarshalAttributes([]netlink.Attribute{ - {Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, + // Operations are differentiated by the Op and whether the SourceRegister + // or DestRegister is set. Mixing them results in EINVAL. + attr := []netlink.Attribute{ {Type: unix.NFTA_EXTHDR_TYPE, Data: []byte{e.Type}}, {Type: unix.NFTA_EXTHDR_OFFSET, Data: binaryutil.BigEndian.PutUint32(e.Offset)}, {Type: unix.NFTA_EXTHDR_LEN, Data: binaryutil.BigEndian.PutUint32(e.Len)}, - // TODO: these fields seem to be conditional? - //{Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, - //{Type: unix.NFTA_EXTHDR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}, {Type: unix.NFTA_EXTHDR_OP, Data: binaryutil.BigEndian.PutUint32(uint32(e.Op))}, - }) + } + if e.SourceRegister != 0 { + attr = append(attr, + netlink.Attribute{Type: unix.NFTA_EXTHDR_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}) + } else { + attr = append(attr, + netlink.Attribute{Type: unix.NFTA_EXTHDR_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, + netlink.Attribute{Type: unix.NFTA_EXTHDR_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}) + } + + data, err := netlink.MarshalAttributes(attr) if err != nil { return nil, err } @@ -60,5 +68,28 @@ func (e *Exthdr) marshal() ([]byte, error) { } func (e *Exthdr) unmarshal(data []byte) error { - return fmt.Errorf("not yet implemented") + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_EXTHDR_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_EXTHDR_TYPE: + e.Type = ad.Uint8() + case unix.NFTA_EXTHDR_OFFSET: + e.Offset = ad.Uint32() + case unix.NFTA_EXTHDR_LEN: + e.Len = ad.Uint32() + case unix.NFTA_EXTHDR_FLAGS: + e.Flags = ad.Uint32() + case unix.NFTA_EXTHDR_OP: + e.Op = ExthdrOp(ad.Uint32()) + case unix.NFTA_EXTHDR_SREG: + e.SourceRegister = ad.Uint32() + } + } + return ad.Err() } diff --git a/expr/exthdr_test.go b/expr/exthdr_test.go new file mode 100644 index 0000000..a573436 --- /dev/null +++ b/expr/exthdr_test.go @@ -0,0 +1,70 @@ +package expr + +import ( + "encoding/binary" + "reflect" + "testing" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" +) + +func TestExthdr(t *testing.T) { + t.Parallel() + tests := []struct { + name string + eh Exthdr + }{ + { + name: "Unmarshal Exthdr DestRegister case", + eh: Exthdr{ + DestRegister: 1, + Type: 2, + Offset: 3, + Len: 4, + Flags: 5, + Op: ExthdrOpTcpopt, + SourceRegister: 0, + }, + }, + { + name: "Unmarshal Exthdr SourceRegister case", + eh: Exthdr{ + SourceRegister: 1, + Type: 2, + Offset: 3, + Len: 4, + Op: ExthdrOpTcpopt, + DestRegister: 0, + Flags: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + neh := Exthdr{} + data, err := tt.eh.marshal() + if err != nil { + t.Fatalf("marshal error: %+v", err) + + } + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + t.Fatalf("NewAttributeDecoder() error: %+v", err) + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + if ad.Type() == unix.NFTA_EXPR_DATA { + if err := neh.unmarshal(ad.Bytes()); err != nil { + t.Errorf("unmarshal error: %+v", err) + break + } + } + } + if !reflect.DeepEqual(tt.eh, neh) { + t.Fatalf("original %+v and recovered %+v Exthdr structs are different", tt.eh, neh) + } + }) + } +} diff --git a/rule.go b/rule.go index a29d4b9..be5473f 100644 --- a/rule.go +++ b/rule.go @@ -260,6 +260,8 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) { e = &expr.Dynset{} case "log": e = &expr.Log{} + case "exthdr": + e = &expr.Exthdr{} } if e == nil { // TODO: introduce an opaque expression type so that users know