Fix concatenated key set validation (#83)

This commit is contained in:
Serguei Bezverkhi 2019-12-17 18:02:00 -05:00 committed by Michael Stapelberg
parent 9a6c96795b
commit 756cfa14a8
2 changed files with 102 additions and 8 deletions

39
set.go
View File

@ -26,6 +26,10 @@ import (
"golang.org/x/sys/unix"
)
// SetConcatTypeBits defines concatination bits, originally defined in
// https://git.netfilter.org/iptables/tree/iptables/nft.c?id=26753888720d8e7eb422ae4311348347f5a05cb4#n1002
const SetConcatTypeBits = 6
var allocSetID uint32
// SetDatatype represents a datatype declared by nft.
@ -434,15 +438,10 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
set.IsMap = (flags & unix.NFTA_SET_TABLE) != 0
case unix.NFTA_SET_KEY_TYPE:
nftMagic := ad.Uint32()
for _, dt := range nftDatatypes {
if nftMagic == dt.nftMagic {
set.KeyType = dt
break
}
}
if set.KeyType.nftMagic == 0 {
return nil, fmt.Errorf("could not determine key type %x", nftMagic)
if invalidMagic, ok := validateKeyType(nftMagic); !ok {
return nil, fmt.Errorf("could not determine key type %+v", invalidMagic)
}
set.KeyType.nftMagic = nftMagic
case unix.NFTA_SET_DATA_TYPE:
nftMagic := ad.Uint32()
// Special case for the data type verdict, in the message it is stored as 0xffffff00 but it is defined as 1
@ -464,6 +463,30 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
return &set, nil
}
func validateKeyType(bits uint32) ([]uint32, bool) {
var unpackTypes []uint32
var invalidTypes []uint32
found := false
valid := true
for bits != 0 {
unpackTypes = append(unpackTypes, bits&0x3f)
bits = bits >> SetConcatTypeBits
}
for _, t := range unpackTypes {
for _, dt := range nftDatatypes {
if t == dt.nftMagic {
found = true
}
}
if !found {
invalidTypes = append(invalidTypes, t)
valid = false
}
found = false
}
return invalidTypes, valid
}
var elemHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM)
func elementsFromMsg(msg netlink.Message) ([]SetElement, error) {

71
set_test.go Normal file
View File

@ -0,0 +1,71 @@
package nftables
import (
"reflect"
"testing"
)
func genSetKeyType(types ...uint32) uint32 {
c := types[0]
for i := 1; i < len(types); i++ {
c = c<<SetConcatTypeBits | types[i]
}
return c
}
func TestValidateNFTMagic(t *testing.T) {
t.Parallel()
tests := []struct {
name string
nftMagicPacked uint32
pass bool
invalid []uint32
}{
{
name: "Single valid nftMagic",
nftMagicPacked: genSetKeyType(7),
pass: true,
invalid: nil,
},
{
name: "Single invalid nftMagic",
nftMagicPacked: genSetKeyType(25),
pass: false,
invalid: []uint32{25},
},
{
name: "Multiple valid nftMagic",
nftMagicPacked: genSetKeyType(7, 13),
pass: true,
invalid: nil,
},
{
name: "Multiple nftMagic with 1 invalid",
nftMagicPacked: genSetKeyType(7, 13, 25),
pass: false,
invalid: []uint32{25},
},
{
name: "Multiple nftMagic with 2 invalid",
nftMagicPacked: genSetKeyType(7, 13, 25, 26),
pass: false,
invalid: []uint32{26, 25},
// Invalid entries will appear in reverse order
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
invalid, pass := validateKeyType(tt.nftMagicPacked)
if pass && !tt.pass {
t.Fatalf("expected to fail but succeeded")
}
if !pass && tt.pass {
t.Fatalf("expected to succeed but failed with invalid nftMagic: %+v", invalid)
}
if !reflect.DeepEqual(tt.invalid, invalid) {
t.Fatalf("Expected invalid: %+v but got: %+v", tt.invalid, invalid)
}
})
}
}