diff --git a/set.go b/set.go index 951568b..76c3d25 100644 --- a/set.go +++ b/set.go @@ -81,11 +81,25 @@ var ( } ) -// ConcatSetType constructs a new SetDatatype which consists of a concatenation of the passed types. It panics, if the -// nftMagic would overflow (more than 5 types) -func ConcatSetType(types ...SetDatatype) SetDatatype { +// ErrTooManyTypes is the error returned by ConcatSetType, if nftMagic would overflow. +var ErrTooManyTypes = errors.New("too many types to concat") + +// MustConcatSetType does the same as ConcatSetType, but panics instead of an +// error. It simplifies safe initialization of global variables. +func MustConcatSetType(types ...SetDatatype) SetDatatype { + t, err := ConcatSetType(types...) + if err != nil { + panic(err) + } + return t +} + +// ConcatSetType constructs a new SetDatatype which consists of a concatenation +// of the passed types. It returns ErrTooManyTypes, if nftMagic would overflow +// (more than 5 types). +func ConcatSetType(types ...SetDatatype) (SetDatatype, error) { if len(types) > 32/SetConcatTypeBits { - panic("too many type to concat") + return SetDatatype{}, ErrTooManyTypes } var magic, bytes uint32 @@ -102,7 +116,7 @@ func ConcatSetType(types ...SetDatatype) SetDatatype { magic <<= SetConcatTypeBits magic |= t.nftMagic & SetConcatTypeMask } - return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic} + return SetDatatype{Name: strings.Join(names, " . "), Bytes: bytes, nftMagic: magic}, nil } // Set represents an nftables set. Anonymous sets are only valid within the diff --git a/set_test.go b/set_test.go index 0c56c03..5417326 100644 --- a/set_test.go +++ b/set_test.go @@ -76,7 +76,7 @@ func TestConcatSetType(t *testing.T) { tests := []struct { name string types []SetDatatype - pass bool + err error concatName string concatBytes uint32 concatMagic uint32 @@ -84,12 +84,12 @@ func TestConcatSetType(t *testing.T) { { name: "Concatenate six (too many) IPv4s", types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr}, - pass: false, + err: ErrTooManyTypes, }, { name: "Concatenate five IPv4s", types: []SetDatatype{TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr, TypeIPAddr}, - pass: true, + err: nil, concatName: "ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr . ipv4_addr", concatBytes: 20, concatMagic: 0x071c71c7, @@ -97,7 +97,7 @@ func TestConcatSetType(t *testing.T) { { name: "Concatenate IPv6 and port", types: []SetDatatype{TypeIP6Addr, TypeInetService}, - pass: true, + err: nil, concatName: "ipv6_addr . inet_service", concatBytes: 20, concatMagic: 0x0000020d, @@ -105,7 +105,7 @@ func TestConcatSetType(t *testing.T) { { name: "Concatenate protocol and port", types: []SetDatatype{TypeInetProto, TypeInetService}, - pass: true, + err: nil, concatName: "inet_proto . inet_service", concatBytes: 8, concatMagic: 0x0000030d, @@ -114,14 +114,13 @@ func TestConcatSetType(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if !tt.pass { - defer func() { - if recover() == nil { - t.Fatalf("ConcatSetType() should have paniced but did not") - } - }() + concat, err := ConcatSetType(tt.types...) + if tt.err != err { + t.Errorf("ConcatSetType() returned an incorrect error: expected %v but got %v", tt.err, err) + } + if err != nil { + return } - concat := ConcatSetType(tt.types...) if tt.concatName != concat.Name { t.Errorf("invalid concatinated name: expceted %s but got %s", tt.concatName, concat.Name) }