This commit is contained in:
Francesco Cheinasso 2024-01-11 10:43:17 +01:00
parent fbea8aee17
commit ae1b9c8df1
3 changed files with 77 additions and 97 deletions

View File

@ -454,12 +454,7 @@ func TestConfigureNAT(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
dnatfirstip, err := nftables.GetFirstIPFromCIDR("20.0.0.0/24") dnatfirstip, dnatlastip, err := nftables.GetFirstAndLastIPFromCIDR("20.0.0.0/24")
if err != nil {
t.Fatal(err)
}
dnatlastip, err := nftables.GetLastIPFromCIDR("20.0.0.0/24")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -478,6 +473,8 @@ func TestConfigureNAT(t *testing.T) {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
// By specifying Xor to 0x0,0x0,0x0,0x0 and Mask to the CIDR mask,
// the rule will match the CIDR of the IP (e.g in this case 10.0.0.0/24).
Xor: []byte{0x0, 0x0, 0x0, 0x0}, Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: dstcidrmatch.Mask, Mask: dstcidrmatch.Mask,
}, },
@ -488,11 +485,11 @@ func TestConfigureNAT(t *testing.T) {
}, },
&expr.Immediate{ &expr.Immediate{
Register: 1, Register: 1,
Data: *dnatfirstip, Data: dnatfirstip,
}, },
&expr.Immediate{ &expr.Immediate{
Register: 2, Register: 2,
Data: *dnatlastip, Data: dnatlastip,
}, },
&expr.NAT{ &expr.NAT{
Type: expr.NATTypeDestNAT, Type: expr.NATTypeDestNAT,

42
util.go
View File

@ -46,34 +46,32 @@ func (genmsg *NFGenMsg) Decode(b []byte) {
genmsg.ResourceID = binary.BigEndian.Uint16(b[2:]) genmsg.ResourceID = binary.BigEndian.Uint16(b[2:])
} }
// GetFirstIPFromCIDR returns the first IP address from a CIDR. // GetFirstAndLastIPFromCIDR returns the first and last IP address from a CIDR.
func GetFirstIPFromCIDR(cidr string) (*net.IP, error) { func GetFirstAndLastIPFromCIDR(cidr string) (firstIP, lastIP net.IP, err error) {
_, subnet, err := net.ParseCIDR(cidr) _, subnet, err := net.ParseCIDR(cidr)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
firstIP = make(net.IP, len(subnet.IP))
lastIP = make(net.IP, len(subnet.IP))
switch len(subnet.IP) {
case net.IPv4len:
mask := binary.BigEndian.Uint32(subnet.Mask) mask := binary.BigEndian.Uint32(subnet.Mask)
ip := binary.BigEndian.Uint32(subnet.IP) ip := binary.BigEndian.Uint32(subnet.IP)
// find the final address
firstIP := make(net.IP, 4)
binary.BigEndian.PutUint32(firstIP, ip&mask) binary.BigEndian.PutUint32(firstIP, ip&mask)
return &firstIP, nil
}
// GetLastIPFromCIDR returns the last IP address from a CIDR.
func GetLastIPFromCIDR(cidr string) (*net.IP, error) {
_, subnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
mask := binary.BigEndian.Uint32(subnet.Mask)
ip := binary.BigEndian.Uint32(subnet.IP)
// find the final address
lastIP := make(net.IP, 4)
binary.BigEndian.PutUint32(lastIP, (ip&mask)|(mask^0xffffffff)) binary.BigEndian.PutUint32(lastIP, (ip&mask)|(mask^0xffffffff))
case net.IPv6len:
return &lastIP, nil mask1 := binary.BigEndian.Uint64(subnet.Mask[:8])
mask2 := binary.BigEndian.Uint64(subnet.Mask[8:])
ip1 := binary.BigEndian.Uint64(subnet.IP[:8])
ip2 := binary.BigEndian.Uint64(subnet.IP[8:])
binary.BigEndian.PutUint64(firstIP[:8], ip1&mask1)
binary.BigEndian.PutUint64(firstIP[8:], ip2&mask2)
binary.BigEndian.PutUint64(lastIP[:8], (ip1&mask1)|(mask1^0xffffffffffffffff))
binary.BigEndian.PutUint64(lastIP[8:], (ip2&mask2)|(mask2^0xffffffffffffffff))
}
return firstIP, lastIP, nil
} }

View File

@ -6,87 +6,72 @@ import (
"testing" "testing"
) )
func TestGetFirstIPFromCIDR(t *testing.T) { func TestGetFirstAndLastIPFromCIDR(t *testing.T) {
type args struct { type args struct {
cidr string cidr string
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *net.IP wantFirstIP net.IP
wantLastIP net.IP
wantErr bool wantErr bool
}{ }{
{ {
name: "Test 0", name: "Test Fake",
args: args{cidr: "fakecidr"}, args: args{cidr: "fakecidr"},
want: nil, wantFirstIP: nil,
wantLastIP: nil,
wantErr: true, wantErr: true,
}, },
{ {
name: "Test 1", name: "Test IPV4 1",
args: args{cidr: "10.0.0.0/24"}, args: args{cidr: "10.0.0.0/24"},
want: &net.IP{10, 0, 0, 0}, wantFirstIP: net.IP{10, 0, 0, 0},
wantLastIP: net.IP{10, 0, 0, 255},
wantErr: false, wantErr: false,
}, },
{ {
name: "Test 2", name: "Test IPV4 2",
args: args{cidr: "10.0.0.20/24"}, args: args{cidr: "10.0.0.20/24"},
want: &net.IP{10, 0, 0, 0}, wantFirstIP: net.IP{10, 0, 0, 0},
wantLastIP: net.IP{10, 0, 0, 255},
wantErr: false,
},
{
name: "Test IPV4 2",
args: args{cidr: "10.0.0.0/19"},
wantFirstIP: net.IP{10, 0, 0, 0},
wantLastIP: net.IP{10, 0, 31, 255},
wantErr: false,
},
{
name: "Test IPV6 1",
args: args{cidr: "ff00::/16"},
wantFirstIP: net.ParseIP("ff00::"),
wantLastIP: net.ParseIP("ff00:ffff:ffff:ffff:ffff:ffff:ffff:ffff"),
wantErr: false,
},
{
name: "Test IPV6 2",
args: args{cidr: "2001:db8::/62"},
wantFirstIP: net.ParseIP("2001:db8::"),
wantLastIP: net.ParseIP("2001:db8:0000:0003:ffff:ffff:ffff:ffff"),
wantErr: false, wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := GetFirstIPFromCIDR(tt.args.cidr) gotFirstIP, gotLastIP, err := GetFirstAndLastIPFromCIDR(tt.args.cidr)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("GetFirstIPFromCIDR() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GetFirstAndLastIPFromCIDR() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(gotFirstIP, tt.wantFirstIP) {
t.Errorf("GetFirstIPFromCIDR() = %v, want %v", got, tt.want) t.Errorf("GetFirstAndLastIPFromCIDR() gotFirstIP = %v, want %v", gotFirstIP, tt.wantFirstIP)
} }
}) if !reflect.DeepEqual(gotLastIP, tt.wantLastIP) {
} t.Errorf("GetFirstAndLastIPFromCIDR() gotLastIP = %v, want %v", gotLastIP, tt.wantLastIP)
}
func TestGetLastIPFromCIDR(t *testing.T) {
type args struct {
cidr string
}
tests := []struct {
name string
args args
want *net.IP
wantErr bool
}{
{
name: "Test 0",
args: args{cidr: "fakecidr"},
want: nil,
wantErr: true,
},
{
name: "Test 1",
args: args{cidr: "10.0.0.0/24"},
want: &net.IP{10, 0, 0, 255},
wantErr: false,
},
{
name: "Test 2",
args: args{cidr: "10.0.0.20/24"},
want: &net.IP{10, 0, 0, 255},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GetLastIPFromCIDR(tt.args.cidr)
if (err != nil) != tt.wantErr {
t.Errorf("GetLastIPFromCIDR() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetLastIPFromCIDR() = %v, want %v", got, tt.want)
} }
}) })
} }