p2p: use netip.Addr where possible (#29891)

enode.Node was recently changed to store a cache of endpoint information. The IP address in the cache is a netip.Addr. I chose that type over net.IP because it is just better. netip.Addr is meant to be used as a value type. Copying it does not allocate, it can be compared with ==, and can be used as a map key.

This PR changes most uses of Node.IP() into Node.IPAddr(), which returns the cached value directly without allocating.
While there are still some public APIs left where net.IP is used, I have converted all code used internally by p2p/discover to the new types. So this does change some public Go API, but hopefully not APIs any external code actually uses.

There weren't supposed to be any semantic differences resulting from this refactoring, however it does introduce one: In package p2p/netutil we treated the 0.0.0.0/8 network (addresses 0.x.y.z) as LAN, but netip.Addr.IsPrivate() doesn't. The treatment of this particular IP address range is controversial, with some software supporting it and others not. IANA lists it as special-purpose and invalid as a destination for a long time, so I don't know why I put it into the LAN list. It has now been marked as special in p2p/netutil as well.
This commit is contained in:
Felix Lange 2024-06-05 19:31:04 +02:00 committed by GitHub
parent d09ddac399
commit bc6569462d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 392 additions and 313 deletions

View File

@ -53,7 +53,8 @@ func (s *Suite) dial() (*Conn, error) {
// dialAs attempts to dial a given node and perform a handshake using the given
// private key.
func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
tcpEndpoint, _ := s.Dest.TCPEndpoint()
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil {
return nil, err
}

View File

@ -53,10 +53,12 @@ func newTestEnv(remote string, listen1, listen2 string) *testenv {
if err != nil {
panic(err)
}
if node.IP() == nil || node.UDP() == 0 {
if !node.IPAddr().IsValid() || node.UDP() == 0 {
var ip net.IP
var tcpPort, udpPort int
if ip = node.IP(); ip == nil {
if node.IPAddr().IsValid() {
ip = node.IPAddr().AsSlice()
} else {
ip = net.ParseIP("127.0.0.1")
}
if tcpPort = node.TCP(); tcpPort == 0 {

View File

@ -19,7 +19,7 @@ package main
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
"strconv"
"strings"
@ -205,11 +205,11 @@ func trueFilter(args []string) (nodeFilter, error) {
}
func ipFilter(args []string) (nodeFilter, error) {
_, cidr, err := net.ParseCIDR(args[0])
prefix, err := netip.ParsePrefix(args[0])
if err != nil {
return nil, err
}
f := func(n nodeJSON) bool { return cidr.Contains(n.N.IP()) }
f := func(n nodeJSON) bool { return prefix.Contains(n.N.IPAddr()) }
return f, nil
}

View File

@ -77,7 +77,11 @@ var (
func rlpxPing(ctx *cli.Context) error {
n := getNodeArg(ctx)
fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", n.IP(), n.TCP()))
tcpEndpoint, ok := n.TCPEndpoint()
if !ok {
return fmt.Errorf("node has no TCP endpoint")
}
fd, err := net.Dial("tcp", tcpEndpoint.String())
if err != nil {
return err
}

View File

@ -65,11 +65,8 @@ type tcpDialer struct {
}
func (t tcpDialer) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) {
return t.d.DialContext(ctx, "tcp", nodeAddr(dest).String())
}
func nodeAddr(n *enode.Node) net.Addr {
return &net.TCPAddr{IP: n.IP(), Port: n.TCP()}
addr, _ := dest.TCPEndpoint()
return t.d.DialContext(ctx, "tcp", addr.String())
}
// checkDial errors:
@ -243,7 +240,7 @@ loop:
select {
case node := <-nodesCh:
if err := d.checkDial(node); err != nil {
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IP(), "reason", err)
d.log.Trace("Discarding dial candidate", "id", node.ID(), "ip", node.IPAddr(), "reason", err)
} else {
d.startDial(newDialTask(node, dynDialedConn))
}
@ -277,7 +274,7 @@ loop:
case node := <-d.addStaticCh:
id := node.ID()
_, exists := d.static[id]
d.log.Trace("Adding static node", "id", id, "ip", node.IP(), "added", !exists)
d.log.Trace("Adding static node", "id", id, "ip", node.IPAddr(), "added", !exists)
if exists {
continue loop
}
@ -376,7 +373,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if n.ID() == d.self {
return errSelf
}
if n.IP() != nil && n.TCP() == 0 {
if n.IPAddr().IsValid() && n.TCP() == 0 {
// This check can trigger if a non-TCP node is found
// by discovery. If there is no IP, the node is a static
// node and the actual endpoint will be resolved later in dialTask.
@ -388,7 +385,7 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if _, ok := d.peers[n.ID()]; ok {
return errAlreadyConnected
}
if d.netRestrict != nil && !d.netRestrict.Contains(n.IP()) {
if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) {
return errNetRestrict
}
if d.history.contains(string(n.ID().Bytes())) {
@ -439,7 +436,7 @@ func (d *dialScheduler) removeFromStaticPool(idx int) {
// startDial runs the given dial task in a separate goroutine.
func (d *dialScheduler) startDial(task *dialTask) {
node := task.dest()
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags)
d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IPAddr(), "flag", task.flags)
hkey := string(node.ID().Bytes())
d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration))
d.dialing[node.ID()] = task
@ -492,7 +489,7 @@ func (t *dialTask) run(d *dialScheduler) {
}
func (t *dialTask) needResolve() bool {
return t.flags&staticDialedConn != 0 && t.dest().IP() == nil
return t.flags&staticDialedConn != 0 && !t.dest().IPAddr().IsValid()
}
// resolve attempts to find the current endpoint for the destination
@ -526,7 +523,8 @@ func (t *dialTask) resolve(d *dialScheduler) bool {
// The node was found.
t.resolveDelay = initialResolveDelay
t.destPtr.Store(resolved)
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()})
resAddr, _ := resolved.TCPEndpoint()
d.log.Debug("Resolved node", "id", resolved.ID(), "addr", resAddr)
return true
}
@ -535,7 +533,8 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
dialMeter.Mark(1)
fd, err := d.dialer.Dial(d.ctx, dest)
if err != nil {
d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err))
addr, _ := dest.TCPEndpoint()
d.log.Trace("Dial error", "id", dest.ID(), "addr", addr, "conn", t.flags, "err", cleanupDialErr(err))
dialConnectionError.Mark(1)
return &dialError{err}
}
@ -545,7 +544,7 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error {
func (t *dialTask) String() string {
node := t.dest()
id := node.ID()
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP())
return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IPAddr(), node.TCP())
}
func cleanupDialErr(err error) error {

View File

@ -25,7 +25,7 @@ package discover
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sync"
"time"
@ -207,8 +207,8 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error {
if err := n.ValidateComplete(); err != nil {
return fmt.Errorf("bad bootstrap node %q: %v", n, err)
}
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.Contains(n.IP()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP())
if tab.cfg.NetRestrict != nil && !tab.cfg.NetRestrict.ContainsAddr(n.IPAddr()) {
tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IPAddr())
continue
}
nursery = append(nursery, n)
@ -448,7 +448,7 @@ func (tab *Table) loadSeedNodes() {
for i := range seeds {
seed := seeds[i]
if tab.log.Enabled(context.Background(), log.LevelTrace) {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP()))
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IPAddr()))
addr, _ := seed.UDPEndpoint()
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age)
}
@ -474,31 +474,31 @@ func (tab *Table) bucketAtDistance(d int) *bucket {
return tab.buckets[d-bucketMinDistance-1]
}
func (tab *Table) addIP(b *bucket, ip net.IP) bool {
if len(ip) == 0 {
func (tab *Table) addIP(b *bucket, ip netip.Addr) bool {
if !ip.IsValid() || ip.IsUnspecified() {
return false // Nodes without IP cannot be added.
}
if netutil.IsLAN(ip) {
if netutil.AddrIsLAN(ip) {
return true
}
if !tab.ips.Add(ip) {
if !tab.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds table limit", "ip", ip)
return false
}
if !b.ips.Add(ip) {
if !b.ips.AddAddr(ip) {
tab.log.Debug("IP exceeds bucket limit", "ip", ip)
tab.ips.Remove(ip)
tab.ips.RemoveAddr(ip)
return false
}
return true
}
func (tab *Table) removeIP(b *bucket, ip net.IP) {
if netutil.IsLAN(ip) {
func (tab *Table) removeIP(b *bucket, ip netip.Addr) {
if netutil.AddrIsLAN(ip) {
return
}
tab.ips.Remove(ip)
b.ips.Remove(ip)
tab.ips.RemoveAddr(ip)
b.ips.RemoveAddr(ip)
}
// handleAddNode adds the node in the request to the table, if there is space.
@ -524,7 +524,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
tab.addReplacement(b, req.node)
return false
}
if !tab.addIP(b, req.node.IP()) {
if !tab.addIP(b, req.node.IPAddr()) {
// Can't add: IP limit reached.
return false
}
@ -547,7 +547,7 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
// TODO: update ENR
return
}
if !tab.addIP(b, n.IP()) {
if !tab.addIP(b, n.IPAddr()) {
return
}
@ -555,7 +555,7 @@ func (tab *Table) addReplacement(b *bucket, n *enode.Node) {
var removed *tableNode
b.replacements, removed = pushNode(b.replacements, wn, maxReplacements)
if removed != nil {
tab.removeIP(b, removed.IP())
tab.removeIP(b, removed.IPAddr())
}
}
@ -595,12 +595,12 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
// Remove the node.
n := b.entries[index]
b.entries = slices.Delete(b.entries, index, index+1)
tab.removeIP(b, n.IP())
tab.removeIP(b, n.IPAddr())
tab.nodeRemoved(b, n)
// Add replacement.
if len(b.replacements) == 0 {
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP())
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr())
return nil
}
rindex := tab.rand.Intn(len(b.replacements))
@ -608,7 +608,7 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode {
b.replacements = slices.Delete(b.replacements, rindex, rindex+1)
b.entries = append(b.entries, rep)
tab.nodeAdded(b, rep)
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP())
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IPAddr(), "r", rep.ID(), "rip", rep.IPAddr())
return rep
}
@ -635,10 +635,10 @@ func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool)
ipchanged := newRecord.IPAddr() != n.IPAddr()
portchanged := newRecord.UDP() != n.UDP()
if ipchanged {
tab.removeIP(b, n.IP())
if !tab.addIP(b, newRecord.IP()) {
tab.removeIP(b, n.IPAddr())
if !tab.addIP(b, newRecord.IPAddr()) {
// It doesn't fit with the limit, put the previous record back.
tab.addIP(b, n.IP())
tab.addIP(b, n.IPAddr())
return n, false
}
}
@ -657,11 +657,11 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) {
var fails int
if op.success {
// Reset failure counter because it counts _consecutive_ failures.
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0)
tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), 0)
} else {
fails = tab.db.FindFails(op.node.ID(), op.node.IP())
fails = tab.db.FindFails(op.node.ID(), op.node.IPAddr())
fails++
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails)
tab.db.UpdateFindFails(op.node.ID(), op.node.IPAddr(), fails)
}
tab.mutex.Lock()

View File

@ -188,7 +188,7 @@ func checkIPLimitInvariant(t *testing.T, tab *Table) {
tabset := netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}
for _, b := range tab.buckets {
for _, n := range b.entries {
tabset.Add(n.IP())
tabset.AddAddr(n.IPAddr())
}
}
if tabset.String() != tab.ips.String() {
@ -268,7 +268,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
}
for _, id := range gen([]enode.ID{}, rand).([]enode.ID) {
r := new(enr.Record)
r.Set(enr.IP(genIP(rand)))
r.Set(enr.IPv4Addr(netutil.RandomAddr(rand, true)))
n := enode.SignNull(r, id)
t.All = append(t.All, n)
}
@ -385,11 +385,11 @@ func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) {
}
t.Log("wrong bucket content. have nodes:")
for _, n := range b.entries {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
}
t.Log("want nodes:")
for _, n := range nodes {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IPAddr())
}
t.FailNow()
@ -483,12 +483,6 @@ func gen(typ interface{}, rand *rand.Rand) interface{} {
return v.Interface()
}
func genIP(rand *rand.Rand) net.IP {
ip := make(net.IP, 4)
rand.Read(ip)
return ip
}
func quickcfg() *quick.Config {
return &quick.Config{
MaxCount: 5000,

View File

@ -100,8 +100,9 @@ func idAtDistance(a enode.ID, n int) (b enode.ID) {
return b
}
// intIP returns a LAN IP address based on i.
func intIP(i int) net.IP {
return net.IP{byte(i), 0, 2, byte(i)}
return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)}
}
// fillBucket inserts nodes into the given bucket until it is full.
@ -254,7 +255,7 @@ NotEqual:
}
func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP())
return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr()
}
func sortByID[N nodeType](nodes []N) {

View File

@ -25,7 +25,6 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"time"
@ -250,8 +249,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func())
return matched, matched
})
// Send the packet.
toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()}
t.localNode.UDPContact(toUDPAddr)
t.localNode.UDPContact(toaddr)
t.write(toaddr, toid, req.Name(), packet)
return rm
}
@ -383,7 +381,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
if respN.Seq() < n.Seq() {
return n, nil // response record is older
}
if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil {
if err := netutil.CheckRelayAddr(addr.Addr(), respN.IPAddr()); err != nil {
return nil, fmt.Errorf("invalid IP in response record: %v", err)
}
return respN, nil
@ -578,15 +576,14 @@ func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error {
// checkBond checks if the given node has a recent enough endpoint proof.
func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool {
return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration
return time.Since(t.db.LastPongReceived(id, ip.Addr())) < bondExpiration
}
// ensureBond solicits a ping from a node if we haven't seen a ping from it for a while.
// This ensures there is a valid endpoint proof on the remote end.
func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) {
ip := toaddr.Addr().AsSlice()
tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration
if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures {
tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.Addr())) > bondExpiration
if tooOld || t.db.FindFails(toid, toaddr.Addr()) > maxFindnodeFailures {
rm := t.sendPing(toid, toaddr, nil)
<-rm.errc
// Wait for them to ping back and process our pong.
@ -687,7 +684,7 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode
// Ping back if our last pong on file is too far in the past.
fromIP := from.Addr().AsSlice()
n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port()))
if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration {
if time.Since(t.db.LastPongReceived(n.ID(), from.Addr())) > bondExpiration {
t.sendPing(fromID, from, func() {
t.tab.addInboundNode(n)
})
@ -696,10 +693,9 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode
}
// Update node database and endpoint predictor.
t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now())
fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())}
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
t.db.UpdateLastPingReceived(n.ID(), from.Addr(), time.Now())
toaddr := netip.AddrPortFrom(netutil.IPToAddr(req.To.IP), req.To.UDP)
t.localNode.UDPEndpointStatement(from, toaddr)
}
// PONG/v4
@ -713,11 +709,9 @@ func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode
if !t.handleReply(fromID, from.Addr(), req) {
return errUnsolicitedReply
}
fromIP := from.Addr().AsSlice()
fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())}
toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
t.db.UpdateLastPongReceived(fromID, fromIP, time.Now())
toaddr := netip.AddrPortFrom(netutil.IPToAddr(req.To.IP), req.To.UDP)
t.localNode.UDPEndpointStatement(from, toaddr)
t.db.UpdateLastPongReceived(fromID, from.Addr(), time.Now())
return nil
}
@ -753,8 +747,7 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID e
p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
var sent bool
for _, n := range closest {
fromIP := from.Addr().AsSlice()
if netutil.CheckRelayIP(fromIP, n.IP()) == nil {
if netutil.CheckRelayAddr(from.Addr(), n.IPAddr()) == nil {
p.Nodes = append(p.Nodes, nodeToRPC(n))
}
if len(p.Nodes) == v4wire.MaxNeighbors {

View File

@ -274,7 +274,7 @@ func TestUDPv4_findnode(t *testing.T) {
// ensure there's a bond with the test node,
// findnode won't be accepted otherwise.
remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID()
test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now())
test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr(), time.Now())
// check that closest neighbors are returned.
expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true)
@ -309,7 +309,7 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) {
defer test.close()
rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey)
test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now())
test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr(), time.Now())
// queue a pending findnode request
resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1)
@ -437,8 +437,8 @@ func TestUDPv4_successfulPing(t *testing.T) {
if n.ID() != rid {
t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid)
}
if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) {
t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr())
if n.IPAddr() != test.remoteaddr.Addr() {
t.Errorf("node has wrong IP: got %v, want: %v", n.IPAddr(), test.remoteaddr.Addr())
}
if n.UDP() != int(test.remoteaddr.Port()) {
t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port())

View File

@ -428,10 +428,10 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s
if err != nil {
return nil, err
}
if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil {
if err := netutil.CheckRelayAddr(c.addr.Addr(), node.IPAddr()); err != nil {
return nil, err
}
if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) {
if t.netrestrict != nil && !t.netrestrict.ContainsAddr(node.IPAddr()) {
return nil, errors.New("not contained in netrestrict list")
}
if node.UDP() <= 1024 {
@ -754,9 +754,8 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort
t.handlePing(p, fromID, fromAddr)
case *v5wire.Pong:
if t.handleCallResponse(fromID, fromAddr, p) {
fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())}
toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}
t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr)
toAddr := netip.AddrPortFrom(netutil.IPToAddr(p.ToIP), p.ToPort)
t.localNode.UDPEndpointStatement(fromAddr, toAddr)
}
case *v5wire.Findnode:
t.handleFindnode(p, fromID, fromAddr)
@ -848,7 +847,6 @@ func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr net
// collectTableNodes creates a FINDNODE result set for the given distances.
func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node {
ripSlice := rip.AsSlice()
var bn []*enode.Node
var nodes []*enode.Node
var processed = make(map[uint]struct{})
@ -863,7 +861,7 @@ func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) [
for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) {
// Apply some pre-checks to avoid sending invalid nodes.
// Note liveness is checked by appendLiveNodes.
if netutil.CheckRelayIP(ripSlice, n.IP()) != nil {
if netutil.CheckRelayAddr(rip, n.IPAddr()) != nil {
continue
}
nodes = append(nodes, n)

View File

@ -606,7 +606,7 @@ func (n *handshakeTestNode) n() *enode.Node {
}
func (n *handshakeTestNode) addr() string {
return n.ln.Node().IP().String()
return n.ln.Node().IPAddr().String()
}
func (n *handshakeTestNode) id() enode.ID {

View File

@ -20,8 +20,8 @@ import (
"crypto/ecdsa"
"fmt"
"net"
"net/netip"
"reflect"
"strconv"
"sync"
"sync/atomic"
"time"
@ -175,8 +175,8 @@ func (ln *LocalNode) delete(e enr.Entry) {
}
}
func (ln *LocalNode) endpointForIP(ip net.IP) *lnEndpoint {
if ip.To4() != nil {
func (ln *LocalNode) endpointForIP(ip netip.Addr) *lnEndpoint {
if ip.Is4() {
return &ln.endpoint4
}
return &ln.endpoint6
@ -188,7 +188,7 @@ func (ln *LocalNode) SetStaticIP(ip net.IP) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.endpointForIP(ip).staticIP = ip
ln.endpointForIP(netutil.IPToAddr(ip)).staticIP = ip
ln.updateEndpoints()
}
@ -198,7 +198,7 @@ func (ln *LocalNode) SetFallbackIP(ip net.IP) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.endpointForIP(ip).fallbackIP = ip
ln.endpointForIP(netutil.IPToAddr(ip)).fallbackIP = ip
ln.updateEndpoints()
}
@ -215,21 +215,21 @@ func (ln *LocalNode) SetFallbackUDP(port int) {
// UDPEndpointStatement should be called whenever a statement about the local node's
// UDP endpoint is received. It feeds the local endpoint predictor.
func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint netip.AddrPort) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.endpointForIP(endpoint.IP).track.AddStatement(fromaddr.String(), endpoint.String())
ln.endpointForIP(endpoint.Addr()).track.AddStatement(fromaddr.Addr(), endpoint)
ln.updateEndpoints()
}
// UDPContact should be called whenever the local node has announced itself to another node
// via UDP. It feeds the local endpoint predictor.
func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
func (ln *LocalNode) UDPContact(toaddr netip.AddrPort) {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.endpointForIP(toaddr.IP).track.AddContact(toaddr.String())
ln.endpointForIP(toaddr.Addr()).track.AddContact(toaddr.Addr())
ln.updateEndpoints()
}
@ -268,29 +268,13 @@ func (e *lnEndpoint) get() (newIP net.IP, newPort uint16) {
}
if e.staticIP != nil {
newIP = e.staticIP
} else if ip, port := predictAddr(e.track); ip != nil {
newIP = ip
newPort = port
} else if ap := e.track.PredictEndpoint(); ap.IsValid() {
newIP = ap.Addr().AsSlice()
newPort = ap.Port()
}
return newIP, newPort
}
// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
// endpoint representation to IP and port types.
func predictAddr(t *netutil.IPTracker) (net.IP, uint16) {
ep := t.PredictEndpoint()
if ep == "" {
return nil, 0
}
ipString, portString, _ := net.SplitHostPort(ep)
ip := net.ParseIP(ipString)
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, 0
}
return ip, uint16(port)
}
func (ln *LocalNode) invalidate() {
ln.cur.Store((*Node)(nil))
}
@ -314,7 +298,7 @@ func (ln *LocalNode) sign() {
panic(fmt.Errorf("enode: can't verify local record: %v", err))
}
ln.cur.Store(n)
log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IPAddr(), "udp", n.UDP(), "tcp", n.TCP())
}
func (ln *LocalNode) bumpSeq() {

View File

@ -17,12 +17,14 @@
package enode
import (
"crypto/rand"
"math/rand"
"net"
"net/netip"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil"
"github.com/stretchr/testify/assert"
)
@ -88,6 +90,7 @@ func TestLocalNodeSeqPersist(t *testing.T) {
// This test checks behavior of the endpoint predictor.
func TestLocalNodeEndpoint(t *testing.T) {
var (
rng = rand.New(rand.NewSource(4))
fallback = &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 80}
predicted = &net.UDPAddr{IP: net.IP{127, 0, 1, 2}, Port: 81}
staticIP = net.IP{127, 0, 1, 2}
@ -96,6 +99,7 @@ func TestLocalNodeEndpoint(t *testing.T) {
defer db.Close()
// Nothing is set initially.
assert.Equal(t, netip.Addr{}, ln.Node().IPAddr())
assert.Equal(t, net.IP(nil), ln.Node().IP())
assert.Equal(t, 0, ln.Node().UDP())
initialSeq := ln.Node().Seq()
@ -103,26 +107,30 @@ func TestLocalNodeEndpoint(t *testing.T) {
// Set up fallback address.
ln.SetFallbackIP(fallback.IP)
ln.SetFallbackUDP(fallback.Port)
assert.Equal(t, netutil.IPToAddr(fallback.IP), ln.Node().IPAddr())
assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+1, ln.Node().Seq())
// Add endpoint statements from random hosts.
for i := 0; i < iptrackMinStatements; i++ {
assert.Equal(t, netutil.IPToAddr(fallback.IP), ln.Node().IPAddr())
assert.Equal(t, fallback.IP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+1, ln.Node().Seq())
from := &net.UDPAddr{IP: make(net.IP, 4), Port: 90}
rand.Read(from.IP)
ln.UDPEndpointStatement(from, predicted)
from := netip.AddrPortFrom(netutil.RandomAddr(rng, true), 9000)
endpoint := netip.AddrPortFrom(netutil.IPToAddr(predicted.IP), uint16(predicted.Port))
ln.UDPEndpointStatement(from, endpoint)
}
assert.Equal(t, netutil.IPToAddr(predicted.IP), ln.Node().IPAddr())
assert.Equal(t, predicted.IP, ln.Node().IP())
assert.Equal(t, predicted.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+2, ln.Node().Seq())
// Static IP overrides prediction.
ln.SetStaticIP(staticIP)
assert.Equal(t, netutil.IPToAddr(staticIP), ln.Node().IPAddr())
assert.Equal(t, staticIP, ln.Node().IP())
assert.Equal(t, fallback.Port, ln.Node().UDP())
assert.Equal(t, initialSeq+3, ln.Node().Seq())

View File

@ -21,7 +21,7 @@ import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"net/netip"
"os"
"sync"
"time"
@ -66,7 +66,7 @@ var (
errInvalidIP = errors.New("invalid IP")
)
var zeroIP = make(net.IP, 16)
var zeroIP = netip.IPv6Unspecified()
// DB is the node database, storing previously seen nodes and any collected metadata about
// them for QoS purposes.
@ -151,39 +151,37 @@ func splitNodeKey(key []byte) (id ID, rest []byte) {
}
// nodeItemKey returns the database key for a node metadata field.
func nodeItemKey(id ID, ip net.IP, field string) []byte {
ip16 := ip.To16()
if ip16 == nil {
panic(fmt.Errorf("invalid IP (length %d)", len(ip)))
func nodeItemKey(id ID, ip netip.Addr, field string) []byte {
if !ip.IsValid() {
panic("invalid IP")
}
return bytes.Join([][]byte{nodeKey(id), ip16, []byte(field)}, []byte{':'})
ip16 := ip.As16()
return bytes.Join([][]byte{nodeKey(id), ip16[:], []byte(field)}, []byte{':'})
}
// splitNodeItemKey returns the components of a key created by nodeItemKey.
func splitNodeItemKey(key []byte) (id ID, ip net.IP, field string) {
func splitNodeItemKey(key []byte) (id ID, ip netip.Addr, field string) {
id, key = splitNodeKey(key)
// Skip discover root.
if string(key) == dbDiscoverRoot {
return id, nil, ""
return id, netip.Addr{}, ""
}
key = key[len(dbDiscoverRoot)+1:]
// Split out the IP.
ip = key[:16]
if ip4 := ip.To4(); ip4 != nil {
ip = ip4
}
ip, _ = netip.AddrFromSlice(key[:16])
key = key[16+1:]
// Field is the remainder of key.
field = string(key)
return id, ip, field
}
func v5Key(id ID, ip net.IP, field string) []byte {
func v5Key(id ID, ip netip.Addr, field string) []byte {
ip16 := ip.As16()
return bytes.Join([][]byte{
[]byte(dbNodePrefix),
id[:],
[]byte(dbDiscv5Root),
ip.To16(),
ip16[:],
[]byte(field),
}, []byte{':'})
}
@ -364,24 +362,24 @@ func (db *DB) expireNodes() {
// LastPingReceived retrieves the time of the last ping packet received from
// a remote node.
func (db *DB) LastPingReceived(id ID, ip net.IP) time.Time {
if ip = ip.To16(); ip == nil {
func (db *DB) LastPingReceived(id ID, ip netip.Addr) time.Time {
if !ip.IsValid() {
return time.Time{}
}
return time.Unix(db.fetchInt64(nodeItemKey(id, ip, dbNodePing)), 0)
}
// UpdateLastPingReceived updates the last time we tried contacting a remote node.
func (db *DB) UpdateLastPingReceived(id ID, ip net.IP, instance time.Time) error {
if ip = ip.To16(); ip == nil {
func (db *DB) UpdateLastPingReceived(id ID, ip netip.Addr, instance time.Time) error {
if !ip.IsValid() {
return errInvalidIP
}
return db.storeInt64(nodeItemKey(id, ip, dbNodePing), instance.Unix())
}
// LastPongReceived retrieves the time of the last successful pong from remote node.
func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time {
if ip = ip.To16(); ip == nil {
func (db *DB) LastPongReceived(id ID, ip netip.Addr) time.Time {
if !ip.IsValid() {
return time.Time{}
}
// Launch expirer
@ -390,40 +388,40 @@ func (db *DB) LastPongReceived(id ID, ip net.IP) time.Time {
}
// UpdateLastPongReceived updates the last pong time of a node.
func (db *DB) UpdateLastPongReceived(id ID, ip net.IP, instance time.Time) error {
if ip = ip.To16(); ip == nil {
func (db *DB) UpdateLastPongReceived(id ID, ip netip.Addr, instance time.Time) error {
if !ip.IsValid() {
return errInvalidIP
}
return db.storeInt64(nodeItemKey(id, ip, dbNodePong), instance.Unix())
}
// FindFails retrieves the number of findnode failures since bonding.
func (db *DB) FindFails(id ID, ip net.IP) int {
if ip = ip.To16(); ip == nil {
func (db *DB) FindFails(id ID, ip netip.Addr) int {
if !ip.IsValid() {
return 0
}
return int(db.fetchInt64(nodeItemKey(id, ip, dbNodeFindFails)))
}
// UpdateFindFails updates the number of findnode failures since bonding.
func (db *DB) UpdateFindFails(id ID, ip net.IP, fails int) error {
if ip = ip.To16(); ip == nil {
func (db *DB) UpdateFindFails(id ID, ip netip.Addr, fails int) error {
if !ip.IsValid() {
return errInvalidIP
}
return db.storeInt64(nodeItemKey(id, ip, dbNodeFindFails), int64(fails))
}
// FindFailsV5 retrieves the discv5 findnode failure counter.
func (db *DB) FindFailsV5(id ID, ip net.IP) int {
if ip = ip.To16(); ip == nil {
func (db *DB) FindFailsV5(id ID, ip netip.Addr) int {
if !ip.IsValid() {
return 0
}
return int(db.fetchInt64(v5Key(id, ip, dbNodeFindFails)))
}
// UpdateFindFailsV5 stores the discv5 findnode failure counter.
func (db *DB) UpdateFindFailsV5(id ID, ip net.IP, fails int) error {
if ip = ip.To16(); ip == nil {
func (db *DB) UpdateFindFailsV5(id ID, ip netip.Addr, fails int) error {
if !ip.IsValid() {
return errInvalidIP
}
return db.storeInt64(v5Key(id, ip, dbNodeFindFails), int64(fails))
@ -470,7 +468,7 @@ seek:
id[0] = 0
continue seek // iterator exhausted
}
if now.Sub(db.LastPongReceived(n.ID(), n.IP())) > maxAge {
if now.Sub(db.LastPongReceived(n.ID(), n.IPAddr())) > maxAge {
continue seek
}
for i := range nodes {

View File

@ -20,6 +20,7 @@ import (
"bytes"
"fmt"
"net"
"net/netip"
"path/filepath"
"reflect"
"testing"
@ -48,8 +49,10 @@ func TestDBNodeKey(t *testing.T) {
}
func TestDBNodeItemKey(t *testing.T) {
wantIP := net.IP{127, 0, 0, 3}
wantIP := netip.MustParseAddr("127.0.0.3")
wantIP4in6 := netip.AddrFrom16(wantIP.As16())
wantField := "foobar"
enc := nodeItemKey(keytestID, wantIP, wantField)
want := []byte{
'n', ':',
@ -69,7 +72,7 @@ func TestDBNodeItemKey(t *testing.T) {
if id != keytestID {
t.Errorf("splitNodeItemKey returned wrong ID: %v", id)
}
if !ip.Equal(wantIP) {
if ip != wantIP4in6 {
t.Errorf("splitNodeItemKey returned wrong IP: %v", ip)
}
if field != wantField {
@ -123,33 +126,33 @@ func TestDBFetchStore(t *testing.T) {
defer db.Close()
// Check fetch/store operations on a node ping object
if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != 0 {
if stored := db.LastPingReceived(node.ID(), node.IPAddr()); stored.Unix() != 0 {
t.Errorf("ping: non-existing object: %v", stored)
}
if err := db.UpdateLastPingReceived(node.ID(), node.IP(), inst); err != nil {
if err := db.UpdateLastPingReceived(node.ID(), node.IPAddr(), inst); err != nil {
t.Errorf("ping: failed to update: %v", err)
}
if stored := db.LastPingReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() {
if stored := db.LastPingReceived(node.ID(), node.IPAddr()); stored.Unix() != inst.Unix() {
t.Errorf("ping: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node pong object
if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != 0 {
if stored := db.LastPongReceived(node.ID(), node.IPAddr()); stored.Unix() != 0 {
t.Errorf("pong: non-existing object: %v", stored)
}
if err := db.UpdateLastPongReceived(node.ID(), node.IP(), inst); err != nil {
if err := db.UpdateLastPongReceived(node.ID(), node.IPAddr(), inst); err != nil {
t.Errorf("pong: failed to update: %v", err)
}
if stored := db.LastPongReceived(node.ID(), node.IP()); stored.Unix() != inst.Unix() {
if stored := db.LastPongReceived(node.ID(), node.IPAddr()); stored.Unix() != inst.Unix() {
t.Errorf("pong: value mismatch: have %v, want %v", stored, inst)
}
// Check fetch/store operations on a node findnode-failure object
if stored := db.FindFails(node.ID(), node.IP()); stored != 0 {
if stored := db.FindFails(node.ID(), node.IPAddr()); stored != 0 {
t.Errorf("find-node fails: non-existing object: %v", stored)
}
if err := db.UpdateFindFails(node.ID(), node.IP(), num); err != nil {
if err := db.UpdateFindFails(node.ID(), node.IPAddr(), num); err != nil {
t.Errorf("find-node fails: failed to update: %v", err)
}
if stored := db.FindFails(node.ID(), node.IP()); stored != num {
if stored := db.FindFails(node.ID(), node.IPAddr()); stored != num {
t.Errorf("find-node fails: value mismatch: have %v, want %v", stored, num)
}
// Check fetch/store operations on an actual node object
@ -266,7 +269,7 @@ func testSeedQuery() error {
if err := db.UpdateNode(seed.node); err != nil {
return fmt.Errorf("node %d: failed to insert: %v", i, err)
}
if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil {
if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IPAddr(), seed.pong); err != nil {
return fmt.Errorf("node %d: failed to insert bondTime: %v", i, err)
}
}
@ -427,7 +430,7 @@ func TestDBExpiration(t *testing.T) {
t.Fatalf("node %d: failed to insert: %v", i, err)
}
}
if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IP(), seed.pong); err != nil {
if err := db.UpdateLastPongReceived(seed.node.ID(), seed.node.IPAddr(), seed.pong); err != nil {
t.Fatalf("node %d: failed to update bondTime: %v", i, err)
}
}
@ -438,13 +441,13 @@ func TestDBExpiration(t *testing.T) {
unixZeroTime := time.Unix(0, 0)
for i, seed := range nodeDBExpirationNodes {
node := db.Node(seed.node.ID())
pong := db.LastPongReceived(seed.node.ID(), seed.node.IP())
pong := db.LastPongReceived(seed.node.ID(), seed.node.IPAddr())
if seed.exp {
if seed.storeNode && node != nil {
t.Errorf("node %d (%s) shouldn't be present after expiration", i, seed.node.ID().TerminalString())
}
if !pong.Equal(unixZeroTime) {
t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IP())
t.Errorf("pong time %d (%s %v) shouldn't be present after expiration", i, seed.node.ID().TerminalString(), seed.node.IPAddr())
}
} else {
if seed.storeNode && node == nil {
@ -463,7 +466,7 @@ func TestDBExpireV5(t *testing.T) {
db, _ := OpenDB("")
defer db.Close()
ip := net.IP{127, 0, 0, 1}
ip := netip.MustParseAddr("127.0.0.1")
db.UpdateFindFailsV5(ID{}, ip, 4)
db.expireNodes()
}

View File

@ -16,18 +16,53 @@
package netutil
import "net"
import (
"fmt"
"math/rand"
"net"
"net/netip"
)
// AddrIP gets the IP address contained in addr. It returns nil if no address is present.
func AddrIP(addr net.Addr) net.IP {
// AddrAddr gets the IP address contained in addr. The result will be invalid if the
// address type is unsupported.
func AddrAddr(addr net.Addr) netip.Addr {
switch a := addr.(type) {
case *net.IPAddr:
return a.IP
return IPToAddr(a.IP)
case *net.TCPAddr:
return a.IP
return IPToAddr(a.IP)
case *net.UDPAddr:
return a.IP
return IPToAddr(a.IP)
default:
return nil
return netip.Addr{}
}
}
// IPToAddr converts net.IP to netip.Addr. Note that unlike netip.AddrFromSlice, this
// function will always ensure that the resulting Addr is IPv4 when the input is.
func IPToAddr(ip net.IP) netip.Addr {
if ip4 := ip.To4(); ip4 != nil {
addr, _ := netip.AddrFromSlice(ip4)
return addr
} else if ip6 := ip.To16(); ip6 != nil {
addr, _ := netip.AddrFromSlice(ip6)
return addr
}
return netip.Addr{}
}
// RandomAddr creates a random IP address.
func RandomAddr(rng *rand.Rand, ipv4 bool) netip.Addr {
var bytes []byte
if ipv4 || rng.Intn(2) == 0 {
bytes = make([]byte, 4)
} else {
bytes = make([]byte, 16)
}
rng.Read(bytes)
addr, ok := netip.AddrFromSlice(bytes)
if !ok {
panic(fmt.Errorf("BUG! invalid IP %v", bytes))
}
return addr
}

View File

@ -17,6 +17,7 @@
package netutil
import (
"net/netip"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
@ -29,14 +30,14 @@ type IPTracker struct {
contactWindow time.Duration
minStatements int
clock mclock.Clock
statements map[string]ipStatement
contact map[string]mclock.AbsTime
statements map[netip.Addr]ipStatement
contact map[netip.Addr]mclock.AbsTime
lastStatementGC mclock.AbsTime
lastContactGC mclock.AbsTime
}
type ipStatement struct {
endpoint string
endpoint netip.AddrPort
time mclock.AbsTime
}
@ -51,9 +52,9 @@ func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTra
return &IPTracker{
window: window,
contactWindow: contactWindow,
statements: make(map[string]ipStatement),
statements: make(map[netip.Addr]ipStatement),
minStatements: minStatements,
contact: make(map[string]mclock.AbsTime),
contact: make(map[netip.Addr]mclock.AbsTime),
clock: mclock.System{},
}
}
@ -74,12 +75,15 @@ func (it *IPTracker) PredictFullConeNAT() bool {
}
// PredictEndpoint returns the current prediction of the external endpoint.
func (it *IPTracker) PredictEndpoint() string {
func (it *IPTracker) PredictEndpoint() netip.AddrPort {
it.gcStatements(it.clock.Now())
// The current strategy is simple: find the endpoint with most statements.
counts := make(map[string]int, len(it.statements))
maxcount, max := 0, ""
var (
counts = make(map[netip.AddrPort]int, len(it.statements))
maxcount int
max netip.AddrPort
)
for _, s := range it.statements {
c := counts[s.endpoint] + 1
counts[s.endpoint] = c
@ -91,7 +95,7 @@ func (it *IPTracker) PredictEndpoint() string {
}
// AddStatement records that a certain host thinks our external endpoint is the one given.
func (it *IPTracker) AddStatement(host, endpoint string) {
func (it *IPTracker) AddStatement(host netip.Addr, endpoint netip.AddrPort) {
now := it.clock.Now()
it.statements[host] = ipStatement{endpoint, now}
if time.Duration(now-it.lastStatementGC) >= it.window {
@ -101,7 +105,7 @@ func (it *IPTracker) AddStatement(host, endpoint string) {
// AddContact records that a packet containing our endpoint information has been sent to a
// certain host.
func (it *IPTracker) AddContact(host string) {
func (it *IPTracker) AddContact(host netip.Addr) {
now := it.clock.Now()
it.contact[host] = now
if time.Duration(now-it.lastContactGC) >= it.contactWindow {

View File

@ -19,6 +19,7 @@ package netutil
import (
crand "crypto/rand"
"fmt"
"net/netip"
"testing"
"time"
@ -42,37 +43,37 @@ func TestIPTracker(t *testing.T) {
tests := map[string][]iptrackTestEvent{
"minStatements": {
{opPredict, 0, "", ""},
{opStatement, 0, "127.0.0.1", "127.0.0.2"},
{opStatement, 0, "127.0.0.1:8000", "127.0.0.2"},
{opPredict, 1000, "", ""},
{opStatement, 1000, "127.0.0.1", "127.0.0.3"},
{opStatement, 1000, "127.0.0.1:8000", "127.0.0.3"},
{opPredict, 1000, "", ""},
{opStatement, 1000, "127.0.0.1", "127.0.0.4"},
{opPredict, 1000, "127.0.0.1", ""},
{opStatement, 1000, "127.0.0.1:8000", "127.0.0.4"},
{opPredict, 1000, "127.0.0.1:8000", ""},
},
"window": {
{opStatement, 0, "127.0.0.1", "127.0.0.2"},
{opStatement, 2000, "127.0.0.1", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1", "127.0.0.4"},
{opPredict, 10000, "127.0.0.1", ""},
{opStatement, 0, "127.0.0.1:8000", "127.0.0.2"},
{opStatement, 2000, "127.0.0.1:8000", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1:8000", "127.0.0.4"},
{opPredict, 10000, "127.0.0.1:8000", ""},
{opPredict, 10001, "", ""}, // first statement expired
{opStatement, 10100, "127.0.0.1", "127.0.0.2"},
{opPredict, 10200, "127.0.0.1", ""},
{opStatement, 10100, "127.0.0.1:8000", "127.0.0.2"},
{opPredict, 10200, "127.0.0.1:8000", ""},
},
"fullcone": {
{opContact, 0, "", "127.0.0.2"},
{opStatement, 10, "127.0.0.1", "127.0.0.2"},
{opStatement, 10, "127.0.0.1:8000", "127.0.0.2"},
{opContact, 2000, "", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1:8000", "127.0.0.3"},
{opContact, 3000, "", "127.0.0.4"},
{opStatement, 3010, "127.0.0.1", "127.0.0.4"},
{opStatement, 3010, "127.0.0.1:8000", "127.0.0.4"},
{opCheckFullCone, 3500, "false", ""},
},
"fullcone_2": {
{opContact, 0, "", "127.0.0.2"},
{opStatement, 10, "127.0.0.1", "127.0.0.2"},
{opStatement, 10, "127.0.0.1:8000", "127.0.0.2"},
{opContact, 2000, "", "127.0.0.3"},
{opStatement, 2010, "127.0.0.1", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1", "127.0.0.4"},
{opStatement, 2010, "127.0.0.1:8000", "127.0.0.3"},
{opStatement, 3000, "127.0.0.1:8000", "127.0.0.4"},
{opContact, 3010, "", "127.0.0.4"},
{opCheckFullCone, 3500, "true", ""},
},
@ -93,12 +94,19 @@ func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) {
clock.Run(evtime - time.Duration(clock.Now()))
switch ev.op {
case opStatement:
it.AddStatement(ev.from, ev.ip)
it.AddStatement(netip.MustParseAddr(ev.from), netip.MustParseAddrPort(ev.ip))
case opContact:
it.AddContact(ev.from)
it.AddContact(netip.MustParseAddr(ev.from))
case opPredict:
if pred := it.PredictEndpoint(); pred != ev.ip {
t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip)
pred := it.PredictEndpoint()
if ev.ip == "" {
if pred.IsValid() {
t.Errorf("op %d: wrong prediction %v, expected invalid", i, pred)
}
} else {
if pred != netip.MustParseAddrPort(ev.ip) {
t.Errorf("op %d: wrong prediction %v, want %q", i, pred, ev.ip)
}
}
case opCheckFullCone:
pred := fmt.Sprintf("%t", it.PredictFullConeNAT())
@ -121,12 +129,11 @@ func TestIPTrackerForceGC(t *testing.T) {
it.clock = &clock
for i := 0; i < 5*max; i++ {
e1 := make([]byte, 4)
e2 := make([]byte, 4)
crand.Read(e1)
crand.Read(e2)
it.AddStatement(string(e1), string(e2))
it.AddContact(string(e1))
var e1, e2 [4]byte
crand.Read(e1[:])
crand.Read(e2[:])
it.AddStatement(netip.AddrFrom4(e1), netip.AddrPortFrom(netip.AddrFrom4(e2), 9000))
it.AddContact(netip.AddrFrom4(e1))
clock.Run(rate)
}
if len(it.contact) > 2*max {

View File

@ -22,21 +22,19 @@ import (
"errors"
"fmt"
"net"
"sort"
"net/netip"
"slices"
"strings"
"golang.org/x/exp/maps"
)
var lan4, lan6, special4, special6 Netlist
var special4, special6 Netlist
func init() {
// Lists from RFC 5735, RFC 5156,
// https://www.iana.org/assignments/iana-ipv4-special-registry/
lan4.Add("0.0.0.0/8") // "This" network
lan4.Add("10.0.0.0/8") // Private Use
lan4.Add("172.16.0.0/12") // Private Use
lan4.Add("192.168.0.0/16") // Private Use
lan6.Add("fe80::/10") // Link-Local
lan6.Add("fc00::/7") // Unique-Local
special4.Add("0.0.0.0/8") // "This" network.
special4.Add("192.0.0.0/29") // IPv4 Service Continuity
special4.Add("192.0.0.9/32") // PCP Anycast
special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery
@ -66,7 +64,7 @@ func init() {
}
// Netlist is a list of IP networks.
type Netlist []net.IPNet
type Netlist []netip.Prefix
// ParseNetlist parses a comma-separated list of CIDR masks.
// Whitespace and extra commas are ignored.
@ -78,11 +76,11 @@ func ParseNetlist(s string) (*Netlist, error) {
if mask == "" {
continue
}
_, n, err := net.ParseCIDR(mask)
prefix, err := netip.ParsePrefix(mask)
if err != nil {
return nil, err
}
l = append(l, *n)
l = append(l, prefix)
}
return &l, nil
}
@ -103,11 +101,11 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
return err
}
for _, mask := range masks {
_, n, err := net.ParseCIDR(mask)
prefix, err := netip.ParsePrefix(mask)
if err != nil {
return err
}
*l = append(*l, *n)
*l = append(*l, prefix)
}
return nil
}
@ -115,15 +113,20 @@ func (l *Netlist) UnmarshalTOML(fn func(interface{}) error) error {
// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is
// intended to be used for setting up static lists.
func (l *Netlist) Add(cidr string) {
_, n, err := net.ParseCIDR(cidr)
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
panic(err)
}
*l = append(*l, *n)
*l = append(*l, prefix)
}
// Contains reports whether the given IP is contained in the list.
func (l *Netlist) Contains(ip net.IP) bool {
return l.ContainsAddr(IPToAddr(ip))
}
// ContainsAddr reports whether the given IP is contained in the list.
func (l *Netlist) ContainsAddr(ip netip.Addr) bool {
if l == nil {
return false
}
@ -137,25 +140,39 @@ func (l *Netlist) Contains(ip net.IP) bool {
// IsLAN reports whether an IP is a local network address.
func IsLAN(ip net.IP) bool {
return AddrIsLAN(IPToAddr(ip))
}
// AddrIsLAN reports whether an IP is a local network address.
func AddrIsLAN(ip netip.Addr) bool {
if ip.Is4In6() {
ip = netip.AddrFrom4(ip.As4())
}
if ip.IsLoopback() {
return true
}
if v4 := ip.To4(); v4 != nil {
return lan4.Contains(v4)
}
return lan6.Contains(ip)
return ip.IsPrivate() || ip.IsLinkLocalUnicast()
}
// IsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses.
func IsSpecialNetwork(ip net.IP) bool {
return AddrIsSpecialNetwork(IPToAddr(ip))
}
// AddrIsSpecialNetwork reports whether an IP is located in a special-use network range
// This includes broadcast, multicast and documentation addresses.
func AddrIsSpecialNetwork(ip netip.Addr) bool {
if ip.Is4In6() {
ip = netip.AddrFrom4(ip.As4())
}
if ip.IsMulticast() {
return true
}
if v4 := ip.To4(); v4 != nil {
return special4.Contains(v4)
if ip.Is4() {
return special4.ContainsAddr(ip)
}
return special6.Contains(ip)
return special6.ContainsAddr(ip)
}
var (
@ -175,19 +192,31 @@ var (
// - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable.
func CheckRelayIP(sender, addr net.IP) error {
if len(addr) != net.IPv4len && len(addr) != net.IPv6len {
return CheckRelayAddr(IPToAddr(sender), IPToAddr(addr))
}
// CheckRelayAddr reports whether an IP relayed from the given sender IP
// is a valid connection target.
//
// There are four rules:
// - Special network addresses are never valid.
// - Loopback addresses are OK if relayed by a loopback host.
// - LAN addresses are OK if relayed by a LAN host.
// - All other addresses are always acceptable.
func CheckRelayAddr(sender, addr netip.Addr) error {
if !addr.IsValid() {
return errInvalid
}
if addr.IsUnspecified() {
return errUnspecified
}
if IsSpecialNetwork(addr) {
if AddrIsSpecialNetwork(addr) {
return errSpecial
}
if addr.IsLoopback() && !sender.IsLoopback() {
return errLoopback
}
if IsLAN(addr) && !IsLAN(sender) {
if AddrIsLAN(addr) && !AddrIsLAN(sender) {
return errLAN
}
return nil
@ -221,17 +250,22 @@ type DistinctNetSet struct {
Subnet uint // number of common prefix bits
Limit uint // maximum number of IPs in each subnet
members map[string]uint
buf net.IP
members map[netip.Prefix]uint
}
// Add adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) Add(ip net.IP) bool {
return s.AddAddr(IPToAddr(ip))
}
// AddAddr adds an IP address to the set. It returns false (and doesn't add the IP) if the
// number of existing IPs in the defined range exceeds the limit.
func (s *DistinctNetSet) AddAddr(ip netip.Addr) bool {
key := s.key(ip)
n := s.members[string(key)]
n := s.members[key]
if n < s.Limit {
s.members[string(key)] = n + 1
s.members[key] = n + 1
return true
}
return false
@ -239,20 +273,30 @@ func (s *DistinctNetSet) Add(ip net.IP) bool {
// Remove removes an IP from the set.
func (s *DistinctNetSet) Remove(ip net.IP) {
s.RemoveAddr(IPToAddr(ip))
}
// RemoveAddr removes an IP from the set.
func (s *DistinctNetSet) RemoveAddr(ip netip.Addr) {
key := s.key(ip)
if n, ok := s.members[string(key)]; ok {
if n, ok := s.members[key]; ok {
if n == 1 {
delete(s.members, string(key))
delete(s.members, key)
} else {
s.members[string(key)] = n - 1
s.members[key] = n - 1
}
}
}
// Contains whether the given IP is contained in the set.
func (s DistinctNetSet) Contains(ip net.IP) bool {
return s.ContainsAddr(IPToAddr(ip))
}
// ContainsAddr whether the given IP is contained in the set.
func (s DistinctNetSet) ContainsAddr(ip netip.Addr) bool {
key := s.key(ip)
_, ok := s.members[string(key)]
_, ok := s.members[key]
return ok
}
@ -265,54 +309,30 @@ func (s DistinctNetSet) Len() int {
return int(n)
}
// key encodes the map key for an address into a temporary buffer.
//
// The first byte of key is '4' or '6' to distinguish IPv4/IPv6 address types.
// The remainder of the key is the IP, truncated to the number of bits.
func (s *DistinctNetSet) key(ip net.IP) net.IP {
// key returns the map key for ip.
func (s *DistinctNetSet) key(ip netip.Addr) netip.Prefix {
// Lazily initialize storage.
if s.members == nil {
s.members = make(map[string]uint)
s.buf = make(net.IP, 17)
s.members = make(map[netip.Prefix]uint)
}
// Canonicalize ip and bits.
typ := byte('6')
if ip4 := ip.To4(); ip4 != nil {
typ, ip = '4', ip4
p, err := ip.Prefix(int(s.Subnet))
if err != nil {
panic(err)
}
bits := s.Subnet
if bits > uint(len(ip)*8) {
bits = uint(len(ip) * 8)
}
// Encode the prefix into s.buf.
nb := int(bits / 8)
mask := ^byte(0xFF >> (bits % 8))
s.buf[0] = typ
buf := append(s.buf[:1], ip[:nb]...)
if nb < len(ip) && mask != 0 {
buf = append(buf, ip[nb]&mask)
}
return buf
return p
}
// String implements fmt.Stringer
func (s DistinctNetSet) String() string {
keys := maps.Keys(s.members)
slices.SortFunc(keys, func(a, b netip.Prefix) int {
return strings.Compare(a.String(), b.String())
})
var buf bytes.Buffer
buf.WriteString("{")
keys := make([]string, 0, len(s.members))
for k := range s.members {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
var ip net.IP
if k[0] == '4' {
ip = make(net.IP, 4)
} else {
ip = make(net.IP, 16)
}
copy(ip, k[1:])
fmt.Fprintf(&buf, "%v×%d", ip, s.members[k])
fmt.Fprintf(&buf, "%v×%d", k, s.members[k])
if i != len(keys)-1 {
buf.WriteString(" ")
}

View File

@ -18,7 +18,9 @@ package netutil
import (
"fmt"
"math/rand"
"net"
"net/netip"
"reflect"
"testing"
"testing/quick"
@ -29,7 +31,7 @@ import (
func TestParseNetlist(t *testing.T) {
var tests = []struct {
input string
wantErr error
wantErr string
wantList *Netlist
}{
{
@ -38,25 +40,27 @@ func TestParseNetlist(t *testing.T) {
},
{
input: "127.0.0.0/8",
wantErr: nil,
wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}},
wantList: &Netlist{netip.MustParsePrefix("127.0.0.0/8")},
},
{
input: "127.0.0.0/44",
wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"},
wantErr: `netip.ParsePrefix("127.0.0.0/44"): prefix length out of range`,
},
{
input: "127.0.0.0/16, 23.23.23.23/24,",
wantList: &Netlist{
{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)},
{IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)},
netip.MustParsePrefix("127.0.0.0/16"),
netip.MustParsePrefix("23.23.23.23/24"),
},
},
}
for _, test := range tests {
l, err := ParseNetlist(test.input)
if !reflect.DeepEqual(err, test.wantErr) {
if err == nil && test.wantErr != "" {
t.Errorf("%q: got no error, expected %q", test.input, test.wantErr)
continue
} else if err != nil && err.Error() != test.wantErr {
t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr)
continue
}
@ -70,14 +74,12 @@ func TestParseNetlist(t *testing.T) {
func TestNilNetListContains(t *testing.T) {
var list *Netlist
checkContains(t, list.Contains, nil, []string{"1.2.3.4"})
checkContains(t, list.Contains, list.ContainsAddr, nil, []string{"1.2.3.4"})
}
func TestIsLAN(t *testing.T) {
checkContains(t, IsLAN,
checkContains(t, IsLAN, AddrIsLAN,
[]string{ // included
"0.0.0.0",
"0.2.0.8",
"127.0.0.1",
"10.0.1.1",
"10.22.0.3",
@ -86,25 +88,35 @@ func TestIsLAN(t *testing.T) {
"fe80::f4a1:8eff:fec5:9d9d",
"febf::ab32:2233",
"fc00::4",
// 4-in-6
"::ffff:127.0.0.1",
"::ffff:10.10.0.2",
},
[]string{ // excluded
"192.0.2.1",
"1.0.0.0",
"172.32.0.1",
"fec0::2233",
// 4-in-6
"::ffff:88.99.100.2",
},
)
}
func TestIsSpecialNetwork(t *testing.T) {
checkContains(t, IsSpecialNetwork,
checkContains(t, IsSpecialNetwork, AddrIsSpecialNetwork,
[]string{ // included
"0.0.0.0",
"0.2.0.8",
"192.0.2.1",
"192.0.2.44",
"2001:db8:85a3:8d3:1319:8a2e:370:7348",
"255.255.255.255",
"224.0.0.22", // IPv4 multicast
"ff05::1:3", // IPv6 multicast
// 4-in-6
"::ffff:255.255.255.255",
"::ffff:192.0.2.1",
},
[]string{ // excluded
"192.0.3.1",
@ -115,15 +127,21 @@ func TestIsSpecialNetwork(t *testing.T) {
)
}
func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) {
func checkContains(t *testing.T, fn func(net.IP) bool, fn2 func(netip.Addr) bool, inc, exc []string) {
for _, s := range inc {
if !fn(parseIP(s)) {
t.Error("returned false for included address", s)
t.Error("returned false for included net.IP", s)
}
if !fn2(netip.MustParseAddr(s)) {
t.Error("returned false for included netip.Addr", s)
}
}
for _, s := range exc {
if fn(parseIP(s)) {
t.Error("returned true for excluded address", s)
t.Error("returned true for excluded net.IP", s)
}
if fn2(netip.MustParseAddr(s)) {
t.Error("returned true for excluded netip.Addr", s)
}
}
}
@ -244,14 +262,22 @@ func TestDistinctNetSet(t *testing.T) {
}
func TestDistinctNetSetAddRemove(t *testing.T) {
cfg := &quick.Config{}
fn := func(ips []net.IP) bool {
cfg := &quick.Config{
Values: func(s []reflect.Value, rng *rand.Rand) {
slice := make([]netip.Addr, rng.Intn(20)+1)
for i := range slice {
slice[i] = RandomAddr(rng, false)
}
s[0] = reflect.ValueOf(slice)
},
}
fn := func(ips []netip.Addr) bool {
s := DistinctNetSet{Limit: 3, Subnet: 2}
for _, ip := range ips {
s.Add(ip)
s.AddAddr(ip)
}
for _, ip := range ips {
s.Remove(ip)
s.RemoveAddr(ip)
}
return s.Len() == 0
}

View File

@ -905,14 +905,14 @@ func (srv *Server) listenLoop() {
break
}
remoteIP := netutil.AddrIP(fd.RemoteAddr())
remoteIP := netutil.AddrAddr(fd.RemoteAddr())
if err := srv.checkInboundConn(remoteIP); err != nil {
srv.log.Debug("Rejected inbound connection", "addr", fd.RemoteAddr(), "err", err)
fd.Close()
slots <- struct{}{}
continue
}
if remoteIP != nil {
if remoteIP.IsValid() {
fd = newMeteredConn(fd)
serveMeter.Mark(1)
srv.log.Trace("Accepted connection", "addr", fd.RemoteAddr())
@ -924,18 +924,19 @@ func (srv *Server) listenLoop() {
}
}
func (srv *Server) checkInboundConn(remoteIP net.IP) error {
if remoteIP == nil {
func (srv *Server) checkInboundConn(remoteIP netip.Addr) error {
if !remoteIP.IsValid() {
// This case happens for internal test connections without remote address.
return nil
}
// Reject connections that do not match NetRestrict.
if srv.NetRestrict != nil && !srv.NetRestrict.Contains(remoteIP) {
if srv.NetRestrict != nil && !srv.NetRestrict.ContainsAddr(remoteIP) {
return errors.New("not in netrestrict list")
}
// Reject Internet peers that try too often.
now := srv.clock.Now()
srv.inboundHistory.expire(now, nil)
if !netutil.IsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
if !netutil.AddrIsLAN(remoteIP) && srv.inboundHistory.contains(remoteIP.String()) {
return errors.New("too many attempts")
}
srv.inboundHistory.add(remoteIP.String(), now.Add(inboundThrottleTime))
@ -1108,7 +1109,7 @@ func (srv *Server) NodeInfo() *NodeInfo {
Name: srv.Name,
Enode: node.URLv4(),
ID: node.ID().String(),
IP: node.IP().String(),
IP: node.IPAddr().String(),
ListenAddr: srv.ListenAddr,
Protocols: make(map[string]interface{}),
}

View File

@ -18,6 +18,7 @@ package p2p
import (
"net"
"net/netip"
"sync/atomic"
"testing"
"time"
@ -64,8 +65,8 @@ func TestServerPortMapping(t *testing.T) {
t.Error("wrong request count:", reqCount)
}
enr := srv.LocalNode().Node()
if enr.IP().String() != "192.0.2.0" {
t.Error("wrong IP in ENR:", enr.IP())
if enr.IPAddr() != netip.MustParseAddr("192.0.2.0") {
t.Error("wrong IP in ENR:", enr.IPAddr())
}
if enr.TCP() != 30000 {
t.Error("wrong TCP port in ENR:", enr.TCP())