p2p/discover: improved node revalidation (#29572)

Node discovery periodically revalidates the nodes in its table by sending PING, checking
if they are still alive. I recently noticed some issues with the implementation of this
process, which can cause strange results such as nodes dropping unexpectedly, certain
nodes not getting revalidated often enough, and bad results being returned to incoming
FINDNODE queries.

In this change, the revalidation process is improved with the following logic:

- We maintain two 'revalidation lists' containing the table nodes, named 'fast' and 'slow'.
- The process chooses random nodes from each list on a randomized interval, the interval being
  faster for the 'fast' list, and performs revalidation for the chosen node.
- Whenever a node is newly inserted into the table, it goes into the 'fast' list.
  Once validation passes, it transfers to the 'slow' list. If a request fails, or the
  node changes endpoint, it transfers back into 'fast'.
- livenessChecks is incremented by one for successful checks. Unlike the old implementation,
  we will not drop the node on the first failing check. We instead quickly decay the
  livenessChecks give it another chance.
- Order of nodes in bucket doesn't matter anymore.

I am also adding a debug API endpoint to dump the node table content.

Co-authored-by: Martin HS <martin@swende.se>
This commit is contained in:
Felix Lange 2024-05-23 14:26:09 +02:00 committed by GitHub
parent 70bee977d6
commit 6a9158bb1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 935 additions and 528 deletions

View File

@ -20,6 +20,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -28,9 +29,11 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/flags" "github.com/ethereum/go-ethereum/internal/flags"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
"github.com/ethereum/go-ethereum/rpc"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
@ -45,6 +48,7 @@ var (
discv4ResolveJSONCommand, discv4ResolveJSONCommand,
discv4CrawlCommand, discv4CrawlCommand,
discv4TestCommand, discv4TestCommand,
discv4ListenCommand,
}, },
} }
discv4PingCommand = &cli.Command{ discv4PingCommand = &cli.Command{
@ -75,6 +79,14 @@ var (
Flags: discoveryNodeFlags, Flags: discoveryNodeFlags,
ArgsUsage: "<nodes.json file>", ArgsUsage: "<nodes.json file>",
} }
discv4ListenCommand = &cli.Command{
Name: "listen",
Usage: "Runs a discovery node",
Action: discv4Listen,
Flags: flags.Merge(discoveryNodeFlags, []cli.Flag{
httpAddrFlag,
}),
}
discv4CrawlCommand = &cli.Command{ discv4CrawlCommand = &cli.Command{
Name: "crawl", Name: "crawl",
Usage: "Updates a nodes.json file with random nodes found in the DHT", Usage: "Updates a nodes.json file with random nodes found in the DHT",
@ -131,6 +143,10 @@ var (
Usage: "Enode of the remote node under test", Usage: "Enode of the remote node under test",
EnvVars: []string{"REMOTE_ENODE"}, EnvVars: []string{"REMOTE_ENODE"},
} }
httpAddrFlag = &cli.StringFlag{
Name: "rpc",
Usage: "HTTP server listening address",
}
) )
var discoveryNodeFlags = []cli.Flag{ var discoveryNodeFlags = []cli.Flag{
@ -154,6 +170,27 @@ func discv4Ping(ctx *cli.Context) error {
return nil return nil
} }
func discv4Listen(ctx *cli.Context) error {
disc, _ := startV4(ctx)
defer disc.Close()
fmt.Println(disc.Self())
httpAddr := ctx.String(httpAddrFlag.Name)
if httpAddr == "" {
// Non-HTTP mode.
select {}
}
api := &discv4API{disc}
log.Info("Starting RPC API server", "addr", httpAddr)
srv := rpc.NewServer()
srv.RegisterName("discv4", api)
http.DefaultServeMux.Handle("/", srv)
httpsrv := http.Server{Addr: httpAddr, Handler: http.DefaultServeMux}
return httpsrv.ListenAndServe()
}
func discv4RequestRecord(ctx *cli.Context) error { func discv4RequestRecord(ctx *cli.Context) error {
n := getNodeArg(ctx) n := getNodeArg(ctx)
disc, _ := startV4(ctx) disc, _ := startV4(ctx)
@ -362,3 +399,23 @@ func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) {
} }
return nodes, nil return nodes, nil
} }
type discv4API struct {
host *discover.UDPv4
}
func (api *discv4API) LookupRandom(n int) (ns []*enode.Node) {
it := api.host.RandomNodes()
for len(ns) < n && it.Next() {
ns = append(ns, it.Node())
}
return ns
}
func (api *discv4API) Buckets() [][]discover.BucketNode {
return api.host.TableBuckets()
}
func (api *discv4API) Self() *enode.Node {
return api.host.Self()
}

View File

@ -58,7 +58,7 @@ func (h *bufHandler) Handle(_ context.Context, r slog.Record) error {
} }
func (h *bufHandler) Enabled(_ context.Context, lvl slog.Level) bool { func (h *bufHandler) Enabled(_ context.Context, lvl slog.Level) bool {
return lvl <= h.level return lvl >= h.level
} }
func (h *bufHandler) WithAttrs(attrs []slog.Attr) slog.Handler { func (h *bufHandler) WithAttrs(attrs []slog.Attr) slog.Handler {

View File

@ -26,6 +26,7 @@ import (
"github.com/ethereum/go-ethereum/internal/debug" "github.com/ethereum/go-ethereum/internal/debug"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
) )
@ -39,6 +40,9 @@ func (n *Node) apis() []rpc.API {
}, { }, {
Namespace: "debug", Namespace: "debug",
Service: debug.Handler, Service: debug.Handler,
}, {
Namespace: "debug",
Service: &p2pDebugAPI{n},
}, { }, {
Namespace: "web3", Namespace: "web3",
Service: &web3API{n}, Service: &web3API{n},
@ -333,3 +337,16 @@ func (s *web3API) ClientVersion() string {
func (s *web3API) Sha3(input hexutil.Bytes) hexutil.Bytes { func (s *web3API) Sha3(input hexutil.Bytes) hexutil.Bytes {
return crypto.Keccak256(input) return crypto.Keccak256(input)
} }
// p2pDebugAPI provides access to p2p internals for debugging.
type p2pDebugAPI struct {
stack *Node
}
func (s *p2pDebugAPI) DiscoveryV4Table() [][]discover.BucketNode {
disc := s.stack.server.DiscoveryV4()
if disc != nil {
return disc.TableBuckets()
}
return nil
}

View File

@ -18,7 +18,11 @@ package discover
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
crand "crypto/rand"
"encoding/binary"
"math/rand"
"net" "net"
"sync"
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/common/mclock"
@ -62,7 +66,7 @@ type Config struct {
func (cfg Config) withDefaults() Config { func (cfg Config) withDefaults() Config {
// Node table configuration: // Node table configuration:
if cfg.PingInterval == 0 { if cfg.PingInterval == 0 {
cfg.PingInterval = 10 * time.Second cfg.PingInterval = 3 * time.Second
} }
if cfg.RefreshInterval == 0 { if cfg.RefreshInterval == 0 {
cfg.RefreshInterval = 30 * time.Minute cfg.RefreshInterval = 30 * time.Minute
@ -92,3 +96,44 @@ type ReadPacket struct {
Data []byte Data []byte
Addr *net.UDPAddr Addr *net.UDPAddr
} }
type randomSource interface {
Intn(int) int
Int63n(int64) int64
Shuffle(int, func(int, int))
}
// reseedingRandom is a random number generator that tracks when it was last re-seeded.
type reseedingRandom struct {
mu sync.Mutex
cur *rand.Rand
}
func (r *reseedingRandom) seed() {
var b [8]byte
crand.Read(b[:])
seed := binary.BigEndian.Uint64(b[:])
new := rand.New(rand.NewSource(int64(seed)))
r.mu.Lock()
r.cur = new
r.mu.Unlock()
}
func (r *reseedingRandom) Intn(n int) int {
r.mu.Lock()
defer r.mu.Unlock()
return r.cur.Intn(n)
}
func (r *reseedingRandom) Int63n(n int64) int64 {
r.mu.Lock()
defer r.mu.Unlock()
return r.cur.Int63n(n)
}
func (r *reseedingRandom) Shuffle(n int, swap func(i, j int)) {
r.mu.Lock()
defer r.mu.Unlock()
r.cur.Shuffle(n, swap)
}

View File

@ -140,32 +140,13 @@ func (it *lookup) slowdown() {
} }
func (it *lookup) query(n *node, reply chan<- []*node) { func (it *lookup) query(n *node, reply chan<- []*node) {
fails := it.tab.db.FindFails(n.ID(), n.IP())
r, err := it.queryfunc(n) r, err := it.queryfunc(n)
if errors.Is(err, errClosed) { if !errors.Is(err, errClosed) { // avoid recording failures on shutdown.
// Avoid recording failures on shutdown. success := len(r) > 0
reply <- nil it.tab.trackRequest(n, success, r)
return if err != nil {
} else if len(r) == 0 { it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "err", err)
fails++
it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails)
// Remove the node from the local table if it fails to return anything useful too
// many times, but only if there are enough other nodes in the bucket.
dropped := false
if fails >= maxFindnodeFailures && it.tab.bucketLen(n.ID()) >= bucketSize/2 {
dropped = true
it.tab.delete(n)
} }
it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "failcount", fails, "dropped", dropped, "err", err)
} else if fails > 0 {
// Reset failure counter because it counts _consecutive_ failures.
it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0)
}
// Grab as many nodes as possible. Some of them might not be alive anymore, but we'll
// just remove those again during revalidation.
for _, n := range r {
it.tab.addSeenNode(n)
} }
reply <- r reply <- r
} }

View File

@ -29,12 +29,22 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
) )
type BucketNode struct {
Node *enode.Node `json:"node"`
AddedToTable time.Time `json:"addedToTable"`
AddedToBucket time.Time `json:"addedToBucket"`
Checks int `json:"checks"`
Live bool `json:"live"`
}
// node represents a host on the network. // node represents a host on the network.
// The fields of Node may not be modified. // The fields of Node may not be modified.
type node struct { type node struct {
enode.Node *enode.Node
addedAt time.Time // time when the node was added to the table addedToTable time.Time // first time node was added to bucket or replacement list
addedToBucket time.Time // time it was added in the actual bucket
livenessChecks uint // how often liveness was checked livenessChecks uint // how often liveness was checked
isValidatedLive bool // true if existence of node is considered validated right now
} }
type encPubkey [64]byte type encPubkey [64]byte
@ -65,7 +75,7 @@ func (e encPubkey) id() enode.ID {
} }
func wrapNode(n *enode.Node) *node { func wrapNode(n *enode.Node) *node {
return &node{Node: *n} return &node{Node: n}
} }
func wrapNodes(ns []*enode.Node) []*node { func wrapNodes(ns []*enode.Node) []*node {
@ -77,7 +87,7 @@ func wrapNodes(ns []*enode.Node) []*node {
} }
func unwrapNode(n *node) *enode.Node { func unwrapNode(n *node) *enode.Node {
return &n.Node return n.Node
} }
func unwrapNodes(ns []*node) []*enode.Node { func unwrapNodes(ns []*node) []*enode.Node {

View File

@ -24,16 +24,15 @@ package discover
import ( import (
"context" "context"
crand "crypto/rand"
"encoding/binary"
"fmt" "fmt"
mrand "math/rand"
"net" "net"
"slices"
"sort" "sort"
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -55,7 +54,6 @@ const (
bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24
tableIPLimit, tableSubnet = 10, 24 tableIPLimit, tableSubnet = 10, 24
copyNodesInterval = 30 * time.Second
seedMinTableTime = 5 * time.Minute seedMinTableTime = 5 * time.Minute
seedCount = 30 seedCount = 30
seedMaxAge = 5 * 24 * time.Hour seedMaxAge = 5 * 24 * time.Hour
@ -68,8 +66,9 @@ type Table struct {
mutex sync.Mutex // protects buckets, bucket content, nursery, rand mutex sync.Mutex // protects buckets, bucket content, nursery, rand
buckets [nBuckets]*bucket // index of known nodes by distance buckets [nBuckets]*bucket // index of known nodes by distance
nursery []*node // bootstrap nodes nursery []*node // bootstrap nodes
rand *mrand.Rand // source of randomness, periodically reseeded rand reseedingRandom // source of randomness, periodically reseeded
ips netutil.DistinctNetSet ips netutil.DistinctNetSet
revalidation tableRevalidation
db *enode.DB // database of known nodes db *enode.DB // database of known nodes
net transport net transport
@ -78,6 +77,10 @@ type Table struct {
// loop channels // loop channels
refreshReq chan chan struct{} refreshReq chan chan struct{}
revalResponseCh chan revalidationResponse
addNodeCh chan addNodeOp
addNodeHandled chan bool
trackRequestCh chan trackRequestOp
initDone chan struct{} initDone chan struct{}
closeReq chan struct{} closeReq chan struct{}
closed chan struct{} closed chan struct{}
@ -104,6 +107,17 @@ type bucket struct {
index int index int
} }
type addNodeOp struct {
node *node
isInbound bool
}
type trackRequestOp struct {
node *node
foundNodes []*node
success bool
}
func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) { func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) {
cfg = cfg.withDefaults() cfg = cfg.withDefaults()
tab := &Table{ tab := &Table{
@ -112,56 +126,49 @@ func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) {
cfg: cfg, cfg: cfg,
log: cfg.Log, log: cfg.Log,
refreshReq: make(chan chan struct{}), refreshReq: make(chan chan struct{}),
revalResponseCh: make(chan revalidationResponse),
addNodeCh: make(chan addNodeOp),
addNodeHandled: make(chan bool),
trackRequestCh: make(chan trackRequestOp),
initDone: make(chan struct{}), initDone: make(chan struct{}),
closeReq: make(chan struct{}), closeReq: make(chan struct{}),
closed: make(chan struct{}), closed: make(chan struct{}),
rand: mrand.New(mrand.NewSource(0)),
ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit},
} }
if err := tab.setFallbackNodes(cfg.Bootnodes); err != nil {
return nil, err
}
for i := range tab.buckets { for i := range tab.buckets {
tab.buckets[i] = &bucket{ tab.buckets[i] = &bucket{
index: i, index: i,
ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit},
} }
} }
tab.seedRand() tab.rand.seed()
tab.revalidation.init(&cfg)
// initial table content
if err := tab.setFallbackNodes(cfg.Bootnodes); err != nil {
return nil, err
}
tab.loadSeedNodes() tab.loadSeedNodes()
return tab, nil return tab, nil
} }
func newMeteredTable(t transport, db *enode.DB, cfg Config) (*Table, error) {
tab, err := newTable(t, db, cfg)
if err != nil {
return nil, err
}
if metrics.Enabled {
tab.nodeAddedHook = func(b *bucket, n *node) {
bucketsCounter[b.index].Inc(1)
}
tab.nodeRemovedHook = func(b *bucket, n *node) {
bucketsCounter[b.index].Dec(1)
}
}
return tab, nil
}
// Nodes returns all nodes contained in the table. // Nodes returns all nodes contained in the table.
func (tab *Table) Nodes() []*enode.Node { func (tab *Table) Nodes() [][]BucketNode {
if !tab.isInitDone() {
return nil
}
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
var nodes []*enode.Node nodes := make([][]BucketNode, len(tab.buckets))
for _, b := range &tab.buckets { for i, b := range &tab.buckets {
for _, n := range b.entries { nodes[i] = make([]BucketNode, len(b.entries))
nodes = append(nodes, unwrapNode(n)) for j, n := range b.entries {
nodes[i][j] = BucketNode{
Node: n.Node,
Checks: int(n.livenessChecks),
Live: n.isValidatedLive,
AddedToTable: n.addedToTable,
AddedToBucket: n.addedToBucket,
}
} }
} }
return nodes return nodes
@ -171,15 +178,6 @@ func (tab *Table) self() *enode.Node {
return tab.net.Self() return tab.net.Self()
} }
func (tab *Table) seedRand() {
var b [8]byte
crand.Read(b[:])
tab.mutex.Lock()
tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:])))
tab.mutex.Unlock()
}
// getNode returns the node with the given ID or nil if it isn't in the table. // getNode returns the node with the given ID or nil if it isn't in the table.
func (tab *Table) getNode(id enode.ID) *enode.Node { func (tab *Table) getNode(id enode.ID) *enode.Node {
tab.mutex.Lock() tab.mutex.Lock()
@ -239,52 +237,173 @@ func (tab *Table) refresh() <-chan struct{} {
return done return done
} }
// loop schedules runs of doRefresh, doRevalidate and copyLiveNodes. // findnodeByID returns the n nodes in the table that are closest to the given id.
// This is used by the FINDNODE/v4 handler.
//
// The preferLive parameter says whether the caller wants liveness-checked results. If
// preferLive is true and the table contains any verified nodes, the result will not
// contain unverified nodes. However, if there are no verified nodes at all, the result
// will contain unverified nodes.
func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance {
tab.mutex.Lock()
defer tab.mutex.Unlock()
// Scan all buckets. There might be a better way to do this, but there aren't that many
// buckets, so this solution should be fine. The worst-case complexity of this loop
// is O(tab.len() * nresults).
nodes := &nodesByDistance{target: target}
liveNodes := &nodesByDistance{target: target}
for _, b := range &tab.buckets {
for _, n := range b.entries {
nodes.push(n, nresults)
if preferLive && n.isValidatedLive {
liveNodes.push(n, nresults)
}
}
}
if preferLive && len(liveNodes.entries) > 0 {
return liveNodes
}
return nodes
}
// appendLiveNodes adds nodes at the given distance to the result slice.
// This is used by the FINDNODE/v5 handler.
func (tab *Table) appendLiveNodes(dist uint, result []*enode.Node) []*enode.Node {
if dist > 256 {
return result
}
if dist == 0 {
return append(result, tab.self())
}
tab.mutex.Lock()
for _, n := range tab.bucketAtDistance(int(dist)).entries {
if n.isValidatedLive {
result = append(result, n.Node)
}
}
tab.mutex.Unlock()
// Shuffle result to avoid always returning same nodes in FINDNODE/v5.
tab.rand.Shuffle(len(result), func(i, j int) {
result[i], result[j] = result[j], result[i]
})
return result
}
// len returns the number of nodes in the table.
func (tab *Table) len() (n int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, b := range &tab.buckets {
n += len(b.entries)
}
return n
}
// addFoundNode adds a node which may not be live. If the bucket has space available,
// adding the node succeeds immediately. Otherwise, the node is added to the replacements
// list.
//
// The caller must not hold tab.mutex.
func (tab *Table) addFoundNode(n *node) bool {
op := addNodeOp{node: n, isInbound: false}
select {
case tab.addNodeCh <- op:
return <-tab.addNodeHandled
case <-tab.closeReq:
return false
}
}
// addInboundNode adds a node from an inbound contact. If the bucket has no space, the
// node is added to the replacements list.
//
// There is an additional safety measure: if the table is still initializing the node is
// not added. This prevents an attack where the table could be filled by just sending ping
// repeatedly.
//
// The caller must not hold tab.mutex.
func (tab *Table) addInboundNode(n *node) bool {
op := addNodeOp{node: n, isInbound: true}
select {
case tab.addNodeCh <- op:
return <-tab.addNodeHandled
case <-tab.closeReq:
return false
}
}
func (tab *Table) trackRequest(n *node, success bool, foundNodes []*node) {
op := trackRequestOp{n, foundNodes, success}
select {
case tab.trackRequestCh <- op:
case <-tab.closeReq:
}
}
// loop is the main loop of Table.
func (tab *Table) loop() { func (tab *Table) loop() {
var ( var (
revalidate = time.NewTimer(tab.nextRevalidateTime())
refresh = time.NewTimer(tab.nextRefreshTime()) refresh = time.NewTimer(tab.nextRefreshTime())
copyNodes = time.NewTicker(copyNodesInterval)
refreshDone = make(chan struct{}) // where doRefresh reports completion refreshDone = make(chan struct{}) // where doRefresh reports completion
revalidateDone chan struct{} // where doRevalidate reports completion
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
revalTimer = mclock.NewAlarm(tab.cfg.Clock)
reseedRandTimer = time.NewTicker(10 * time.Minute)
) )
defer refresh.Stop() defer refresh.Stop()
defer revalidate.Stop() defer revalTimer.Stop()
defer copyNodes.Stop() defer reseedRandTimer.Stop()
// Start initial refresh. // Start initial refresh.
go tab.doRefresh(refreshDone) go tab.doRefresh(refreshDone)
loop: loop:
for { for {
nextTime := tab.revalidation.run(tab, tab.cfg.Clock.Now())
revalTimer.Schedule(nextTime)
select { select {
case <-reseedRandTimer.C:
tab.rand.seed()
case <-revalTimer.C():
case r := <-tab.revalResponseCh:
tab.revalidation.handleResponse(tab, r)
case op := <-tab.addNodeCh:
tab.mutex.Lock()
ok := tab.handleAddNode(op)
tab.mutex.Unlock()
tab.addNodeHandled <- ok
case op := <-tab.trackRequestCh:
tab.handleTrackRequest(op)
case <-refresh.C: case <-refresh.C:
tab.seedRand()
if refreshDone == nil { if refreshDone == nil {
refreshDone = make(chan struct{}) refreshDone = make(chan struct{})
go tab.doRefresh(refreshDone) go tab.doRefresh(refreshDone)
} }
case req := <-tab.refreshReq: case req := <-tab.refreshReq:
waiting = append(waiting, req) waiting = append(waiting, req)
if refreshDone == nil { if refreshDone == nil {
refreshDone = make(chan struct{}) refreshDone = make(chan struct{})
go tab.doRefresh(refreshDone) go tab.doRefresh(refreshDone)
} }
case <-refreshDone: case <-refreshDone:
for _, ch := range waiting { for _, ch := range waiting {
close(ch) close(ch)
} }
waiting, refreshDone = nil, nil waiting, refreshDone = nil, nil
refresh.Reset(tab.nextRefreshTime()) refresh.Reset(tab.nextRefreshTime())
case <-revalidate.C:
revalidateDone = make(chan struct{})
go tab.doRevalidate(revalidateDone)
case <-revalidateDone:
revalidate.Reset(tab.nextRevalidateTime())
revalidateDone = nil
case <-copyNodes.C:
go tab.copyLiveNodes()
case <-tab.closeReq: case <-tab.closeReq:
break loop break loop
} }
@ -296,9 +415,6 @@ loop:
for _, ch := range waiting { for _, ch := range waiting {
close(ch) close(ch)
} }
if revalidateDone != nil {
<-revalidateDone
}
close(tab.closed) close(tab.closed)
} }
@ -335,169 +451,15 @@ func (tab *Table) loadSeedNodes() {
age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP()))
tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age)
} }
tab.addSeenNode(seed) tab.handleAddNode(addNodeOp{node: seed, isInbound: false})
} }
} }
// doRevalidate checks that the last node in a random bucket is still live and replaces or
// deletes the node if it isn't.
func (tab *Table) doRevalidate(done chan<- struct{}) {
defer func() { done <- struct{}{} }()
last, bi := tab.nodeToRevalidate()
if last == nil {
// No non-empty bucket found.
return
}
// Ping the selected node and wait for a pong.
remoteSeq, err := tab.net.ping(unwrapNode(last))
// Also fetch record if the node replied and returned a higher sequence number.
if last.Seq() < remoteSeq {
n, err := tab.net.RequestENR(unwrapNode(last))
if err != nil {
tab.log.Debug("ENR request failed", "id", last.ID(), "addr", last.addr(), "err", err)
} else {
last = &node{Node: *n, addedAt: last.addedAt, livenessChecks: last.livenessChecks}
}
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
b := tab.buckets[bi]
if err == nil {
// The node responded, move it to the front.
last.livenessChecks++
tab.log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks)
tab.bumpInBucket(b, last)
return
}
// No reply received, pick a replacement or delete the node if there aren't
// any replacements.
if r := tab.replace(b, last); r != nil {
tab.log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP())
} else {
tab.log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks)
}
}
// nodeToRevalidate returns the last node in a random, non-empty bucket.
func (tab *Table) nodeToRevalidate() (n *node, bi int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, bi = range tab.rand.Perm(len(tab.buckets)) {
b := tab.buckets[bi]
if len(b.entries) > 0 {
last := b.entries[len(b.entries)-1]
return last, bi
}
}
return nil, 0
}
func (tab *Table) nextRevalidateTime() time.Duration {
tab.mutex.Lock()
defer tab.mutex.Unlock()
return time.Duration(tab.rand.Int63n(int64(tab.cfg.PingInterval)))
}
func (tab *Table) nextRefreshTime() time.Duration { func (tab *Table) nextRefreshTime() time.Duration {
tab.mutex.Lock()
defer tab.mutex.Unlock()
half := tab.cfg.RefreshInterval / 2 half := tab.cfg.RefreshInterval / 2
return half + time.Duration(tab.rand.Int63n(int64(half))) return half + time.Duration(tab.rand.Int63n(int64(half)))
} }
// copyLiveNodes adds nodes from the table to the database if they have been in the table
// longer than seedMinTableTime.
func (tab *Table) copyLiveNodes() {
tab.mutex.Lock()
defer tab.mutex.Unlock()
now := time.Now()
for _, b := range &tab.buckets {
for _, n := range b.entries {
if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime {
tab.db.UpdateNode(unwrapNode(n))
}
}
}
}
// findnodeByID returns the n nodes in the table that are closest to the given id.
// This is used by the FINDNODE/v4 handler.
//
// The preferLive parameter says whether the caller wants liveness-checked results. If
// preferLive is true and the table contains any verified nodes, the result will not
// contain unverified nodes. However, if there are no verified nodes at all, the result
// will contain unverified nodes.
func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance {
tab.mutex.Lock()
defer tab.mutex.Unlock()
// Scan all buckets. There might be a better way to do this, but there aren't that many
// buckets, so this solution should be fine. The worst-case complexity of this loop
// is O(tab.len() * nresults).
nodes := &nodesByDistance{target: target}
liveNodes := &nodesByDistance{target: target}
for _, b := range &tab.buckets {
for _, n := range b.entries {
nodes.push(n, nresults)
if preferLive && n.livenessChecks > 0 {
liveNodes.push(n, nresults)
}
}
}
if preferLive && len(liveNodes.entries) > 0 {
return liveNodes
}
return nodes
}
// appendLiveNodes adds nodes at the given distance to the result slice.
func (tab *Table) appendLiveNodes(dist uint, result []*enode.Node) []*enode.Node {
if dist > 256 {
return result
}
if dist == 0 {
return append(result, tab.self())
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, n := range tab.bucketAtDistance(int(dist)).entries {
if n.livenessChecks >= 1 {
node := n.Node // avoid handing out pointer to struct field
result = append(result, &node)
}
}
return result
}
// len returns the number of nodes in the table.
func (tab *Table) len() (n int) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
for _, b := range &tab.buckets {
n += len(b.entries)
}
return n
}
// bucketLen returns the number of nodes in the bucket for the given ID.
func (tab *Table) bucketLen(id enode.ID) int {
tab.mutex.Lock()
defer tab.mutex.Unlock()
return len(tab.bucket(id).entries)
}
// bucket returns the bucket for the given node ID hash. // bucket returns the bucket for the given node ID hash.
func (tab *Table) bucket(id enode.ID) *bucket { func (tab *Table) bucket(id enode.ID) *bucket {
d := enode.LogDist(tab.self().ID(), id) d := enode.LogDist(tab.self().ID(), id)
@ -511,95 +473,6 @@ func (tab *Table) bucketAtDistance(d int) *bucket {
return tab.buckets[d-bucketMinDistance-1] return tab.buckets[d-bucketMinDistance-1]
} }
// addSeenNode adds a node which may or may not be live to the end of a bucket. If the
// bucket has space available, adding the node succeeds immediately. Otherwise, the node is
// added to the replacements list.
//
// The caller must not hold tab.mutex.
func (tab *Table) addSeenNode(n *node) {
if n.ID() == tab.self().ID() {
return
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
b := tab.bucket(n.ID())
if contains(b.entries, n.ID()) {
// Already in bucket, don't add.
return
}
if len(b.entries) >= bucketSize {
// Bucket full, maybe add as replacement.
tab.addReplacement(b, n)
return
}
if !tab.addIP(b, n.IP()) {
// Can't add: IP limit reached.
return
}
// Add to end of bucket:
b.entries = append(b.entries, n)
b.replacements = deleteNode(b.replacements, n)
n.addedAt = time.Now()
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(b, n)
}
}
// addVerifiedNode adds a node whose existence has been verified recently to the front of a
// bucket. If the node is already in the bucket, it is moved to the front. If the bucket
// has no space, the node is added to the replacements list.
//
// There is an additional safety measure: if the table is still initializing the node
// is not added. This prevents an attack where the table could be filled by just sending
// ping repeatedly.
//
// The caller must not hold tab.mutex.
func (tab *Table) addVerifiedNode(n *node) {
if !tab.isInitDone() {
return
}
if n.ID() == tab.self().ID() {
return
}
tab.mutex.Lock()
defer tab.mutex.Unlock()
b := tab.bucket(n.ID())
if tab.bumpInBucket(b, n) {
// Already in bucket, moved to front.
return
}
if len(b.entries) >= bucketSize {
// Bucket full, maybe add as replacement.
tab.addReplacement(b, n)
return
}
if !tab.addIP(b, n.IP()) {
// Can't add: IP limit reached.
return
}
// Add to front of bucket.
b.entries, _ = pushNode(b.entries, n, bucketSize)
b.replacements = deleteNode(b.replacements, n)
n.addedAt = time.Now()
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(b, n)
}
}
// delete removes an entry from the node table. It is used to evacuate dead nodes.
func (tab *Table) delete(node *node) {
tab.mutex.Lock()
defer tab.mutex.Unlock()
tab.deleteInBucket(tab.bucket(node.ID()), node)
}
func (tab *Table) addIP(b *bucket, ip net.IP) bool { func (tab *Table) addIP(b *bucket, ip net.IP) bool {
if len(ip) == 0 { if len(ip) == 0 {
return false // Nodes without IP cannot be added. return false // Nodes without IP cannot be added.
@ -627,15 +500,51 @@ func (tab *Table) removeIP(b *bucket, ip net.IP) {
b.ips.Remove(ip) b.ips.Remove(ip)
} }
func (tab *Table) addReplacement(b *bucket, n *node) { // handleAddNode adds the node in the request to the table, if there is space.
for _, e := range b.replacements { // The caller must hold tab.mutex.
if e.ID() == n.ID() { func (tab *Table) handleAddNode(req addNodeOp) bool {
return // already in list if req.node.ID() == tab.self().ID() {
return false
} }
// For nodes from inbound contact, there is an additional safety measure: if the table
// is still initializing the node is not added.
if req.isInbound && !tab.isInitDone() {
return false
}
b := tab.bucket(req.node.ID())
if tab.bumpInBucket(b, req.node.Node) {
// Already in bucket, update record.
return false
}
if len(b.entries) >= bucketSize {
// Bucket full, maybe add as replacement.
tab.addReplacement(b, req.node)
return false
}
if !tab.addIP(b, req.node.IP()) {
// Can't add: IP limit reached.
return false
}
// Add to bucket.
b.entries = append(b.entries, req.node)
b.replacements = deleteNode(b.replacements, req.node)
tab.nodeAdded(b, req.node)
return true
}
// addReplacement adds n to the replacement cache of bucket b.
func (tab *Table) addReplacement(b *bucket, n *node) {
if contains(b.replacements, n.ID()) {
// TODO: update ENR
return
} }
if !tab.addIP(b, n.IP()) { if !tab.addIP(b, n.IP()) {
return return
} }
n.addedToTable = time.Now()
var removed *node var removed *node
b.replacements, removed = pushNode(b.replacements, n, maxReplacements) b.replacements, removed = pushNode(b.replacements, n, maxReplacements)
if removed != nil { if removed != nil {
@ -643,59 +552,107 @@ func (tab *Table) addReplacement(b *bucket, n *node) {
} }
} }
// replace removes n from the replacement list and replaces 'last' with it if it is the func (tab *Table) nodeAdded(b *bucket, n *node) {
// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced if n.addedToTable == (time.Time{}) {
// with someone else or became active. n.addedToTable = time.Now()
func (tab *Table) replace(b *bucket, last *node) *node {
if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID() != last.ID() {
// Entry has moved, don't replace it.
return nil
} }
// Still the last entry. n.addedToBucket = time.Now()
if len(b.replacements) == 0 { tab.revalidation.nodeAdded(tab, n)
tab.deleteInBucket(b, last) if tab.nodeAddedHook != nil {
return nil tab.nodeAddedHook(b, n)
}
if metrics.Enabled {
bucketsCounter[b.index].Inc(1)
} }
r := b.replacements[tab.rand.Intn(len(b.replacements))]
b.replacements = deleteNode(b.replacements, r)
b.entries[len(b.entries)-1] = r
tab.removeIP(b, last.IP())
return r
} }
// bumpInBucket moves the given node to the front of the bucket entry list func (tab *Table) nodeRemoved(b *bucket, n *node) {
// if it is contained in that list. tab.revalidation.nodeRemoved(n)
func (tab *Table) bumpInBucket(b *bucket, n *node) bool { if tab.nodeRemovedHook != nil {
for i := range b.entries { tab.nodeRemovedHook(b, n)
if b.entries[i].ID() == n.ID() { }
if !n.IP().Equal(b.entries[i].IP()) { if metrics.Enabled {
bucketsCounter[b.index].Dec(1)
}
}
// deleteInBucket removes node n from the table.
// If there are replacement nodes in the bucket, the node is replaced.
func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node {
index := slices.IndexFunc(b.entries, func(e *node) bool { return e.ID() == id })
if index == -1 {
// Entry has been removed already.
return nil
}
// Remove the node.
n := b.entries[index]
b.entries = slices.Delete(b.entries, index, index+1)
tab.removeIP(b, n.IP())
tab.nodeRemoved(b, n)
// Add replacement.
if len(b.replacements) == 0 {
tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP())
return nil
}
rindex := tab.rand.Intn(len(b.replacements))
rep := b.replacements[rindex]
b.replacements = slices.Delete(b.replacements, rindex, rindex+1)
b.entries = append(b.entries, rep)
tab.nodeAdded(b, rep)
tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP())
return rep
}
// bumpInBucket updates the node record of n in the bucket.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node) bool {
i := slices.IndexFunc(b.entries, func(elem *node) bool {
return elem.ID() == newRecord.ID()
})
if i == -1 {
return false
}
if !newRecord.IP().Equal(b.entries[i].IP()) {
// Endpoint has changed, ensure that the new IP fits into table limits. // Endpoint has changed, ensure that the new IP fits into table limits.
tab.removeIP(b, b.entries[i].IP()) tab.removeIP(b, b.entries[i].IP())
if !tab.addIP(b, n.IP()) { if !tab.addIP(b, newRecord.IP()) {
// It doesn't, put the previous one back. // It doesn't, put the previous one back.
tab.addIP(b, b.entries[i].IP()) tab.addIP(b, b.entries[i].IP())
return false return false
} }
} }
// Move it to the front. b.entries[i].Node = newRecord
copy(b.entries[1:], b.entries[:i])
b.entries[0] = n
return true return true
}
}
return false
} }
func (tab *Table) deleteInBucket(b *bucket, n *node) { func (tab *Table) handleTrackRequest(op trackRequestOp) {
// Check if the node is actually in the bucket so the removed hook var fails int
// isn't called multiple times for the same node. if op.success {
if !contains(b.entries, n.ID()) { // Reset failure counter because it counts _consecutive_ failures.
return tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0)
} else {
fails = tab.db.FindFails(op.node.ID(), op.node.IP())
fails++
tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails)
} }
b.entries = deleteNode(b.entries, n)
tab.removeIP(b, n.IP()) tab.mutex.Lock()
if tab.nodeRemovedHook != nil { defer tab.mutex.Unlock()
tab.nodeRemovedHook(b, n)
b := tab.bucket(op.node.ID())
// Remove the node from the local table if it fails to return anything useful too
// many times, but only if there are enough other nodes in the bucket. This latter
// condition specifically exists to make bootstrapping in smaller test networks more
// reliable.
if fails >= maxFindnodeFailures && len(b.entries) >= bucketSize/4 {
tab.deleteInBucket(b, op.node.ID())
}
// Add found nodes.
for _, n := range op.foundNodes {
tab.handleAddNode(addNodeOp{n, false})
} }
} }

223
p2p/discover/table_reval.go Normal file
View File

@ -0,0 +1,223 @@
// Copyright 2024 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package discover
import (
"fmt"
"math"
"slices"
"time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
)
const never = mclock.AbsTime(math.MaxInt64)
// tableRevalidation implements the node revalidation process.
// It tracks all nodes contained in Table, and schedules sending PING to them.
type tableRevalidation struct {
fast revalidationList
slow revalidationList
activeReq map[enode.ID]struct{}
}
type revalidationResponse struct {
n *node
newRecord *enode.Node
list *revalidationList
didRespond bool
}
func (tr *tableRevalidation) init(cfg *Config) {
tr.activeReq = make(map[enode.ID]struct{})
tr.fast.nextTime = never
tr.fast.interval = cfg.PingInterval
tr.fast.name = "fast"
tr.slow.nextTime = never
tr.slow.interval = cfg.PingInterval * 3
tr.slow.name = "slow"
}
// nodeAdded is called when the table receives a new node.
func (tr *tableRevalidation) nodeAdded(tab *Table, n *node) {
tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand)
}
// nodeRemoved is called when a node was removed from the table.
func (tr *tableRevalidation) nodeRemoved(n *node) {
if !tr.fast.remove(n) {
tr.slow.remove(n)
}
}
// run performs node revalidation.
// It returns the next time it should be invoked, which is used in the Table main loop
// to schedule a timer. However, run can be called at any time.
func (tr *tableRevalidation) run(tab *Table, now mclock.AbsTime) (nextTime mclock.AbsTime) {
if n := tr.fast.get(now, &tab.rand, tr.activeReq); n != nil {
tr.startRequest(tab, &tr.fast, n)
tr.fast.schedule(now, &tab.rand)
}
if n := tr.slow.get(now, &tab.rand, tr.activeReq); n != nil {
tr.startRequest(tab, &tr.slow, n)
tr.slow.schedule(now, &tab.rand)
}
return min(tr.fast.nextTime, tr.slow.nextTime)
}
// startRequest spawns a revalidation request for node n.
func (tr *tableRevalidation) startRequest(tab *Table, list *revalidationList, n *node) {
if _, ok := tr.activeReq[n.ID()]; ok {
panic(fmt.Errorf("duplicate startRequest (list %q, node %v)", list.name, n.ID()))
}
tr.activeReq[n.ID()] = struct{}{}
resp := revalidationResponse{n: n, list: list}
// Fetch the node while holding lock.
tab.mutex.Lock()
node := n.Node
tab.mutex.Unlock()
go tab.doRevalidate(resp, node)
}
func (tab *Table) doRevalidate(resp revalidationResponse, node *enode.Node) {
// Ping the selected node and wait for a pong response.
remoteSeq, err := tab.net.ping(node)
resp.didRespond = err == nil
// Also fetch record if the node replied and returned a higher sequence number.
if remoteSeq > node.Seq() {
newrec, err := tab.net.RequestENR(node)
if err != nil {
tab.log.Debug("ENR request failed", "id", node.ID(), "err", err)
} else {
resp.newRecord = newrec
}
}
select {
case tab.revalResponseCh <- resp:
case <-tab.closed:
}
}
// handleResponse processes the result of a revalidation request.
func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationResponse) {
now := tab.cfg.Clock.Now()
n := resp.n
b := tab.bucket(n.ID())
delete(tr.activeReq, n.ID())
tab.mutex.Lock()
defer tab.mutex.Unlock()
if !resp.didRespond {
// Revalidation failed.
n.livenessChecks /= 3
if n.livenessChecks <= 0 {
tab.deleteInBucket(b, n.ID())
} else {
tr.moveToList(&tr.fast, resp.list, n, now, &tab.rand)
}
return
}
// The node responded.
n.livenessChecks++
n.isValidatedLive = true
var endpointChanged bool
if resp.newRecord != nil {
endpointChanged = tab.bumpInBucket(b, resp.newRecord)
if endpointChanged {
// If the node changed its advertised endpoint, the updated ENR is not served
// until it has been revalidated.
n.isValidatedLive = false
}
}
tab.log.Debug("Revalidated node", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", resp.list.name)
// Move node over to slow queue after first validation.
if !endpointChanged {
tr.moveToList(&tr.slow, resp.list, n, now, &tab.rand)
} else {
tr.moveToList(&tr.fast, resp.list, n, now, &tab.rand)
}
// Store potential seeds in database.
if n.isValidatedLive && n.livenessChecks > 5 {
tab.db.UpdateNode(resp.n.Node)
}
}
func (tr *tableRevalidation) moveToList(dest, source *revalidationList, n *node, now mclock.AbsTime, rand randomSource) {
if source == dest {
return
}
if !source.remove(n) {
panic(fmt.Errorf("moveToList(%q -> %q): node %v not in source list", source.name, dest.name, n.ID()))
}
dest.push(n, now, rand)
}
// revalidationList holds a list nodes and the next revalidation time.
type revalidationList struct {
nodes []*node
nextTime mclock.AbsTime
interval time.Duration
name string
}
// get returns a random node from the queue. Nodes in the 'exclude' map are not returned.
func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *node {
if now < list.nextTime || len(list.nodes) == 0 {
return nil
}
for i := 0; i < len(list.nodes)*3; i++ {
n := list.nodes[rand.Intn(len(list.nodes))]
_, excluded := exclude[n.ID()]
if !excluded {
return n
}
}
return nil
}
func (list *revalidationList) schedule(now mclock.AbsTime, rand randomSource) {
list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval))))
}
func (list *revalidationList) push(n *node, now mclock.AbsTime, rand randomSource) {
list.nodes = append(list.nodes, n)
if list.nextTime == never {
list.schedule(now, rand)
}
}
func (list *revalidationList) remove(n *node) bool {
i := slices.Index(list.nodes, n)
if i == -1 {
return false
}
list.nodes = slices.Delete(list.nodes, i, i+1)
if len(list.nodes) == 0 {
list.nextTime = never
}
return true
}

View File

@ -20,14 +20,16 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
"time" "time"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/internal/testlog"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/enr"
"github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/p2p/netutil"
@ -49,106 +51,109 @@ func TestTable_pingReplace(t *testing.T) {
} }
func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) {
simclock := new(mclock.Simulated)
transport := newPingRecorder() transport := newPingRecorder()
tab, db := newTestTable(transport) tab, db := newTestTable(transport, Config{
Clock: simclock,
Log: testlog.Logger(t, log.LevelTrace),
})
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
<-tab.initDone <-tab.initDone
// Fill up the sender's bucket. // Fill up the sender's bucket.
pingKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8")
pingSender := wrapNode(enode.NewV4(&pingKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)) replacementNode := wrapNode(enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99))
last := fillBucket(tab, pingSender) last := fillBucket(tab, replacementNode.ID())
tab.mutex.Lock()
nodeEvents := newNodeEventRecorder(128)
tab.nodeAddedHook = nodeEvents.nodeAdded
tab.nodeRemovedHook = nodeEvents.nodeRemoved
tab.mutex.Unlock()
// Add the sender as if it just pinged us. Revalidate should replace the last node in // The revalidation process should replace
// its bucket if it is unresponsive. Revalidate again to ensure that // this node in the bucket if it is unresponsive.
transport.dead[last.ID()] = !lastInBucketIsResponding transport.dead[last.ID()] = !lastInBucketIsResponding
transport.dead[pingSender.ID()] = !newNodeIsResponding transport.dead[replacementNode.ID()] = !newNodeIsResponding
tab.addSeenNode(pingSender)
tab.doRevalidate(make(chan struct{}, 1))
tab.doRevalidate(make(chan struct{}, 1))
if !transport.pinged[last.ID()] { // Add replacement node to table.
// Oldest node in bucket is pinged to see whether it is still alive. tab.addFoundNode(replacementNode)
t.Error("table did not ping last node in bucket")
t.Log("last:", last.ID())
t.Log("replacement:", replacementNode.ID())
// Wait until the last node was pinged.
waitForRevalidationPing(t, transport, tab, last.ID())
if !lastInBucketIsResponding {
if !nodeEvents.waitNodeAbsent(last.ID(), 2*time.Second) {
t.Error("last node was not removed")
}
if !nodeEvents.waitNodePresent(replacementNode.ID(), 2*time.Second) {
t.Error("replacement node was not added")
} }
// If a replacement is expected, we also need to wait until the replacement node
// was pinged and added/removed.
waitForRevalidationPing(t, transport, tab, replacementNode.ID())
if !newNodeIsResponding {
if !nodeEvents.waitNodeAbsent(replacementNode.ID(), 2*time.Second) {
t.Error("replacement node was not removed")
}
}
}
// Check bucket content.
tab.mutex.Lock() tab.mutex.Lock()
defer tab.mutex.Unlock() defer tab.mutex.Unlock()
wantSize := bucketSize wantSize := bucketSize
if !lastInBucketIsResponding && !newNodeIsResponding { if !lastInBucketIsResponding && !newNodeIsResponding {
wantSize-- wantSize--
} }
if l := len(tab.bucket(pingSender.ID()).entries); l != wantSize { bucket := tab.bucket(replacementNode.ID())
t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) if l := len(bucket.entries); l != wantSize {
t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize)
} }
if found := contains(tab.bucket(pingSender.ID()).entries, last.ID()); found != lastInBucketIsResponding { if ok := contains(bucket.entries, last.ID()); ok != lastInBucketIsResponding {
t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding)
} }
wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding
if found := contains(tab.bucket(pingSender.ID()).entries, pingSender.ID()); found != wantNewEntry { if ok := contains(bucket.entries, replacementNode.ID()); ok != wantNewEntry {
t.Errorf("new entry found: %t, want: %t", found, wantNewEntry) t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry)
} }
} }
func TestBucket_bumpNoDuplicates(t *testing.T) { // waitForRevalidationPing waits until a PING message is sent to a node with the given id.
t.Parallel() func waitForRevalidationPing(t *testing.T, transport *pingRecorder, tab *Table, id enode.ID) *enode.Node {
cfg := &quick.Config{ t.Helper()
MaxCount: 1000,
Rand: rand.New(rand.NewSource(time.Now().Unix())),
Values: func(args []reflect.Value, rand *rand.Rand) {
// generate a random list of nodes. this will be the content of the bucket.
n := rand.Intn(bucketSize-1) + 1
nodes := make([]*node, n)
for i := range nodes {
nodes[i] = nodeAtDistance(enode.ID{}, 200, intIP(200))
}
args[0] = reflect.ValueOf(nodes)
// generate random bump positions.
bumps := make([]int, rand.Intn(100))
for i := range bumps {
bumps[i] = rand.Intn(len(nodes))
}
args[1] = reflect.ValueOf(bumps)
},
}
prop := func(nodes []*node, bumps []int) (ok bool) { simclock := tab.cfg.Clock.(*mclock.Simulated)
tab, db := newTestTable(newPingRecorder()) maxAttempts := tab.len() * 8
defer db.Close() for i := 0; i < maxAttempts; i++ {
defer tab.close() simclock.Run(tab.cfg.PingInterval)
p := transport.waitPing(2 * time.Second)
b := &bucket{entries: make([]*node, len(nodes))} if p == nil {
copy(b.entries, nodes) t.Fatal("Table did not send revalidation ping")
for i, pos := range bumps {
tab.bumpInBucket(b, b.entries[pos])
if hasDuplicates(b.entries) {
t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps))
for _, n := range b.entries {
t.Logf(" %p", n)
} }
return false if id == (enode.ID{}) || p.ID() == id {
return p
} }
} }
checkIPLimitInvariant(t, tab) t.Fatalf("Table did not ping node %v (%d attempts)", id, maxAttempts)
return true return nil
}
if err := quick.Check(prop, cfg); err != nil {
t.Error(err)
}
} }
// This checks that the table-wide IP limit is applied correctly. // This checks that the table-wide IP limit is applied correctly.
func TestTable_IPLimit(t *testing.T) { func TestTable_IPLimit(t *testing.T) {
transport := newPingRecorder() transport := newPingRecorder()
tab, db := newTestTable(transport) tab, db := newTestTable(transport, Config{})
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
for i := 0; i < tableIPLimit+1; i++ { for i := 0; i < tableIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
tab.addSeenNode(n) tab.addFoundNode(n)
} }
if tab.len() > tableIPLimit { if tab.len() > tableIPLimit {
t.Errorf("too many nodes in table") t.Errorf("too many nodes in table")
@ -159,14 +164,14 @@ func TestTable_IPLimit(t *testing.T) {
// This checks that the per-bucket IP limit is applied correctly. // This checks that the per-bucket IP limit is applied correctly.
func TestTable_BucketIPLimit(t *testing.T) { func TestTable_BucketIPLimit(t *testing.T) {
transport := newPingRecorder() transport := newPingRecorder()
tab, db := newTestTable(transport) tab, db := newTestTable(transport, Config{})
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
d := 3 d := 3
for i := 0; i < bucketIPLimit+1; i++ { for i := 0; i < bucketIPLimit+1; i++ {
n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
tab.addSeenNode(n) tab.addFoundNode(n)
} }
if tab.len() > bucketIPLimit { if tab.len() > bucketIPLimit {
t.Errorf("too many nodes in table") t.Errorf("too many nodes in table")
@ -196,7 +201,7 @@ func TestTable_findnodeByID(t *testing.T) {
test := func(test *closeTest) bool { test := func(test *closeTest) bool {
// for any node table, Target and N // for any node table, Target and N
transport := newPingRecorder() transport := newPingRecorder()
tab, db := newTestTable(transport) tab, db := newTestTable(transport, Config{})
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
fillTable(tab, test.All, true) fillTable(tab, test.All, true)
@ -271,7 +276,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
} }
func TestTable_addVerifiedNode(t *testing.T) { func TestTable_addVerifiedNode(t *testing.T) {
tab, db := newTestTable(newPingRecorder()) tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone <-tab.initDone
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
@ -279,31 +284,32 @@ func TestTable_addVerifiedNode(t *testing.T) {
// Insert two nodes. // Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addSeenNode(n1) tab.addFoundNode(n1)
tab.addSeenNode(n2) tab.addFoundNode(n2)
bucket := tab.bucket(n1.ID())
// Verify bucket content: // Verify bucket content:
bcontent := []*node{n1, n2} bcontent := []*node{n1, n2}
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { if !reflect.DeepEqual(unwrapNodes(bucket.entries), unwrapNodes(bcontent)) {
t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries) t.Fatalf("wrong bucket content: %v", bucket.entries)
} }
// Add a changed version of n2. // Add a changed version of n2.
newrec := n2.Record() newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99}) newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID())) newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addVerifiedNode(newn2) tab.addInboundNode(newn2)
// Check that bucket is updated correctly. // Check that bucket is updated correctly.
newBcontent := []*node{newn2, n1} newBcontent := []*node{n1, newn2}
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, newBcontent) { if !reflect.DeepEqual(unwrapNodes(bucket.entries), unwrapNodes(newBcontent)) {
t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries) t.Fatalf("wrong bucket content after update: %v", bucket.entries)
} }
checkIPLimitInvariant(t, tab) checkIPLimitInvariant(t, tab)
} }
func TestTable_addSeenNode(t *testing.T) { func TestTable_addSeenNode(t *testing.T) {
tab, db := newTestTable(newPingRecorder()) tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone <-tab.initDone
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
@ -311,8 +317,8 @@ func TestTable_addSeenNode(t *testing.T) {
// Insert two nodes. // Insert two nodes.
n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1})
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addSeenNode(n1) tab.addFoundNode(n1)
tab.addSeenNode(n2) tab.addFoundNode(n2)
// Verify bucket content: // Verify bucket content:
bcontent := []*node{n1, n2} bcontent := []*node{n1, n2}
@ -324,7 +330,7 @@ func TestTable_addSeenNode(t *testing.T) {
newrec := n2.Record() newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99}) newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID())) newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addSeenNode(newn2) tab.addFoundNode(newn2)
// Check that bucket content is unchanged. // Check that bucket content is unchanged.
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
@ -337,7 +343,10 @@ func TestTable_addSeenNode(t *testing.T) {
// announces a new sequence number, the new record should be pulled. // announces a new sequence number, the new record should be pulled.
func TestTable_revalidateSyncRecord(t *testing.T) { func TestTable_revalidateSyncRecord(t *testing.T) {
transport := newPingRecorder() transport := newPingRecorder()
tab, db := newTestTable(transport) tab, db := newTestTable(transport, Config{
Clock: new(mclock.Simulated),
Log: testlog.Logger(t, log.LevelTrace),
})
<-tab.initDone <-tab.initDone
defer db.Close() defer db.Close()
defer tab.close() defer tab.close()
@ -347,14 +356,18 @@ func TestTable_revalidateSyncRecord(t *testing.T) {
r.Set(enr.IP(net.IP{127, 0, 0, 1})) r.Set(enr.IP(net.IP{127, 0, 0, 1}))
id := enode.ID{1} id := enode.ID{1}
n1 := wrapNode(enode.SignNull(&r, id)) n1 := wrapNode(enode.SignNull(&r, id))
tab.addSeenNode(n1) tab.addFoundNode(n1)
// Update the node record. // Update the node record.
r.Set(enr.WithEntry("foo", "bar")) r.Set(enr.WithEntry("foo", "bar"))
n2 := enode.SignNull(&r, id) n2 := enode.SignNull(&r, id)
transport.updateRecord(n2) transport.updateRecord(n2)
tab.doRevalidate(make(chan struct{}, 1)) // Wait for revalidation. We wait for the node to be revalidated two times
// in order to synchronize with the update in the able.
waitForRevalidationPing(t, transport, tab, n2.ID())
waitForRevalidationPing(t, transport, tab, n2.ID())
intable := tab.getNode(id) intable := tab.getNode(id)
if !reflect.DeepEqual(intable, n2) { if !reflect.DeepEqual(intable, n2) {
t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq()) t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq())

View File

@ -26,6 +26,8 @@ import (
"net" "net"
"slices" "slices"
"sync" "sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enode"
@ -40,8 +42,7 @@ func init() {
nullNode = enode.SignNull(&r, enode.ID{}) nullNode = enode.SignNull(&r, enode.ID{})
} }
func newTestTable(t transport) (*Table, *enode.DB) { func newTestTable(t transport, cfg Config) (*Table, *enode.DB) {
cfg := Config{}
db, _ := enode.OpenDB("") db, _ := enode.OpenDB("")
tab, _ := newTable(t, db, cfg) tab, _ := newTable(t, db, cfg)
go tab.loop() go tab.loop()
@ -98,11 +99,14 @@ func intIP(i int) net.IP {
} }
// fillBucket inserts nodes into the given bucket until it is full. // fillBucket inserts nodes into the given bucket until it is full.
func fillBucket(tab *Table, n *node) (last *node) { func fillBucket(tab *Table, id enode.ID) (last *node) {
ld := enode.LogDist(tab.self().ID(), n.ID()) ld := enode.LogDist(tab.self().ID(), id)
b := tab.bucket(n.ID()) b := tab.bucket(id)
for len(b.entries) < bucketSize { for len(b.entries) < bucketSize {
b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld))) node := nodeAtDistance(tab.self().ID(), ld, intIP(ld))
if !tab.addFoundNode(node) {
panic("node not added")
}
} }
return b.entries[bucketSize-1] return b.entries[bucketSize-1]
} }
@ -113,15 +117,18 @@ func fillTable(tab *Table, nodes []*node, setLive bool) {
for _, n := range nodes { for _, n := range nodes {
if setLive { if setLive {
n.livenessChecks = 1 n.livenessChecks = 1
n.isValidatedLive = true
} }
tab.addSeenNode(n) tab.addFoundNode(n)
} }
} }
type pingRecorder struct { type pingRecorder struct {
mu sync.Mutex mu sync.Mutex
dead, pinged map[enode.ID]bool cond *sync.Cond
dead map[enode.ID]bool
records map[enode.ID]*enode.Node records map[enode.ID]*enode.Node
pinged []*enode.Node
n *enode.Node n *enode.Node
} }
@ -130,12 +137,13 @@ func newPingRecorder() *pingRecorder {
r.Set(enr.IP{0, 0, 0, 0}) r.Set(enr.IP{0, 0, 0, 0})
n := enode.SignNull(&r, enode.ID{}) n := enode.SignNull(&r, enode.ID{})
return &pingRecorder{ t := &pingRecorder{
dead: make(map[enode.ID]bool), dead: make(map[enode.ID]bool),
pinged: make(map[enode.ID]bool),
records: make(map[enode.ID]*enode.Node), records: make(map[enode.ID]*enode.Node),
n: n, n: n,
} }
t.cond = sync.NewCond(&t.mu)
return t
} }
// updateRecord updates a node record. Future calls to ping and // updateRecord updates a node record. Future calls to ping and
@ -151,12 +159,40 @@ func (t *pingRecorder) Self() *enode.Node { return nullNode }
func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } func (t *pingRecorder) lookupSelf() []*enode.Node { return nil }
func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } func (t *pingRecorder) lookupRandom() []*enode.Node { return nil }
func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node {
t.mu.Lock()
defer t.mu.Unlock()
// Wake up the loop on timeout.
var timedout atomic.Bool
timer := time.AfterFunc(timeout, func() {
timedout.Store(true)
t.cond.Broadcast()
})
defer timer.Stop()
// Wait for a ping.
for {
if timedout.Load() {
return nil
}
if len(t.pinged) > 0 {
n := t.pinged[0]
t.pinged = append(t.pinged[:0], t.pinged[1:]...)
return n
}
t.cond.Wait()
}
}
// ping simulates a ping request. // ping simulates a ping request.
func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
t.pinged[n.ID()] = true t.pinged = append(t.pinged, n)
t.cond.Broadcast()
if t.dead[n.ID()] { if t.dead[n.ID()] {
return 0, errTimeout return 0, errTimeout
} }
@ -256,3 +292,57 @@ func hexEncPubkey(h string) (ret encPubkey) {
copy(ret[:], b) copy(ret[:], b)
return ret return ret
} }
type nodeEventRecorder struct {
evc chan recordedNodeEvent
}
type recordedNodeEvent struct {
node *node
added bool
}
func newNodeEventRecorder(buffer int) *nodeEventRecorder {
return &nodeEventRecorder{
evc: make(chan recordedNodeEvent, buffer),
}
}
func (set *nodeEventRecorder) nodeAdded(b *bucket, n *node) {
select {
case set.evc <- recordedNodeEvent{n, true}:
default:
panic("no space in event buffer")
}
}
func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *node) {
select {
case set.evc <- recordedNodeEvent{n, false}:
default:
panic("no space in event buffer")
}
}
func (set *nodeEventRecorder) waitNodePresent(id enode.ID, timeout time.Duration) bool {
return set.waitNodeEvent(id, timeout, true)
}
func (set *nodeEventRecorder) waitNodeAbsent(id enode.ID, timeout time.Duration) bool {
return set.waitNodeEvent(id, timeout, false)
}
func (set *nodeEventRecorder) waitNodeEvent(id enode.ID, timeout time.Duration, added bool) bool {
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case ev := <-set.evc:
if ev.node.ID() == id && ev.added == added {
return true
}
case <-timer.C:
return false
}
}
}

View File

@ -142,7 +142,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) {
log: cfg.Log, log: cfg.Log,
} }
tab, err := newMeteredTable(t, ln.Database(), cfg) tab, err := newTable(t, ln.Database(), cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -375,6 +375,10 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) {
return respN, nil return respN, nil
} }
func (t *UDPv4) TableBuckets() [][]BucketNode {
return t.tab.Nodes()
}
// pending adds a reply matcher to the pending reply queue. // pending adds a reply matcher to the pending reply queue.
// see the documentation of type replyMatcher for a detailed explanation. // see the documentation of type replyMatcher for a detailed explanation.
func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) *replyMatcher { func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) *replyMatcher {
@ -669,10 +673,10 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I
n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port)) n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port))
if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration {
t.sendPing(fromID, from, func() { t.sendPing(fromID, from, func() {
t.tab.addVerifiedNode(n) t.tab.addInboundNode(n)
}) })
} else { } else {
t.tab.addVerifiedNode(n) t.tab.addInboundNode(n)
} }
// Update node database and endpoint predictor. // Update node database and endpoint predictor.

View File

@ -264,7 +264,7 @@ func TestUDPv4_findnode(t *testing.T) {
n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000)) n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000))
// Ensure half of table content isn't verified live yet. // Ensure half of table content isn't verified live yet.
if i > numCandidates/2 { if i > numCandidates/2 {
n.livenessChecks = 1 n.isValidatedLive = true
live[n.ID()] = true live[n.ID()] = true
} }
nodes.push(n, numCandidates) nodes.push(n, numCandidates)

View File

@ -175,7 +175,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) {
cancelCloseCtx: cancelCloseCtx, cancelCloseCtx: cancelCloseCtx,
} }
t.talk = newTalkSystem(t) t.talk = newTalkSystem(t)
tab, err := newMeteredTable(t, t.db, cfg) tab, err := newTable(t, t.db, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -699,7 +699,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error {
} }
if fromNode != nil { if fromNode != nil {
// Handshake succeeded, add to table. // Handshake succeeded, add to table.
t.tab.addSeenNode(wrapNode(fromNode)) t.tab.addInboundNode(wrapNode(fromNode))
} }
if packet.Kind() != v5wire.WhoareyouPacket { if packet.Kind() != v5wire.WhoareyouPacket {
// WHOAREYOU logged separately to report errors. // WHOAREYOU logged separately to report errors.

View File

@ -141,7 +141,7 @@ func TestUDPv5_unknownPacket(t *testing.T) {
// Make node known. // Make node known.
n := test.getNode(test.remotekey, test.remoteaddr).Node() n := test.getNode(test.remotekey, test.remoteaddr).Node()
test.table.addSeenNode(wrapNode(n)) test.table.addFoundNode(wrapNode(n))
test.packetIn(&v5wire.Unknown{Nonce: nonce}) test.packetIn(&v5wire.Unknown{Nonce: nonce})
test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) {

View File

@ -190,8 +190,8 @@ type Server struct {
nodedb *enode.DB nodedb *enode.DB
localnode *enode.LocalNode localnode *enode.LocalNode
ntab *discover.UDPv4 discv4 *discover.UDPv4
DiscV5 *discover.UDPv5 discv5 *discover.UDPv5
discmix *enode.FairMix discmix *enode.FairMix
dialsched *dialScheduler dialsched *dialScheduler
@ -400,6 +400,16 @@ func (srv *Server) Self() *enode.Node {
return ln.Node() return ln.Node()
} }
// DiscoveryV4 returns the discovery v4 instance, if configured.
func (srv *Server) DiscoveryV4() *discover.UDPv4 {
return srv.discv4
}
// DiscoveryV4 returns the discovery v5 instance, if configured.
func (srv *Server) DiscoveryV5() *discover.UDPv5 {
return srv.discv5
}
// Stop terminates the server and all active peer connections. // Stop terminates the server and all active peer connections.
// It blocks until all active connections have been closed. // It blocks until all active connections have been closed.
func (srv *Server) Stop() { func (srv *Server) Stop() {
@ -547,13 +557,13 @@ func (srv *Server) setupDiscovery() error {
) )
// If both versions of discovery are running, setup a shared // If both versions of discovery are running, setup a shared
// connection, so v5 can read unhandled messages from v4. // connection, so v5 can read unhandled messages from v4.
if srv.DiscoveryV4 && srv.DiscoveryV5 { if srv.Config.DiscoveryV4 && srv.Config.DiscoveryV5 {
unhandled = make(chan discover.ReadPacket, 100) unhandled = make(chan discover.ReadPacket, 100)
sconn = &sharedUDPConn{conn, unhandled} sconn = &sharedUDPConn{conn, unhandled}
} }
// Start discovery services. // Start discovery services.
if srv.DiscoveryV4 { if srv.Config.DiscoveryV4 {
cfg := discover.Config{ cfg := discover.Config{
PrivateKey: srv.PrivateKey, PrivateKey: srv.PrivateKey,
NetRestrict: srv.NetRestrict, NetRestrict: srv.NetRestrict,
@ -565,17 +575,17 @@ func (srv *Server) setupDiscovery() error {
if err != nil { if err != nil {
return err return err
} }
srv.ntab = ntab srv.discv4 = ntab
srv.discmix.AddSource(ntab.RandomNodes()) srv.discmix.AddSource(ntab.RandomNodes())
} }
if srv.DiscoveryV5 { if srv.Config.DiscoveryV5 {
cfg := discover.Config{ cfg := discover.Config{
PrivateKey: srv.PrivateKey, PrivateKey: srv.PrivateKey,
NetRestrict: srv.NetRestrict, NetRestrict: srv.NetRestrict,
Bootnodes: srv.BootstrapNodesV5, Bootnodes: srv.BootstrapNodesV5,
Log: srv.log, Log: srv.log,
} }
srv.DiscV5, err = discover.ListenV5(sconn, srv.localnode, cfg) srv.discv5, err = discover.ListenV5(sconn, srv.localnode, cfg)
if err != nil { if err != nil {
return err return err
} }
@ -602,8 +612,8 @@ func (srv *Server) setupDialScheduler() {
dialer: srv.Dialer, dialer: srv.Dialer,
clock: srv.clock, clock: srv.clock,
} }
if srv.ntab != nil { if srv.discv4 != nil {
config.resolver = srv.ntab config.resolver = srv.discv4
} }
if config.dialer == nil { if config.dialer == nil {
config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}} config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}}
@ -799,11 +809,11 @@ running:
srv.log.Trace("P2P networking is spinning down") srv.log.Trace("P2P networking is spinning down")
// Terminate discovery. If there is a running lookup it will terminate soon. // Terminate discovery. If there is a running lookup it will terminate soon.
if srv.ntab != nil { if srv.discv4 != nil {
srv.ntab.Close() srv.discv4.Close()
} }
if srv.DiscV5 != nil { if srv.discv5 != nil {
srv.DiscV5.Close() srv.discv5.Close()
} }
// Disconnect all peers. // Disconnect all peers.
for _, p := range peers { for _, p := range peers {