turned panic in ConcatSetType into an error, added MustConcatSetType

This commit is contained in:
Leon Vack 2020-01-21 16:22:52 +01:00
parent dcf82b4d03
commit 5d59b9a7fb
No known key found for this signature in database
GPG Key ID: B66DAB934BCECCB7
2 changed files with 30 additions and 17 deletions

24
set.go
View File

@ -81,11 +81,25 @@ var (
}
)
// ConcatSetType constructs a new SetDatatype which consists of a concatenation of the passed types. It panics, if the
// nftMagic would overflow (more than 5 types)
func ConcatSetType(types ...SetDatatype) SetDatatype {
// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow.
var ErrTooManyTypes = errors.New("too many types to concat")
// MustConcatSetType does the same as ConcatSetType, but panics instead of an
// error. It simplifies safe initialization of global variables.
func MustConcatSetType(types ...SetDatatype) SetDatatype {
t, err := ConcatSetType(types...)
if err != nil {
panic(err)
}
return t
}
// ConcatSetType constructs a new SetDatatype which consists of a concatenation
// of the passed types. It returns ErrTooManyTypes, if nftMagic would overflow
// (more than 5 types).
func ConcatSetType(types ...SetDatatype) (SetDatatype, error) {
if len(types) > 32/SetConcatTypeBits {
panic("too many type to concat")
return SetDatatype{}, ErrTooManyTypes
}
var magic, bytes uint32
@ -102,7 +116,7 @@ func ConcatSetType(types ...SetDatatype) SetDatatype {
magic <<= SetConcatTypeBits
magic |= t.nftMagic & SetConcatTypeMask
}
return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}
return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil
}
// Set represents an nftables set. Anonymous sets are only valid within the

View File

@ -76,7 +76,7 @@ func TestConcatSetType(t *testing.T) {
tests := []struct {
name string
types []SetDatatype
pass bool
err error
concatName string
concatBytes uint32
concatMagic uint32
@ -84,12 +84,12 @@ func TestConcatSetType(t *testing.T) {
{
name: "Concatenate six (too many) IPv4s",
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
pass: false,
err: ErrTooManyTypes,
},
{
name: "Concatenate five IPv4s",
types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr},
pass: true,
err: nil,
concatName: "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr",
concatBytes: 20,
concatMagic: 0x071c71c7,
@ -97,7 +97,7 @@ func TestConcatSetType(t *testing.T) {
{
name: "Concatenate IPv6 and port",
types: []SetDatatype{TypeIP6Addr, TypeInetService},
pass: true,
err: nil,
concatName: "ipv6_addr . inet_service",
concatBytes: 20,
concatMagic: 0x0000020d,
@ -105,7 +105,7 @@ func TestConcatSetType(t *testing.T) {
{
name: "Concatenate protocol and port",
types: []SetDatatype{TypeInetProto, TypeInetService},
pass: true,
err: nil,
concatName: "inet_proto . inet_service",
concatBytes: 8,
concatMagic: 0x0000030d,
@ -114,14 +114,13 @@ func TestConcatSetType(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !tt.pass {
defer func() {
if recover() == nil {
t.Fatalf("ConcatSetType() should have paniced but did not")
}
}()
concat, err := ConcatSetType(tt.types...)
if tt.err != err {
t.Errorf("ConcatSetType() returned an incorrect error: expected %v but got %v", tt.err, err)
}
if err != nil {
return
}
concat := ConcatSetType(tt.types...)
if tt.concatName != concat.Name {
t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name)
}