Fix concatenated key set validation (#83)
This commit is contained in:
parent
9a6c96795b
commit
756cfa14a8
39
set.go
39
set.go
|
@ -26,6 +26,10 @@ import (
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetConcatTypeBits defines concatination bits, originally defined in
|
||||||
|
// https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002
|
||||||
|
const SetConcatTypeBits = 6
|
||||||
|
|
||||||
var allocSetID uint32
|
var allocSetID uint32
|
||||||
|
|
||||||
// SetDatatype represents a datatype declared by nft.
|
// SetDatatype represents a datatype declared by nft.
|
||||||
|
@ -434,15 +438,10 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
|
||||||
set.IsMap = (flags & unix.NFTA_SET_TABLE) != 0
|
set.IsMap = (flags & unix.NFTA_SET_TABLE) != 0
|
||||||
case unix.NFTA_SET_KEY_TYPE:
|
case unix.NFTA_SET_KEY_TYPE:
|
||||||
nftMagic := ad.Uint32()
|
nftMagic := ad.Uint32()
|
||||||
for _, dt := range nftDatatypes {
|
if invalidMagic, ok := validateKeyType(nftMagic); !ok {
|
||||||
if nftMagic == dt.nftMagic {
|
return nil, fmt.Errorf("could not determine key type %+v", invalidMagic)
|
||||||
set.KeyType = dt
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if set.KeyType.nftMagic == 0 {
|
|
||||||
return nil, fmt.Errorf("could not determine key type %x", nftMagic)
|
|
||||||
}
|
}
|
||||||
|
set.KeyType.nftMagic = nftMagic
|
||||||
case unix.NFTA_SET_DATA_TYPE:
|
case unix.NFTA_SET_DATA_TYPE:
|
||||||
nftMagic := ad.Uint32()
|
nftMagic := ad.Uint32()
|
||||||
// Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1
|
// Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1
|
||||||
|
@ -464,6 +463,30 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
|
||||||
return &set, nil
|
return &set, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateKeyType(bits uint32) ([]uint32, bool) {
|
||||||
|
var unpackTypes []uint32
|
||||||
|
var invalidTypes []uint32
|
||||||
|
found := false
|
||||||
|
valid := true
|
||||||
|
for bits != 0 {
|
||||||
|
unpackTypes = append(unpackTypes, bits&0x3f)
|
||||||
|
bits = bits >> SetConcatTypeBits
|
||||||
|
}
|
||||||
|
for _, t := range unpackTypes {
|
||||||
|
for _, dt := range nftDatatypes {
|
||||||
|
if t == dt.nftMagic {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
invalidTypes = append(invalidTypes, t)
|
||||||
|
valid = false
|
||||||
|
}
|
||||||
|
found = false
|
||||||
|
}
|
||||||
|
return invalidTypes, valid
|
||||||
|
}
|
||||||
|
|
||||||
var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
|
var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
|
||||||
|
|
||||||
func elementsFromMsg(msg netlink.Message) ([]SetElement, error) {
|
func elementsFromMsg(msg netlink.Message) ([]SetElement, error) {
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func genSetKeyType(types ...uint32) uint32 {
|
||||||
|
c := types[0]
|
||||||
|
for i := 1; i < len(types); i++ {
|
||||||
|
c = c<<SetConcatTypeBits | types[i]
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateNFTMagic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
nftMagicPacked uint32
|
||||||
|
pass bool
|
||||||
|
invalid []uint32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Single valid nftMagic",
|
||||||
|
nftMagicPacked: genSetKeyType(7),
|
||||||
|
pass: true,
|
||||||
|
invalid: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single invalid nftMagic",
|
||||||
|
nftMagicPacked: genSetKeyType(25),
|
||||||
|
pass: false,
|
||||||
|
invalid: []uint32{25},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple valid nftMagic",
|
||||||
|
nftMagicPacked: genSetKeyType(7, 13),
|
||||||
|
pass: true,
|
||||||
|
invalid: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple nftMagic with 1 invalid",
|
||||||
|
nftMagicPacked: genSetKeyType(7, 13, 25),
|
||||||
|
pass: false,
|
||||||
|
invalid: []uint32{25},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple nftMagic with 2 invalid",
|
||||||
|
nftMagicPacked: genSetKeyType(7, 13, 25, 26),
|
||||||
|
pass: false,
|
||||||
|
invalid: []uint32{26, 25},
|
||||||
|
// Invalid entries will appear in reverse order
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
invalid, pass := validateKeyType(tt.nftMagicPacked)
|
||||||
|
if pass && !tt.pass {
|
||||||
|
t.Fatalf("expected to fail but succeeded")
|
||||||
|
}
|
||||||
|
if !pass && tt.pass {
|
||||||
|
t.Fatalf("expected to succeed but failed with invalid nftMagic: %+v", invalid)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(tt.invalid, invalid) {
|
||||||
|
t.Fatalf("Expected invalid: %+v but got: %+v", tt.invalid, invalid)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue