p2p: use netip.Addr where possible (#29891)

enode.Node was recently changed to store a cache of endpoint information. The IP address in the cache is a netip.Addr. I chose that type over net.IP because it is just better. netip.Addr is meant to be used as a value type. Copying it does not allocate, it can be compared with ==, and can be used as a map key.

This PR changes most uses of Node.IP() into Node.IPAddr(), which returns the cached value directly without allocating.
While there are still some public APIs left where net.IP is used, I have converted all code used internally by p2p/discover to the new types. So this does change some public Go API, but hopefully not APIs any external code actually uses.

There weren't supposed to be any semantic differences resulting from this refactoring, however it does introduce one: In package p2p/netutil we treated the 0.0.0.0/8 network (addresses 0.x.y.z) as LAN, but netip.Addr.IsPrivate() doesn't. The treatment of this particular IP address range is controversial, with some software supporting it and others not. IANA lists it as special-purpose and invalid as a destination for a long time, so I don't know why I put it into the LAN list. It has now been marked as special in p2p/netutil as well.
This commit is contained in:
Felix Lange 2024-06-05 19:31:04 +02:00 committed by GitHub
parent d09ddac399
commit bc6569462d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 392 additions and 313 deletions

View File

@ -53,7 +53,8 @@ func (s *Suite) dial() (*Conn, error) {
// dialAs attempts to dial a given node and perform a handshake using the given // dialAs attempts to dial a given node and perform a handshake using the given
// private key. // private key.
func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) { func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP())) tcpEndpoint, _ := s.Dest.TCPEndpoint()
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -53,10 +53,12 @@ func newTestEnv(remote string, listen1, listen2 string) *testenv {
if err != nil { if err != nil {
panic(err) panic(err)
} }
if node.IP() == nil || node.UDP() == 0 { if !node.IPAddr().IsValid() || node.UDP() == 0 {
var ip net.IP var ip net.IP
var tcpPort, udpPort int var tcpPort, udpPort int
if ip = node.IP(); ip == nil { if node.IPAddr().IsValid() {
ip = node.IPAddr().AsSlice()
} else {
ip = net.ParseIP("127.0.0.1") ip = net.ParseIP("127.0.0.1")
} }
if tcpPort = node.TCP(); tcpPort == 0 { if tcpPort = node.TCP(); tcpPort == 0 {

View File

@ -19,7 +19,7 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"net" "net/netip"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -205,11 +205,11 @@ func trueFilter(args []string) (nodeFilter, error) {
} }
func ipFilter(args []string) (nodeFilter, error) { func ipFilter(args []string) (nodeFilter, error) {
_, cidr, err := net.ParseCIDR(args[0]) prefix, err := netip.ParsePrefix(args[0])
if err != nil { if err != nil {
return nil, err return nil, err
} }
f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) } f := func(n nodeJSON) bool { return prefix.Contains(n.N.IPAddr()) }
return f, nil return f, nil
} }

View File

@ -77,7 +77,11 @@ var (
func rlpxPing(ctx *cli.Context) error { func rlpxPing(ctx *cli.Context) error {
n := getNodeArg(ctx) n := getNodeArg(ctx)
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP())) tcpEndpoint, ok := n.TCPEndpoint()
if !ok {
return fmt.Errorf("node has no TCP endpoint")
}
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil { if err != nil {
return err return err
} }

View File

@ -65,11 +65,8 @@ type tcpDialer struct {
} }
func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) {
return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String()) addr, _ := dest.TCPEndpoint()
} return t.d.DialContext(ctx, "tcp", addr.String())
func nodeAddr(n *enode.Node) net.Addr {
return &net.TCPAddr{IP: n.IP(), Port: n.TCP()}
} }
// checkDial errors: // checkDial errors:
@ -243,7 +240,7 @@ loop:
select { select {
case node := <-nodesCh: case node := <-nodesCh:
if err := d.checkDial(node); err != nil { if err := d.checkDial(node); err != nil {
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err) d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IPAddr(), "reason", err)
} else { } else {
d.startDial(newDialTask(node, dynDialedConn)) d.startDial(newDialTask(node, dynDialedConn))
} }
@ -277,7 +274,7 @@ loop:
case node := <-d.addStaticCh: case node := <-d.addStaticCh:
id := node.ID() id := node.ID()
_, exists := d.static[id] _, exists := d.static[id]
d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists) d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists)
if exists { if exists {
continue loop continue loop
} }
@ -376,7 +373,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if n.ID() == d.self { if n.ID() == d.self {
return errSelf return errSelf
} }
if n.IP() != nil && n.TCP() == 0 { if n.IPAddr().IsValid() && n.TCP() == 0 {
// This check can trigger if a non-TCP node is found // This check can trigger if a non-TCP node is found
// by discovery. If there is no IP, the node is a static // by discovery. If there is no IP, the node is a static
// node and the actual endpoint will be resolved later in dialTask. // node and the actual endpoint will be resolved later in dialTask.
@ -388,7 +385,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if _, ok := d.peers[n.ID()]; ok { if _, ok := d.peers[n.ID()]; ok {
return errAlreadyConnected return errAlreadyConnected
} }
if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) { if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) {
return errNetRestrict return errNetRestrict
} }
if d.history.contains(string(n.ID().Bytes())) { if d.history.contains(string(n.ID().Bytes())) {
@ -439,7 +436,7 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
// startDial runs the given dial task in a separate goroutine. // startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) { func (d *dialScheduler) startDial(task *dialTask) {
node := task.dest() node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags) d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags)
hkey := string(node.ID().Bytes()) hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[node.ID()] = task d.dialing[node.ID()] = task
@ -492,7 +489,7 @@ func (t *dialTask) run(d *dialScheduler) {
} }
func (t *dialTask) needResolve() bool { func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid()
} }
// resolve attempts to find the current endpoint for the destination // resolve attempts to find the current endpoint for the destination
@ -526,7 +523,8 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
// The node was found. // The node was found.
t.resolveDelay = initialResolveDelay t.resolveDelay = initialResolveDelay
t.destPtr.Store(resolved) t.destPtr.Store(resolved)
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()}) resAddr, _ := resolved.TCPEndpoint()
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", resAddr)
return true return true
} }
@ -535,7 +533,8 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1) dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, dest) fd, err := d.dialer.Dial(d.ctx, dest)
if err != nil { if err != nil {
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err)) addr, _ := dest.TCPEndpoint()
d.log.Trace("Dial error", "id", dest.ID(), "addr", addr, "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1) dialConnectionError.Mark(1)
return &dialError{err} return &dialError{err}
} }
@ -545,7 +544,7 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
func (t *dialTask) String() string { func (t *dialTask) String() string {
node := t.dest() node := t.dest()
id := node.ID() id := node.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP()) return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IPAddr(), node.TCP())
} }
func cleanupDialErr(err error) error { func cleanupDialErr(err error) error {

View File

@ -25,7 +25,7 @@ package discover
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"sync" "sync"
"time" "time"
@ -207,8 +207,8 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
if err := n.ValidateComplete(); err != nil { if err := n.ValidateComplete(); err != nil {
return fmt.Errorf("bad bootstrap node %q: %v", n, err) return fmt.Errorf("bad bootstrap node %q: %v", n, err)
} }
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.Contains(n.IP()) { if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.ContainsAddr(n.IPAddr()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP()) tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IPAddr())
continue continue
} }
nursery = append(nursery, n) nursery = append(nursery, n)
@ -448,7 +448,7 @@ func (tab *Table) loadSeedNodes() {
for i := range seeds { for i := range seeds {
seed := seeds[i] seed := seeds[i]
if tab.log.Enabled(context.Background(), log.LevelTrace) { if tab.log.Enabled(context.Background(), log.LevelTrace) {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IPAddr()))
addr, _ := seed.UDPEndpoint() addr, _ := seed.UDPEndpoint()
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age) tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age)
} }
@ -474,31 +474,31 @@ func (tab *Table) bucketAtDistance(d int) *bucket {
return tab.buckets[d-bucketMinDistance-1] return tab.buckets[d-bucketMinDistance-1]
} }
func (tab *Table) addIP(b *bucket, ip net.IP) bool { func (tab *Table) addIP(b *bucket, ip netip.Addr) bool {
if len(ip) == 0 { if !ip.IsValid() || ip.IsUnspecified() {
return false // Nodes without IP cannot be added. return false // Nodes without IP cannot be added.
} }
if netutil.IsLAN(ip) { if netutil.AddrIsLAN(ip) {
return true return true
} }
if !tab.ips.Add(ip) { if !tab.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds table limit", "ip", ip) tab.log.Debug("IP exceeds table limit", "ip", ip)
return false return false
} }
if !b.ips.Add(ip) { if !b.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds bucket limit", "ip", ip) tab.log.Debug("IP exceeds bucket limit", "ip", ip)
tab.ips.Remove(ip) tab.ips.RemoveAddr(ip)
return false return false
} }
return true return true
} }
func (tab *Table) removeIP(b *bucket, ip net.IP) { func (tab *Table) removeIP(b *bucket, ip netip.Addr) {
if netutil.IsLAN(ip) { if netutil.AddrIsLAN(ip) {
return return
} }
tab.ips.Remove(ip) tab.ips.RemoveAddr(ip)
b.ips.Remove(ip) b.ips.RemoveAddr(ip)
} }
// handleAddNode adds the node in the request to the table, if there is space. // handleAddNode adds the node in the request to the table, if there is space.
@ -524,7 +524,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
tab.addReplacement(b, req.node) tab.addReplacement(b, req.node)
return false return false
} }
if !tab.addIP(b, req.node.IP()) { if !tab.addIP(b, req.node.IPAddr()) {
// Can't add: IP limit reached. // Can't add: IP limit reached.
return false return false
} }
@ -547,7 +547,7 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
// TODO: update ENR // TODO: update ENR
return return
} }
if !tab.addIP(b, n.IP()) { if !tab.addIP(b, n.IPAddr()) {
return return
} }
@ -555,7 +555,7 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
var removed *tableNode var removed *tableNode
b.replacements, removed = pushNode(b.replacements, wn, maxReplacements) b.replacements, removed = pushNode(b.replacements, wn, maxReplacements)
if removed != nil { if removed != nil {
tab.removeIP(b, removed.IP()) tab.removeIP(b, removed.IPAddr())
} }
} }
@ -595,12 +595,12 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
// Remove the node. // Remove the node.
n := b.entries[index] n := b.entries[index]
b.entries = slices.Delete(b.entries, index, index+1) b.entries = slices.Delete(b.entries, index, index+1)
tab.removeIP(b, n.IP()) tab.removeIP(b, n.IPAddr())
tab.nodeRemoved(b, n) tab.nodeRemoved(b, n)
// Add replacement. // Add replacement.
if len(b.replacements) == 0 { if len(b.replacements) == 0 {
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP()) tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr())
return nil return nil
} }
rindex := tab.rand.Intn(len(b.replacements)) rindex := tab.rand.Intn(len(b.replacements))
@ -608,7 +608,7 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
b.replacements = slices.Delete(b.replacements, rindex, rindex+1) b.replacements = slices.Delete(b.replacements, rindex, rindex+1)
b.entries = append(b.entries, rep) b.entries = append(b.entries, rep)
tab.nodeAdded(b, rep) tab.nodeAdded(b, rep)
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP()) tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr(), "r", rep.ID(), "rip", rep.IPAddr())
return rep return rep
} }
@ -635,10 +635,10 @@ func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool)
ipchanged := newRecord.IPAddr() != n.IPAddr() ipchanged := newRecord.IPAddr() != n.IPAddr()
portchanged := newRecord.UDP() != n.UDP() portchanged := newRecord.UDP() != n.UDP()
if ipchanged { if ipchanged {
tab.removeIP(b, n.IP()) tab.removeIP(b, n.IPAddr())
if !tab.addIP(b, newRecord.IP()) { if !tab.addIP(b, newRecord.IPAddr()) {
// It doesn't fit with the limit, put the previous record back. // It doesn't fit with the limit, put the previous record back.
tab.addIP(b, n.IP()) tab.addIP(b, n.IPAddr())
return n, false return n, false
} }
} }
@ -657,11 +657,11 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) {
var fails int var fails int
if op.success { if op.success {
// Reset failure counter because it counts _consecutive_ failures. // Reset failure counter because it counts _consecutive_ failures.
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0) tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), 0)
} else { } else {
fails = tab.db.FindFails(op.node.ID(), op.node.IP()) fails = tab.db.FindFails(op.node.ID(), op.node.IPAddr())
fails++ fails++
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails) tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), fails)
} }
tab.mutex.Lock() tab.mutex.Lock()

View File

@ -188,7 +188,7 @@ func checkIPLimitInvariant(t *testing.T, tab *Table) {
tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit} tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}
for _, b := range tab.buckets { for _, b := range tab.buckets {
for _, n := range b.entries { for _, n := range b.entries {
tabset.Add(n.IP()) tabset.AddAddr(n.IPAddr())
} }
} }
if tabset.String() != tab.ips.String() { if tabset.String() != tab.ips.String() {
@ -268,7 +268,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
} }
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
r := new(enr.Record) r := new(enr.Record)
r.Set(enr.IP(genIP(rand))) r.Set(enr.IPv4Addr(netutil.RandomAddr(rand, true)))
n := enode.SignNull(r, id) n := enode.SignNull(r, id)
t.All = append(t.All, n) t.All = append(t.All, n)
} }
@ -385,11 +385,11 @@ func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) {
} }
t.Log("wrong bucket content. have nodes:") t.Log("wrong bucket content. have nodes:")
for _, n := range b.entries { for _, n := range b.entries {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
} }
t.Log("want nodes:") t.Log("want nodes:")
for _, n := range nodes { for _, n := range nodes {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
} }
t.FailNow() t.FailNow()
@ -483,12 +483,6 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface() return v.Interface()
} }
func genIP(rand *rand.Rand) net.IP {
ip := make(net.IP, 4)
rand.Read(ip)
return ip
}
func quickcfg() *quick.Config { func quickcfg() *quick.Config {
return &quick.Config{ return &quick.Config{
MaxCount: 5000, MaxCount: 5000,

View File

@ -100,8 +100,9 @@ func idAtDistance(a enode.ID, n int) (b enode.ID) {
return b return b
} }
// intIP returns a LAN IP address based on i.
func intIP(i int) net.IP { func intIP(i int) net.IP {
return net.IP{byte(i), 0, 2, byte(i)} return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)}
} }
// fillBucket inserts nodes into the given bucket until it is full. // fillBucket inserts nodes into the given bucket until it is full.
@ -254,7 +255,7 @@ NotEqual:
} }
func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr()
} }
func sortByID[N nodeType](nodes []N) { func sortByID[N nodeType](nodes []N) {

View File

@ -25,7 +25,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@ -250,8 +249,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func())
return matched, matched return matched, matched
}) })
// Send the packet. // Send the packet.
toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()} t.localNode.UDPContact(toaddr)
t.localNode.UDPContact(toUDPAddr)
t.write(toaddr, toid, req.Name(), packet) t.write(toaddr, toid, req.Name(), packet)
return rm return rm
} }
@ -383,7 +381,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
if respN.Seq() < n.Seq() { if respN.Seq() < n.Seq() {
return n, nil // response record is older return n, nil // response record is older
} }
if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil { if err := netutil.CheckRelayAddr(addr.Addr(), respN.IPAddr()); err != nil {
return nil, fmt.Errorf("invalid IP in response record: %v", err) return nil, fmt.Errorf("invalid IP in response record: %v", err)
} }
return respN, nil return respN, nil
@ -578,15 +576,14 @@ func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error {
// checkBond checks if the given node has a recent enough endpoint proof. // checkBond checks if the given node has a recent enough endpoint proof.
func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool { func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool {
return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration return time.Since(t.db.LastPongReceived(id, ip.Addr())) < bondExpiration
} }
// ensureBond solicits a ping from a node if we haven't seen a ping from it for a while. // ensureBond solicits a ping from a node if we haven't seen a ping from it for a while.
// This ensures there is a valid endpoint proof on the remote end. // This ensures there is a valid endpoint proof on the remote end.
func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) { func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) {
ip := toaddr.Addr().AsSlice() tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.Addr())) > bondExpiration
tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration if tooOld || t.db.FindFails(toid, toaddr.Addr()) > maxFindnodeFailures {
if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures {
rm := t.sendPing(toid, toaddr, nil) rm := t.sendPing(toid, toaddr, nil)
<-rm.errc <-rm.errc
// Wait for them to ping back and process our pong. // Wait for them to ping back and process our pong.
@ -687,7 +684,7 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode
// Ping back if our last pong on file is too far in the past. // Ping back if our last pong on file is too far in the past.
fromIP := from.Addr().AsSlice() fromIP := from.Addr().AsSlice()
n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port())) n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port()))
if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration { if time.Since(t.db.LastPongReceived(n.ID(), from.Addr())) > bondExpiration {
t.sendPing(fromID, from, func() { t.sendPing(fromID, from, func() {
t.tab.addInboundNode(n) t.tab.addInboundNode(n)
}) })
@ -696,10 +693,9 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode
} }
// Update node database and endpoint predictor. // Update node database and endpoint predictor.
t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now()) t.db.UpdateLastPingReceived(n.ID(), from.Addr(), time.Now())
fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} toaddr := netip.AddrPortFrom(netutil.IPToAddr(req.To.IP), req.To.UDP)
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} t.localNode.UDPEndpointStatement(from, toaddr)
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
} }
// PONG/v4 // PONG/v4
@ -713,11 +709,9 @@ func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode
if !t.handleReply(fromID, from.Addr(), req) { if !t.handleReply(fromID, from.Addr(), req) {
return errUnsolicitedReply return errUnsolicitedReply
} }
fromIP := from.Addr().AsSlice() toaddr := netip.AddrPortFrom(netutil.IPToAddr(req.To.IP), req.To.UDP)
fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} t.localNode.UDPEndpointStatement(from, toaddr)
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} t.db.UpdateLastPongReceived(fromID, from.Addr(), time.Now())
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
t.db.UpdateLastPongReceived(fromID, fromIP, time.Now())
return nil return nil
} }
@ -753,8 +747,7 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID e
p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool var sent bool
for _, n := range closest { for _, n := range closest {
fromIP := from.Addr().AsSlice() if netutil.CheckRelayAddr(from.Addr(), n.IPAddr()) == nil {
if netutil.CheckRelayIP(fromIP, n.IP()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n)) p.Nodes = append(p.Nodes, nodeToRPC(n))
} }
if len(p.Nodes) == v4wire.MaxNeighbors { if len(p.Nodes) == v4wire.MaxNeighbors {

View File

@ -274,7 +274,7 @@ func TestUDPv4_findnode(t *testing.T) {
// ensure there's a bond with the test node, // ensure there's a bond with the test node,
// findnode won't be accepted otherwise. // findnode won't be accepted otherwise.
remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID() remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID()
test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now()) test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr(), time.Now())
// check that closest neighbors are returned. // check that closest neighbors are returned.
expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true)
@ -309,7 +309,7 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) {
defer test.close() defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now()) test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr(), time.Now())
// queue a pending findnode request // queue a pending findnode request
resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1) resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1)
@ -437,8 +437,8 @@ func TestUDPv4_successfulPing(t *testing.T) {
if n.ID() != rid { if n.ID() != rid {
t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid) t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid)
} }
if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) { if n.IPAddr() != test.remoteaddr.Addr() {
t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr()) t.Errorf("node has wrong IP: got %v, want: %v", n.IPAddr(), test.remoteaddr.Addr())
} }
if n.UDP() != int(test.remoteaddr.Port()) { if n.UDP() != int(test.remoteaddr.Port()) {
t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port()) t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port())

View File

@ -428,10 +428,10 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil { if err := netutil.CheckRelayAddr(c.addr.Addr(), node.IPAddr()); err != nil {
return nil, err return nil, err
} }
if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) { if t.netrestrict != nil && !t.netrestrict.ContainsAddr(node.IPAddr()) {
return nil, errors.New("not contained in netrestrict list") return nil, errors.New("not contained in netrestrict list")
} }
if node.UDP() <= 1024 { if node.UDP() <= 1024 {
@ -754,9 +754,8 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort
t.handlePing(p, fromID, fromAddr) t.handlePing(p, fromID, fromAddr)
case *v5wire.Pong: case *v5wire.Pong:
if t.handleCallResponse(fromID, fromAddr, p) { if t.handleCallResponse(fromID, fromAddr, p) {
fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())} toAddr := netip.AddrPortFrom(netutil.IPToAddr(p.ToIP), p.ToPort)
toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)} t.localNode.UDPEndpointStatement(fromAddr, toAddr)
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
} }
case *v5wire.Findnode: case *v5wire.Findnode:
t.handleFindnode(p, fromID, fromAddr) t.handleFindnode(p, fromID, fromAddr)
@ -848,7 +847,6 @@ func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr net
// collectTableNodes creates a FINDNODE result set for the given distances. // collectTableNodes creates a FINDNODE result set for the given distances.
func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node { func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node {
ripSlice := rip.AsSlice()
var bn []*enode.Node var bn []*enode.Node
var nodes []*enode.Node var nodes []*enode.Node
var processed = make(map[uint]struct{}) var processed = make(map[uint]struct{})
@ -863,7 +861,7 @@ func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) [
for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) { for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) {
// Apply some pre-checks to avoid sending invalid nodes. // Apply some pre-checks to avoid sending invalid nodes.
// Note liveness is checked by appendLiveNodes. // Note liveness is checked by appendLiveNodes.
if netutil.CheckRelayIP(ripSlice, n.IP()) != nil { if netutil.CheckRelayAddr(rip, n.IPAddr()) != nil {
continue continue
} }
nodes = append(nodes, n) nodes = append(nodes, n)

View File

@ -606,7 +606,7 @@ func (n *handshakeTestNode) n() *enode.Node {
} }
func (n *handshakeTestNode) addr() string { func (n *handshakeTestNode) addr() string {
return n.ln.Node().IP().String() return n.ln.Node().IPAddr().String()
} }
func (n *handshakeTestNode) id() enode.ID { func (n *handshakeTestNode) id() enode.ID {

View File

@ -20,8 +20,8 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"net" "net"
"net/netip"
"reflect" "reflect"
"strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -175,8 +175,8 @@ func (ln *LocalNode) delete(e enr.Entry) {
} }
} }
func (ln *LocalNode) endpointForIP(ip net.IP) *lnEndpoint { func (ln *LocalNode) endpointForIP(ip netip.Addr) *lnEndpoint {
if ip.To4() != nil { if ip.Is4() {
return &ln.endpoint4 return &ln.endpoint4
} }
return &ln.endpoint6 return &ln.endpoint6
@ -188,7 +188,7 @@ func (ln *LocalNode) SetStaticIP(ip net.IP) {
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
ln.endpointForIP(ip).staticIP = ip ln.endpointForIP(netutil.IPToAddr(ip)).staticIP = ip
ln.updateEndpoints() ln.updateEndpoints()
} }
@ -198,7 +198,7 @@ func (ln *LocalNode) SetFallbackIP(ip net.IP) {
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
ln.endpointForIP(ip).fallbackIP = ip ln.endpointForIP(netutil.IPToAddr(ip)).fallbackIP = ip
ln.updateEndpoints() ln.updateEndpoints()
} }
@ -215,21 +215,21 @@ func (ln *LocalNode) SetFallbackUDP(port int) {
// UDPEndpointStatement should be called whenever a statement about the local node's // UDPEndpointStatement should be called whenever a statement about the local node's
// UDP endpoint is received. It feeds the local endpoint predictor. // UDP endpoint is received. It feeds the local endpoint predictor.
func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) { func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint netip.AddrPort) {
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
ln.endpointForIP(endpoint.IP).track.AddStatement(fromaddr.String(), endpoint.String()) ln.endpointForIP(endpoint.Addr()).track.AddStatement(fromaddr.Addr(), endpoint)
ln.updateEndpoints() ln.updateEndpoints()
} }
// UDPContact should be called whenever the local node has announced itself to another node // UDPContact should be called whenever the local node has announced itself to another node
// via UDP. It feeds the local endpoint predictor. // via UDP. It feeds the local endpoint predictor.
func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) { func (ln *LocalNode) UDPContact(toaddr netip.AddrPort) {
ln.mu.Lock() ln.mu.Lock()
defer ln.mu.Unlock() defer ln.mu.Unlock()
ln.endpointForIP(toaddr.IP).track.AddContact(toaddr.String()) ln.endpointForIP(toaddr.Addr()).track.AddContact(toaddr.Addr())
ln.updateEndpoints() ln.updateEndpoints()
} }
@ -268,29 +268,13 @@ func (e *lnEndpoint) get() (newIP net.IP, newPort uint16) {
} }
if e.staticIP != nil { if e.staticIP != nil {
newIP = e.staticIP newIP = e.staticIP
} else if ip, port := predictAddr(e.track); ip != nil { } else if ap := e.track.PredictEndpoint(); ap.IsValid() {
newIP = ip newIP = ap.Addr().AsSlice()
newPort = port newPort = ap.Port()
} }
return newIP, newPort return newIP, newPort
} }
// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
// endpoint representation to IP and port types.
func predictAddr(t *netutil.IPTracker) (net.IP, uint16) {
ep := t.PredictEndpoint()
if ep == "" {
return nil, 0
}
ipString, portString, _ := net.SplitHostPort(ep)
ip := net.ParseIP(ipString)
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, 0
}
return ip, uint16(port)
}
func (ln *LocalNode) invalidate() { func (ln *LocalNode) invalidate() {
ln.cur.Store((*Node)(nil)) ln.cur.Store((*Node)(nil))
} }
@ -314,7 +298,7 @@ func (ln *LocalNode) sign() {
panic(fmt.Errorf("enode: can't verify local record: %v", err)) panic(fmt.Errorf("enode: can't verify local record: %v", err))
} }
ln.cur.Store(n) ln.cur.Store(n)
log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP()) log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IPAddr(), "udp", n.UDP(), "tcp", n.TCP())
} }
func (ln *LocalNode) bumpSeq() { func (ln *LocalNode) bumpSeq() {

View File

@ -17,12 +17,14 @@
package enode package enode
import ( import (
"crypto/rand" "math/rand"
"net" "net"
"net/netip"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -88,6 +90,7 @@ func TestLocalNodeSeqPersist(t *testing.T) {
// This test checks behavior of the endpoint predictor. // This test checks behavior of the endpoint predictor.
func TestLocalNodeEndpoint(t *testing.T) { func TestLocalNodeEndpoint(t *testing.T) {
var ( var (
rng = rand.New(rand.NewSource(4))
fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80} fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80}
predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81} predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81}
staticIP = net.IP{127, 0, 1, 2} staticIP = net.IP{127, 0, 1, 2}
@ -96,6 +99,7 @@ func TestLocalNodeEndpoint(t *testing.T) {
defer db.Close() defer db.Close()
// Nothing is set initially. // Nothing is set initially.
assert.Equal(t, netip.Addr{}, ln.Node().IPAddr())
assert.Equal(t, net.IP(nil), ln.Node().IP()) assert.Equal(t, net.IP(nil), ln.Node().IP())
assert.Equal(t, 0, ln.Node().UDP()) assert.Equal(t, 0, ln.Node().UDP())
initialSeq := ln.Node().Seq() initialSeq := ln.Node().Seq()
@ -103,26 +107,30 @@ func TestLocalNodeEndpoint(t *testing.T) {
// Set up fallback address. // Set up fallback address.
ln.SetFallbackIP(fallback.IP) ln.SetFallbackIP(fallback.IP)
ln.SetFallbackUDP(fallback.Port) ln.SetFallbackUDP(fallback.Port)
assert.Equal(t, netutil.IPToAddr(fallback.IP), ln.Node().IPAddr())
assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+1, ln.Node().Seq()) assert.Equal(t, initialSeq+1, ln.Node().Seq())
// Add endpoint statements from random hosts. // Add endpoint statements from random hosts.
for i := 0; i < iptrackMinStatements; i++ { for i := 0; i < iptrackMinStatements; i++ {
assert.Equal(t, netutil.IPToAddr(fallback.IP), ln.Node().IPAddr())
assert.Equal(t, fallback.IP, ln.Node().IP()) assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+1, ln.Node().Seq()) assert.Equal(t, initialSeq+1, ln.Node().Seq())
from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90} from := netip.AddrPortFrom(netutil.RandomAddr(rng, true), 9000)
rand.Read(from.IP) endpoint := netip.AddrPortFrom(netutil.IPToAddr(predicted.IP), uint16(predicted.Port))
ln.UDPEndpointStatement(from, predicted) ln.UDPEndpointStatement(from, endpoint)
} }
assert.Equal(t, netutil.IPToAddr(predicted.IP), ln.Node().IPAddr())
assert.Equal(t, predicted.IP, ln.Node().IP()) assert.Equal(t, predicted.IP, ln.Node().IP())
assert.Equal(t, predicted.Port, ln.Node().UDP()) assert.Equal(t, predicted.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+2, ln.Node().Seq()) assert.Equal(t, initialSeq+2, ln.Node().Seq())
// Static IP overrides prediction. // Static IP overrides prediction.
ln.SetStaticIP(staticIP) ln.SetStaticIP(staticIP)
assert.Equal(t, netutil.IPToAddr(staticIP), ln.Node().IPAddr())
assert.Equal(t, staticIP, ln.Node().IP()) assert.Equal(t, staticIP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP()) assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+3, ln.Node().Seq()) assert.Equal(t, initialSeq+3, ln.Node().Seq())

View File

@ -21,7 +21,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net/netip"
"os" "os"
"sync" "sync"
"time" "time"
@ -66,7 +66,7 @@ var (
errInvalidIP = errors.New("invalid IP") errInvalidIP = errors.New("invalid IP")
) )
var zeroIP = make(net.IP, 16) var zeroIP = netip.IPv6Unspecified()
// DB is the node database, storing previously seen nodes and any collected metadata about // DB is the node database, storing previously seen nodes and any collected metadata about
// them for QoS purposes. // them for QoS purposes.
@ -151,39 +151,37 @@ func splitNodeKey(key []byte) (id ID, rest []byte) {
} }
// nodeItemKey returns the database key for a node metadata field. // nodeItemKey returns the database key for a node metadata field.
func nodeItemKey(id ID, ip net.IP, field string) []byte { func nodeItemKey(id ID, ip netip.Addr, field string) []byte {
ip16 := ip.To16() if !ip.IsValid() {
if ip16 == nil { panic("invalid IP")
panic(fmt.Errorf("invalid IP (length %d)", len(ip)))
} }
return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'}) ip16 := ip.As16()
return bytes.Join([][]byte{nodeKey(id), ip16[:], []byte(field)}, []byte{':'})
} }
// splitNodeItemKey returns the components of a key created by nodeItemKey. // splitNodeItemKey returns the components of a key created by nodeItemKey.
func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) { func splitNodeItemKey(key []byte) (id ID, ip netip.Addr, field string) {
id, key = splitNodeKey(key) id, key = splitNodeKey(key)
// Skip discover root. // Skip discover root.
if string(key) == dbDiscoverRoot { if string(key) == dbDiscoverRoot {
return id, nil, "" return id, netip.Addr{}, ""
} }
key = key[len(dbDiscoverRoot)+1:] key = key[len(dbDiscoverRoot)+1:]
// Split out the IP. // Split out the IP.
ip = key[:16] ip, _ = netip.AddrFromSlice(key[:16])
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
key = key[16+1:] key = key[16+1:]
// Field is the remainder of key. // Field is the remainder of key.
field = string(key) field = string(key)
return id, ip, field return id, ip, field
} }
func v5Key(id ID, ip net.IP, field string) []byte { func v5Key(id ID, ip netip.Addr, field string) []byte {
ip16 := ip.As16()
return bytes.Join([][]byte{ return bytes.Join([][]byte{
[]byte(dbNodePrefix), []byte(dbNodePrefix),
id[:], id[:],
[]byte(dbDiscv5Root), []byte(dbDiscv5Root),
ip.To16(), ip16[:],
[]byte(field), []byte(field),
}, []byte{':'}) }, []byte{':'})
} }
@ -364,24 +362,24 @@ func (db *DB) expireNodes() {
// LastPingReceived retrieves the time of the last ping packet received from // LastPingReceived retrieves the time of the last ping packet received from
// a remote node. // a remote node.
func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time { func (db *DB) LastPingReceived(id ID, ip netip.Addr) time.Time {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return time.Time{} return time.Time{}
} }
return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0) return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0)
} }
// UpdateLastPingReceived updates the last time we tried contacting a remote node. // UpdateLastPingReceived updates the last time we tried contacting a remote node.
func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error { func (db *DB) UpdateLastPingReceived(id ID, ip netip.Addr, instance time.Time) error {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return errInvalidIP return errInvalidIP
} }
return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix()) return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix())
} }
// LastPongReceived retrieves the time of the last successful pong from remote node. // LastPongReceived retrieves the time of the last successful pong from remote node.
func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time { func (db *DB) LastPongReceived(id ID, ip netip.Addr) time.Time {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return time.Time{} return time.Time{}
} }
// Launch expirer // Launch expirer
@ -390,40 +388,40 @@ func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time {
} }
// UpdateLastPongReceived updates the last pong time of a node. // UpdateLastPongReceived updates the last pong time of a node.
func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error { func (db *DB) UpdateLastPongReceived(id ID, ip netip.Addr, instance time.Time) error {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return errInvalidIP return errInvalidIP
} }
return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix()) return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix())
} }
// FindFails retrieves the number of findnode failures since bonding. // FindFails retrieves the number of findnode failures since bonding.
func (db *DB) FindFails(id ID, ip net.IP) int { func (db *DB) FindFails(id ID, ip netip.Addr) int {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return 0 return 0
} }
return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails))) return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails)))
} }
// UpdateFindFails updates the number of findnode failures since bonding. // UpdateFindFails updates the number of findnode failures since bonding.
func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error { func (db *DB) UpdateFindFails(id ID, ip netip.Addr, fails int) error {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return errInvalidIP return errInvalidIP
} }
return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails)) return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails))
} }
// FindFailsV5 retrieves the discv5 findnode failure counter. // FindFailsV5 retrieves the discv5 findnode failure counter.
func (db *DB) FindFailsV5(id ID, ip net.IP) int { func (db *DB) FindFailsV5(id ID, ip netip.Addr) int {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return 0 return 0
} }
return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails))) return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails)))
} }
// UpdateFindFailsV5 stores the discv5 findnode failure counter. // UpdateFindFailsV5 stores the discv5 findnode failure counter.
func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error { func (db *DB) UpdateFindFailsV5(id ID, ip netip.Addr, fails int) error {
if ip = ip.To16(); ip == nil { if !ip.IsValid() {
return errInvalidIP return errInvalidIP
} }
return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails)) return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
@ -470,7 +468,7 @@ seek:
id[0] = 0 id[0] = 0
continue seek // iterator exhausted continue seek // iterator exhausted
} }
if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge { if now.Sub(db.LastPongReceived(n.ID(), n.IPAddr())) > maxAge {
continue seek continue seek
} }
for i := range nodes { for i := range nodes {

View File

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"net" "net"
"net/netip"
"path/filepath" "path/filepath"
"reflect" "reflect"
"testing" "testing"
@ -48,8 +49,10 @@ func TestDBNodeKey(t *testing.T) {
} }
func TestDBNodeItemKey(t *testing.T) { func TestDBNodeItemKey(t *testing.T) {
wantIP := net.IP{127, 0, 0, 3} wantIP := netip.MustParseAddr("127.0.0.3")
wantIP4in6 := netip.AddrFrom16(wantIP.As16())
wantField := "foobar" wantField := "foobar"
enc := nodeItemKey(keytestID, wantIP, wantField) enc := nodeItemKey(keytestID, wantIP, wantField)
want := []byte{ want := []byte{
'n', ':', 'n', ':',
@ -69,7 +72,7 @@ func TestDBNodeItemKey(t *testing.T) {
if id != keytestID { if id != keytestID {
t.Errorf("splitNodeItemKey returned wrong ID: %v", id) t.Errorf("splitNodeItemKey returned wrong ID: %v", id)
} }
if !ip.Equal(wantIP) { if ip != wantIP4in6 {
t.Errorf("splitNodeItemKey returned wrong IP: %v", ip) t.Errorf("splitNodeItemKey returned wrong IP: %v", ip)
} }
if field != wantField { if field != wantField {
@ -123,33 +126,33 @@ func TestDBFetchStore(t *testing.T) {
defer db.Close() defer db.Close()
// Check fetch/store operations on a node ping object // Check fetch/store operations on a node ping object
if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 { if stored := db.LastPingReceived(node.ID(), node.IPAddr()); stored.Unix() != 0 {
t.Errorf("ping: non-existing object: %v", stored) t.Errorf("ping: non-existing object: %v", stored)
} }
if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil { if err := db.UpdateLastPingReceived(node.ID(), node.IPAddr(), inst); err != nil {
t.Errorf("ping: failed to update: %v", err) t.Errorf("ping: failed to update: %v", err)
} }
if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { if stored := db.LastPingReceived(node.ID(), node.IPAddr()); stored.Unix() != inst.Unix() {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst) t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
} }
// Check fetch/store operations on a node pong object // Check fetch/store operations on a node pong object
if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 { if stored := db.LastPongReceived(node.ID(), node.IPAddr()); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored) t.Errorf("pong: non-existing object: %v", stored)
} }
if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil { if err := db.UpdateLastPongReceived(node.ID(), node.IPAddr(), inst); err != nil {
t.Errorf("pong: failed to update: %v", err) t.Errorf("pong: failed to update: %v", err)
} }
if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() { if stored := db.LastPongReceived(node.ID(), node.IPAddr()); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst) t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
} }
// Check fetch/store operations on a node findnode-failure object // Check fetch/store operations on a node findnode-failure object
if stored := db.FindFails(node.ID(), node.IP()); stored != 0 { if stored := db.FindFails(node.ID(), node.IPAddr()); stored != 0 {
t.Errorf("find-node fails: non-existing object: %v", stored) t.Errorf("find-node fails: non-existing object: %v", stored)
} }
if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil { if err := db.UpdateFindFails(node.ID(), node.IPAddr(), num); err != nil {
t.Errorf("find-node fails: failed to update: %v", err) t.Errorf("find-node fails: failed to update: %v", err)
} }
if stored := db.FindFails(node.ID(), node.IP()); stored != num { if stored := db.FindFails(node.ID(), node.IPAddr()); stored != num {
t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num) t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num)
} }
// Check fetch/store operations on an actual node object // Check fetch/store operations on an actual node object
@ -266,7 +269,7 @@ func testSeedQuery() error {
if err := db.UpdateNode(seed.node); err != nil { if err := db.UpdateNode(seed.node); err != nil {
return fmt.Errorf("node %d: failed to insert: %v", i, err) return fmt.Errorf("node %d: failed to insert: %v", i, err)
} }
if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IPAddr(), seed.pong); err != nil {
return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err) return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err)
} }
} }
@ -427,7 +430,7 @@ func TestDBExpiration(t *testing.T) {
t.Fatalf("node %d: failed to insert: %v", i, err) t.Fatalf("node %d: failed to insert: %v", i, err)
} }
} }
if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil { if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IPAddr(), seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err) t.Fatalf("node %d: failed to update bondTime: %v", i, err)
} }
} }
@ -438,13 +441,13 @@ func TestDBExpiration(t *testing.T) {
unixZeroTime := time.Unix(0, 0) unixZeroTime := time.Unix(0, 0)
for i, seed := range nodeDBExpirationNodes { for i, seed := range nodeDBExpirationNodes {
node := db.Node(seed.node.ID()) node := db.Node(seed.node.ID())
pong := db.LastPongReceived(seed.node.ID(), seed.node.IP()) pong := db.LastPongReceived(seed.node.ID(), seed.node.IPAddr())
if seed.exp { if seed.exp {
if seed.storeNode && node != nil { if seed.storeNode && node != nil {
t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString()) t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString())
} }
if !pong.Equal(unixZeroTime) { if !pong.Equal(unixZeroTime) {
t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP()) t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IPAddr())
} }
} else { } else {
if seed.storeNode && node == nil { if seed.storeNode && node == nil {
@ -463,7 +466,7 @@ func TestDBExpireV5(t *testing.T) {
db, _ := OpenDB("") db, _ := OpenDB("")
defer db.Close() defer db.Close()
ip := net.IP{127, 0, 0, 1} ip := netip.MustParseAddr("127.0.0.1")
db.UpdateFindFailsV5(ID{}, ip, 4) db.UpdateFindFailsV5(ID{}, ip, 4)
db.expireNodes() db.expireNodes()
} }

View File

@ -16,18 +16,53 @@
package netutil package netutil
import "net" import (
"fmt"
"math/rand"
"net"
"net/netip"
)
// AddrIP gets the IP address contained in addr. It returns nil if no address is present. // AddrAddr gets the IP address contained in addr. The result will be invalid if the
func AddrIP(addr net.Addr) net.IP { // address type is unsupported.
func AddrAddr(addr net.Addr) netip.Addr {
switch a := addr.(type) { switch a := addr.(type) {
case *net.IPAddr: case *net.IPAddr:
return a.IP return IPToAddr(a.IP)
case *net.TCPAddr: case *net.TCPAddr:
return a.IP return IPToAddr(a.IP)
case *net.UDPAddr: case *net.UDPAddr:
return a.IP return IPToAddr(a.IP)
default: default:
return nil return netip.Addr{}
} }
} }
// IPToAddr converts net.IP to netip.Addr. Note that unlike netip.AddrFromSlice, this
// function will always ensure that the resulting Addr is IPv4 when the input is.
func IPToAddr(ip net.IP) netip.Addr {
if ip4 := ip.To4(); ip4 != nil {
addr, _ := netip.AddrFromSlice(ip4)
return addr
} else if ip6 := ip.To16(); ip6 != nil {
addr, _ := netip.AddrFromSlice(ip6)
return addr
}
return netip.Addr{}
}
// RandomAddr creates a random IP address.
func RandomAddr(rng *rand.Rand, ipv4 bool) netip.Addr {
var bytes []byte
if ipv4 || rng.Intn(2) == 0 {
bytes = make([]byte, 4)
} else {
bytes = make([]byte, 16)
}
rng.Read(bytes)
addr, ok := netip.AddrFromSlice(bytes)
if !ok {
panic(fmt.Errorf("BUG! invalid IP %v", bytes))
}
return addr
}

View File

@ -17,6 +17,7 @@
package netutil package netutil
import ( import (
"net/netip"
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
@ -29,14 +30,14 @@ type IPTracker struct {
contactWindow time.Duration contactWindow time.Duration
minStatements int minStatements int
clock mclock.Clock clock mclock.Clock
statements map[string]ipStatement statements map[netip.Addr]ipStatement
contact map[string]mclock.AbsTime contact map[netip.Addr]mclock.AbsTime
lastStatementGC mclock.AbsTime lastStatementGC mclock.AbsTime
lastContactGC mclock.AbsTime lastContactGC mclock.AbsTime
} }
type ipStatement struct { type ipStatement struct {
endpoint string endpoint netip.AddrPort
time mclock.AbsTime time mclock.AbsTime
} }
@ -51,9 +52,9 @@ func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTra
return &IPTracker{ return &IPTracker{
window: window, window: window,
contactWindow: contactWindow, contactWindow: contactWindow,
statements: make(map[string]ipStatement), statements: make(map[netip.Addr]ipStatement),
minStatements: minStatements, minStatements: minStatements,
contact: make(map[string]mclock.AbsTime), contact: make(map[netip.Addr]mclock.AbsTime),
clock: mclock.System{}, clock: mclock.System{},
} }
} }
@ -74,12 +75,15 @@ func (it *IPTracker) PredictFullConeNAT() bool {
} }
// PredictEndpoint returns the current prediction of the external endpoint. // PredictEndpoint returns the current prediction of the external endpoint.
func (it *IPTracker) PredictEndpoint() string { func (it *IPTracker) PredictEndpoint() netip.AddrPort {
it.gcStatements(it.clock.Now()) it.gcStatements(it.clock.Now())
// The current strategy is simple: find the endpoint with most statements. // The current strategy is simple: find the endpoint with most statements.
counts := make(map[string]int, len(it.statements)) var (
maxcount, max := 0, "" counts = make(map[netip.AddrPort]int, len(it.statements))
maxcount int
max netip.AddrPort
)
for _, s := range it.statements { for _, s := range it.statements {
c := counts[s.endpoint] + 1 c := counts[s.endpoint] + 1
counts[s.endpoint] = c counts[s.endpoint] = c
@ -91,7 +95,7 @@ func (it *IPTracker) PredictEndpoint() string {
} }
// AddStatement records that a certain host thinks our external endpoint is the one given. // AddStatement records that a certain host thinks our external endpoint is the one given.
func (it *IPTracker) AddStatement(host, endpoint string) { func (it *IPTracker) AddStatement(host netip.Addr, endpoint netip.AddrPort) {
now := it.clock.Now() now := it.clock.Now()
it.statements[host] = ipStatement{endpoint, now} it.statements[host] = ipStatement{endpoint, now}
if time.Duration(now-it.lastStatementGC) >= it.window { if time.Duration(now-it.lastStatementGC) >= it.window {
@ -101,7 +105,7 @@ func (it *IPTracker) AddStatement(host, endpoint string) {
// AddContact records that a packet containing our endpoint information has been sent to a // AddContact records that a packet containing our endpoint information has been sent to a
// certain host. // certain host.
func (it *IPTracker) AddContact(host string) { func (it *IPTracker) AddContact(host netip.Addr) {
now := it.clock.Now() now := it.clock.Now()
it.contact[host] = now it.contact[host] = now
if time.Duration(now-it.lastContactGC) >= it.contactWindow { if time.Duration(now-it.lastContactGC) >= it.contactWindow {

View File

@ -19,6 +19,7 @@ package netutil
import ( import (
crand "crypto/rand" crand "crypto/rand"
"fmt" "fmt"
"net/netip"
"testing" "testing"
"time" "time"
@ -42,37 +43,37 @@ func TestIPTracker(t *testing.T) {
tests := map[string][]iptrackTestEvent{ tests := map[string][]iptrackTestEvent{
"minStatements": { "minStatements": {
{opPredict, 0, "", ""}, {opPredict, 0, "", ""},
{opStatement, 0, "127.0.0.1", "127.0.0.2"}, {opStatement, 0, "127.0.0.1:8000", "127.0.0.2"},
{opPredict, 1000, "", ""}, {opPredict, 1000, "", ""},
{opStatement, 1000, "127.0.0.1", "127.0.0.3"}, {opStatement, 1000, "127.0.0.1:8000", "127.0.0.3"},
{opPredict, 1000, "", ""}, {opPredict, 1000, "", ""},
{opStatement, 1000, "127.0.0.1", "127.0.0.4"}, {opStatement, 1000, "127.0.0.1:8000", "127.0.0.4"},
{opPredict, 1000, "127.0.0.1", ""}, {opPredict, 1000, "127.0.0.1:8000", ""},
}, },
"window": { "window": {
{opStatement, 0, "127.0.0.1", "127.0.0.2"}, {opStatement, 0, "127.0.0.1:8000", "127.0.0.2"},
{opStatement, 2000, "127.0.0.1", "127.0.0.3"}, {opStatement, 2000, "127.0.0.1:8000", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1", "127.0.0.4"}, {opStatement, 3000, "127.0.0.1:8000", "127.0.0.4"},
{opPredict, 10000, "127.0.0.1", ""}, {opPredict, 10000, "127.0.0.1:8000", ""},
{opPredict, 10001, "", ""}, // first statement expired {opPredict, 10001, "", ""}, // first statement expired
{opStatement, 10100, "127.0.0.1", "127.0.0.2"}, {opStatement, 10100, "127.0.0.1:8000", "127.0.0.2"},
{opPredict, 10200, "127.0.0.1", ""}, {opPredict, 10200, "127.0.0.1:8000", ""},
}, },
"fullcone": { "fullcone": {
{opContact, 0, "", "127.0.0.2"}, {opContact, 0, "", "127.0.0.2"},
{opStatement, 10, "127.0.0.1", "127.0.0.2"}, {opStatement, 10, "127.0.0.1:8000", "127.0.0.2"},
{opContact, 2000, "", "127.0.0.3"}, {opContact, 2000, "", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1", "127.0.0.3"}, {opStatement, 2010, "127.0.0.1:8000", "127.0.0.3"},
{opContact, 3000, "", "127.0.0.4"}, {opContact, 3000, "", "127.0.0.4"},
{opStatement, 3010, "127.0.0.1", "127.0.0.4"}, {opStatement, 3010, "127.0.0.1:8000", "127.0.0.4"},
{opCheckFullCone, 3500, "false", ""}, {opCheckFullCone, 3500, "false", ""},
}, },
"fullcone_2": { "fullcone_2": {
{opContact, 0, "", "127.0.0.2"}, {opContact, 0, "", "127.0.0.2"},
{opStatement, 10, "127.0.0.1", "127.0.0.2"}, {opStatement, 10, "127.0.0.1:8000", "127.0.0.2"},
{opContact, 2000, "", "127.0.0.3"}, {opContact, 2000, "", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1", "127.0.0.3"}, {opStatement, 2010, "127.0.0.1:8000", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1", "127.0.0.4"}, {opStatement, 3000, "127.0.0.1:8000", "127.0.0.4"},
{opContact, 3010, "", "127.0.0.4"}, {opContact, 3010, "", "127.0.0.4"},
{opCheckFullCone, 3500, "true", ""}, {opCheckFullCone, 3500, "true", ""},
}, },
@ -93,12 +94,19 @@ func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) {
clock.Run(evtime - time.Duration(clock.Now())) clock.Run(evtime - time.Duration(clock.Now()))
switch ev.op { switch ev.op {
case opStatement: case opStatement:
it.AddStatement(ev.from, ev.ip) it.AddStatement(netip.MustParseAddr(ev.from), netip.MustParseAddrPort(ev.ip))
case opContact: case opContact:
it.AddContact(ev.from) it.AddContact(netip.MustParseAddr(ev.from))
case opPredict: case opPredict:
if pred := it.PredictEndpoint(); pred != ev.ip { pred := it.PredictEndpoint()
t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip) if ev.ip == "" {
if pred.IsValid() {
t.Errorf("op %d: wrong prediction %v, expected invalid", i, pred)
}
} else {
if pred != netip.MustParseAddrPort(ev.ip) {
t.Errorf("op %d: wrong prediction %v, want %q", i, pred, ev.ip)
}
} }
case opCheckFullCone: case opCheckFullCone:
pred := fmt.Sprintf("%t", it.PredictFullConeNAT()) pred := fmt.Sprintf("%t", it.PredictFullConeNAT())
@ -121,12 +129,11 @@ func TestIPTrackerForceGC(t *testing.T) {
it.clock = &clock it.clock = &clock
for i := 0; i < 5*max; i++ { for i := 0; i < 5*max; i++ {
e1 := make([]byte, 4) var e1, e2 [4]byte
e2 := make([]byte, 4) crand.Read(e1[:])
crand.Read(e1) crand.Read(e2[:])
crand.Read(e2) it.AddStatement(netip.AddrFrom4(e1), netip.AddrPortFrom(netip.AddrFrom4(e2), 9000))
it.AddStatement(string(e1), string(e2)) it.AddContact(netip.AddrFrom4(e1))
it.AddContact(string(e1))
clock.Run(rate) clock.Run(rate)
} }
if len(it.contact) > 2*max { if len(it.contact) > 2*max {

View File

@ -22,21 +22,19 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sort" "net/netip"
"slices"
"strings" "strings"
"golang.org/x/exp/maps"
) )
var lan4, lan6, special4, special6 Netlist var special4, special6 Netlist
func init() { func init() {
// Lists from RFC 5735, RFC 5156, // Lists from RFC 5735, RFC 5156,
// https://www.iana.org/assignments/iana-ipv4-special-registry/ // https://www.iana.org/assignments/iana-ipv4-special-registry/
lan4.Add("0.0.0.0/8") // "This" network special4.Add("0.0.0.0/8") // "This" network.
lan4.Add("10.0.0.0/8") // Private Use
lan4.Add("172.16.0.0/12") // Private Use
lan4.Add("192.168.0.0/16") // Private Use
lan6.Add("fe80::/10") // Link-Local
lan6.Add("fc00::/7") // Unique-Local
special4.Add("192.0.0.0/29") // IPv4 Service Continuity special4.Add("192.0.0.0/29") // IPv4 Service Continuity
special4.Add("192.0.0.9/32") // PCP Anycast special4.Add("192.0.0.9/32") // PCP Anycast
special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery
@ -66,7 +64,7 @@ func init() {
} }
// Netlist is a list of IP networks. // Netlist is a list of IP networks.
type Netlist []net.IPNet type Netlist []netip.Prefix
// ParseNetlist parses a comma-separated list of CIDR masks. // ParseNetlist parses a comma-separated list of CIDR masks.
// Whitespace and extra commas are ignored. // Whitespace and extra commas are ignored.
@ -78,11 +76,11 @@ func ParseNetlist(s string) (*Netlist, error) {
if mask == "" { if mask == "" {
continue continue
} }
_, n, err := net.ParseCIDR(mask) prefix, err := netip.ParsePrefix(mask)
if err != nil { if err != nil {
return nil, err return nil, err
} }
l = append(l, *n) l = append(l, prefix)
} }
return &l, nil return &l, nil
} }
@ -103,11 +101,11 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
return err return err
} }
for _, mask := range masks { for _, mask := range masks {
_, n, err := net.ParseCIDR(mask) prefix, err := netip.ParsePrefix(mask)
if err != nil { if err != nil {
return err return err
} }
*l = append(*l, *n) *l = append(*l, prefix)
} }
return nil return nil
} }
@ -115,15 +113,20 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is // Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
// intended to be used for setting up static lists. // intended to be used for setting up static lists.
func (l *Netlist) Add(cidr string) { func (l *Netlist) Add(cidr string) {
_, n, err := net.ParseCIDR(cidr) prefix, err := netip.ParsePrefix(cidr)
if err != nil { if err != nil {
panic(err) panic(err)
} }
*l = append(*l, *n) *l = append(*l, prefix)
} }
// Contains reports whether the given IP is contained in the list. // Contains reports whether the given IP is contained in the list.
func (l *Netlist) Contains(ip net.IP) bool { func (l *Netlist) Contains(ip net.IP) bool {
return l.ContainsAddr(IPToAddr(ip))
}
// ContainsAddr reports whether the given IP is contained in the list.
func (l *Netlist) ContainsAddr(ip netip.Addr) bool {
if l == nil { if l == nil {
return false return false
} }
@ -137,25 +140,39 @@ func (l *Netlist) Contains(ip net.IP) bool {
// IsLAN reports whether an IP is a local network address. // IsLAN reports whether an IP is a local network address.
func IsLAN(ip net.IP) bool { func IsLAN(ip net.IP) bool {
return AddrIsLAN(IPToAddr(ip))
}
// AddrIsLAN reports whether an IP is a local network address.
func AddrIsLAN(ip netip.Addr) bool {
if ip.Is4In6() {
ip = netip.AddrFrom4(ip.As4())
}
if ip.IsLoopback() { if ip.IsLoopback() {
return true return true
} }
if v4 := ip.To4(); v4 != nil { return ip.IsPrivate() || ip.IsLinkLocalUnicast()
return lan4.Contains(v4)
}
return lan6.Contains(ip)
} }
// IsSpecialNetwork reports whether an IP is located in a special-use network range // IsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses. // This includes broadcast, multicast and documentation addresses.
func IsSpecialNetwork(ip net.IP) bool { func IsSpecialNetwork(ip net.IP) bool {
return AddrIsSpecialNetwork(IPToAddr(ip))
}
// AddrIsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses.
func AddrIsSpecialNetwork(ip netip.Addr) bool {
if ip.Is4In6() {
ip = netip.AddrFrom4(ip.As4())
}
if ip.IsMulticast() { if ip.IsMulticast() {
return true return true
} }
if v4 := ip.To4(); v4 != nil { if ip.Is4() {
return special4.Contains(v4) return special4.ContainsAddr(ip)
} }
return special6.Contains(ip) return special6.ContainsAddr(ip)
} }
var ( var (
@ -175,19 +192,31 @@ var (
// - LAN addresses are OK if relayed by a LAN host. // - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable. // - All other addresses are always acceptable.
func CheckRelayIP(sender, addr net.IP) error { func CheckRelayIP(sender, addr net.IP) error {
if len(addr) != net.IPv4len && len(addr) != net.IPv6len { return CheckRelayAddr(IPToAddr(sender), IPToAddr(addr))
}
// CheckRelayAddr reports whether an IP relayed from the given sender IP
// is a valid connection target.
//
// There are four rules:
// - Special network addresses are never valid.
// - Loopback addresses are OK if relayed by a loopback host.
// - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable.
func CheckRelayAddr(sender, addr netip.Addr) error {
if !addr.IsValid() {
return errInvalid return errInvalid
} }
if addr.IsUnspecified() { if addr.IsUnspecified() {
return errUnspecified return errUnspecified
} }
if IsSpecialNetwork(addr) { if AddrIsSpecialNetwork(addr) {
return errSpecial return errSpecial
} }
if addr.IsLoopback() && !sender.IsLoopback() { if addr.IsLoopback() && !sender.IsLoopback() {
return errLoopback return errLoopback
} }
if IsLAN(addr) && !IsLAN(sender) { if AddrIsLAN(addr) && !AddrIsLAN(sender) {
return errLAN return errLAN
} }
return nil return nil
@ -221,17 +250,22 @@ type DistinctNetSet struct {
Subnet uint // number of common prefix bits Subnet uint // number of common prefix bits
Limit uint // maximum number of IPs in each subnet Limit uint // maximum number of IPs in each subnet
members map[string]uint members map[netip.Prefix]uint
buf net.IP
} }
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the // Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit. // number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) Add(ip net.IP) bool { func (s *DistinctNetSet) Add(ip net.IP) bool {
return s.AddAddr(IPToAddr(ip))
}
// AddAddr adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) AddAddr(ip netip.Addr) bool {
key := s.key(ip) key := s.key(ip)
n := s.members[string(key)] n := s.members[key]
if n < s.Limit { if n < s.Limit {
s.members[string(key)] = n + 1 s.members[key] = n + 1
return true return true
} }
return false return false
@ -239,20 +273,30 @@ func (s *DistinctNetSet) Add(ip net.IP) bool {
// Remove removes an IP from the set. // Remove removes an IP from the set.
func (s *DistinctNetSet) Remove(ip net.IP) { func (s *DistinctNetSet) Remove(ip net.IP) {
s.RemoveAddr(IPToAddr(ip))
}
// RemoveAddr removes an IP from the set.
func (s *DistinctNetSet) RemoveAddr(ip netip.Addr) {
key := s.key(ip) key := s.key(ip)
if n, ok := s.members[string(key)]; ok { if n, ok := s.members[key]; ok {
if n == 1 { if n == 1 {
delete(s.members, string(key)) delete(s.members, key)
} else { } else {
s.members[string(key)] = n - 1 s.members[key] = n - 1
} }
} }
} }
// Contains whether the given IP is contained in the set. // Contains whether the given IP is contained in the set.
func (s DistinctNetSet) Contains(ip net.IP) bool { func (s DistinctNetSet) Contains(ip net.IP) bool {
return s.ContainsAddr(IPToAddr(ip))
}
// ContainsAddr whether the given IP is contained in the set.
func (s DistinctNetSet) ContainsAddr(ip netip.Addr) bool {
key := s.key(ip) key := s.key(ip)
_, ok := s.members[string(key)] _, ok := s.members[key]
return ok return ok
} }
@ -265,54 +309,30 @@ func (s DistinctNetSet) Len() int {
return int(n) return int(n)
} }
// key encodes the map key for an address into a temporary buffer. // key returns the map key for ip.
// func (s *DistinctNetSet) key(ip netip.Addr) netip.Prefix {
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
// The remainder of the key is the IP, truncated to the number of bits.
func (s *DistinctNetSet) key(ip net.IP) net.IP {
// Lazily initialize storage. // Lazily initialize storage.
if s.members == nil { if s.members == nil {
s.members = make(map[string]uint) s.members = make(map[netip.Prefix]uint)
s.buf = make(net.IP, 17)
} }
// Canonicalize ip and bits. p, err := ip.Prefix(int(s.Subnet))
typ := byte('6') if err != nil {
if ip4 := ip.To4(); ip4 != nil { panic(err)
typ, ip = '4', ip4
} }
bits := s.Subnet return p
if bits > uint(len(ip)*8) {
bits = uint(len(ip) * 8)
}
// Encode the prefix into s.buf.
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
s.buf[0] = typ
buf := append(s.buf[:1], ip[:nb]...)
if nb < len(ip) && mask != 0 {
buf = append(buf, ip[nb]&mask)
}
return buf
} }
// String implements fmt.Stringer // String implements fmt.Stringer
func (s DistinctNetSet) String() string { func (s DistinctNetSet) String() string {
keys := maps.Keys(s.members)
slices.SortFunc(keys, func(a, b netip.Prefix) int {
return strings.Compare(a.String(), b.String())
})
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("{") buf.WriteString("{")
keys := make([]string, 0, len(s.members))
for k := range s.members {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys { for i, k := range keys {
var ip net.IP fmt.Fprintf(&buf, "%v×%d", k, s.members[k])
if k[0] == '4' {
ip = make(net.IP, 4)
} else {
ip = make(net.IP, 16)
}
copy(ip, k[1:])
fmt.Fprintf(&buf, "%v×%d", ip, s.members[k])
if i != len(keys)-1 { if i != len(keys)-1 {
buf.WriteString(" ") buf.WriteString(" ")
} }

View File

@ -18,7 +18,9 @@ package netutil
import ( import (
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip"
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
@ -29,7 +31,7 @@ import (
func TestParseNetlist(t *testing.T) { func TestParseNetlist(t *testing.T) {
var tests = []struct { var tests = []struct {
input string input string
wantErr error wantErr string
wantList *Netlist wantList *Netlist
}{ }{
{ {
@ -38,25 +40,27 @@ func TestParseNetlist(t *testing.T) {
}, },
{ {
input: "127.0.0.0/8", input: "127.0.0.0/8",
wantErr: nil, wantList: &Netlist{netip.MustParsePrefix("127.0.0.0/8")},
wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}},
}, },
{ {
input: "127.0.0.0/44", input: "127.0.0.0/44",
wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"}, wantErr: `netip.ParsePrefix("127.0.0.0/44"): prefix length out of range`,
}, },
{ {
input: "127.0.0.0/16, 23.23.23.23/24,", input: "127.0.0.0/16, 23.23.23.23/24,",
wantList: &Netlist{ wantList: &Netlist{
{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)}, netip.MustParsePrefix("127.0.0.0/16"),
{IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)}, netip.MustParsePrefix("23.23.23.23/24"),
}, },
}, },
} }
for _, test := range tests { for _, test := range tests {
l, err := ParseNetlist(test.input) l, err := ParseNetlist(test.input)
if !reflect.DeepEqual(err, test.wantErr) { if err == nil && test.wantErr != "" {
t.Errorf("%q: got no error, expected %q", test.input, test.wantErr)
continue
} else if err != nil && err.Error() != test.wantErr {
t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr) t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr)
continue continue
} }
@ -70,14 +74,12 @@ func TestParseNetlist(t *testing.T) {
func TestNilNetListContains(t *testing.T) { func TestNilNetListContains(t *testing.T) {
var list *Netlist var list *Netlist
checkContains(t, list.Contains, nil, []string{"1.2.3.4"}) checkContains(t, list.Contains, list.ContainsAddr, nil, []string{"1.2.3.4"})
} }
func TestIsLAN(t *testing.T) { func TestIsLAN(t *testing.T) {
checkContains(t, IsLAN, checkContains(t, IsLAN, AddrIsLAN,
[]string{ // included []string{ // included
"0.0.0.0",
"0.2.0.8",
"127.0.0.1", "127.0.0.1",
"10.0.1.1", "10.0.1.1",
"10.22.0.3", "10.22.0.3",
@ -86,25 +88,35 @@ func TestIsLAN(t *testing.T) {
"fe80::f4a1:8eff:fec5:9d9d", "fe80::f4a1:8eff:fec5:9d9d",
"febf::ab32:2233", "febf::ab32:2233",
"fc00::4", "fc00::4",
// 4-in-6
"::ffff:127.0.0.1",
"::ffff:10.10.0.2",
}, },
[]string{ // excluded []string{ // excluded
"192.0.2.1", "192.0.2.1",
"1.0.0.0", "1.0.0.0",
"172.32.0.1", "172.32.0.1",
"fec0::2233", "fec0::2233",
// 4-in-6
"::ffff:88.99.100.2",
}, },
) )
} }
func TestIsSpecialNetwork(t *testing.T) { func TestIsSpecialNetwork(t *testing.T) {
checkContains(t, IsSpecialNetwork, checkContains(t, IsSpecialNetwork, AddrIsSpecialNetwork,
[]string{ // included []string{ // included
"0.0.0.0",
"0.2.0.8",
"192.0.2.1", "192.0.2.1",
"192.0.2.44", "192.0.2.44",
"2001:db8:85a3:8d3:1319:8a2e:370:7348", "2001:db8:85a3:8d3:1319:8a2e:370:7348",
"255.255.255.255", "255.255.255.255",
"224.0.0.22", // IPv4 multicast "224.0.0.22", // IPv4 multicast
"ff05::1:3", // IPv6 multicast "ff05::1:3", // IPv6 multicast
// 4-in-6
"::ffff:255.255.255.255",
"::ffff:192.0.2.1",
}, },
[]string{ // excluded []string{ // excluded
"192.0.3.1", "192.0.3.1",
@ -115,15 +127,21 @@ func TestIsSpecialNetwork(t *testing.T) {
) )
} }
func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) { func checkContains(t *testing.T, fn func(net.IP) bool, fn2 func(netip.Addr) bool, inc, exc []string) {
for _, s := range inc { for _, s := range inc {
if !fn(parseIP(s)) { if !fn(parseIP(s)) {
t.Error("returned false for included address", s) t.Error("returned false for included net.IP", s)
}
if !fn2(netip.MustParseAddr(s)) {
t.Error("returned false for included netip.Addr", s)
} }
} }
for _, s := range exc { for _, s := range exc {
if fn(parseIP(s)) { if fn(parseIP(s)) {
t.Error("returned true for excluded address", s) t.Error("returned true for excluded net.IP", s)
}
if fn2(netip.MustParseAddr(s)) {
t.Error("returned true for excluded netip.Addr", s)
} }
} }
} }
@ -244,14 +262,22 @@ func TestDistinctNetSet(t *testing.T) {
} }
func TestDistinctNetSetAddRemove(t *testing.T) { func TestDistinctNetSetAddRemove(t *testing.T) {
cfg := &quick.Config{} cfg := &quick.Config{
fn := func(ips []net.IP) bool { Values: func(s []reflect.Value, rng *rand.Rand) {
slice := make([]netip.Addr, rng.Intn(20)+1)
for i := range slice {
slice[i] = RandomAddr(rng, false)
}
s[0] = reflect.ValueOf(slice)
},
}
fn := func(ips []netip.Addr) bool {
s := DistinctNetSet{Limit: 3, Subnet: 2} s := DistinctNetSet{Limit: 3, Subnet: 2}
for _, ip := range ips { for _, ip := range ips {
s.Add(ip) s.AddAddr(ip)
} }
for _, ip := range ips { for _, ip := range ips {
s.Remove(ip) s.RemoveAddr(ip)
} }
return s.Len() == 0 return s.Len() == 0
} }

View File

@ -905,14 +905,14 @@ func (srv *Server) listenLoop() {
break break
} }
remoteIP := netutil.AddrIP(fd.RemoteAddr()) remoteIP := netutil.AddrAddr(fd.RemoteAddr())
if err := srv.checkInboundConn(remoteIP); err != nil { if err := srv.checkInboundConn(remoteIP); err != nil {
srv.log.Debug("Rejected inbound connection", "addr", fd.RemoteAddr(), "err", err) srv.log.Debug("Rejected inbound connection", "addr", fd.RemoteAddr(), "err", err)
fd.Close() fd.Close()
slots <- struct{}{} slots <- struct{}{}
continue continue
} }
if remoteIP != nil { if remoteIP.IsValid() {
fd = newMeteredConn(fd) fd = newMeteredConn(fd)
serveMeter.Mark(1) serveMeter.Mark(1)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr()) srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
@ -924,18 +924,19 @@ func (srv *Server) listenLoop() {
} }
} }
func (srv *Server) checkInboundConn(remoteIP net.IP) error { func (srv *Server) checkInboundConn(remoteIP netip.Addr) error {
if remoteIP == nil { if !remoteIP.IsValid() {
// This case happens for internal test connections without remote address.
return nil return nil
} }
// Reject connections that do not match NetRestrict. // Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) { if srv.NetRestrict != nil && !srv.NetRestrict.ContainsAddr(remoteIP) {
return errors.New("not in netrestrict list") return errors.New("not in netrestrict list")
} }
// Reject Internet peers that try too often. // Reject Internet peers that try too often.
now := srv.clock.Now() now := srv.clock.Now()
srv.inboundHistory.expire(now, nil) srv.inboundHistory.expire(now, nil)
if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) { if !netutil.AddrIsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
return errors.New("too many attempts") return errors.New("too many attempts")
} }
srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime)) srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime))
@ -1108,7 +1109,7 @@ func (srv *Server) NodeInfo() *NodeInfo {
Name: srv.Name, Name: srv.Name,
Enode: node.URLv4(), Enode: node.URLv4(),
ID: node.ID().String(), ID: node.ID().String(),
IP: node.IP().String(), IP: node.IPAddr().String(),
ListenAddr: srv.ListenAddr, ListenAddr: srv.ListenAddr,
Protocols: make(map[string]interface{}), Protocols: make(map[string]interface{}),
} }

View File

@ -18,6 +18,7 @@ package p2p
import ( import (
"net" "net"
"net/netip"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -64,8 +65,8 @@ func TestServerPortMapping(t *testing.T) {
t.Error("wrong request count:", reqCount) t.Error("wrong request count:", reqCount)
} }
enr := srv.LocalNode().Node() enr := srv.LocalNode().Node()
if enr.IP().String() != "192.0.2.0" { if enr.IPAddr() != netip.MustParseAddr("192.0.2.0") {
t.Error("wrong IP in ENR:", enr.IP()) t.Error("wrong IP in ENR:", enr.IPAddr())
} }
if enr.TCP() != 30000 { if enr.TCP() != 30000 {
t.Error("wrong TCP port in ENR:", enr.TCP()) t.Error("wrong TCP port in ENR:", enr.TCP())