go-ethereum/p2p/discover/table_util_test.go

353 lines
8.4 KiB
Go

// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package discover
import (
"bytes"
"crypto/ecdsa"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"net"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover/v4wire"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
)
var nullNode *enode.Node
func init() {
var r enr.Record
r.Set(enr.IP{0, 0, 0, 0})
nullNode = enode.SignNull(&r, enode.ID{})
}
func newTestTable(t transport, cfg Config) (*Table, *enode.DB) {
tab, db := newInactiveTestTable(t, cfg)
go tab.loop()
return tab, db
}
// newInactiveTestTable creates a Table without running the main loop.
func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) {
db, _ := enode.OpenDB("")
tab, _ := newTable(t, db, cfg)
return tab, db
}
// nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld.
func nodeAtDistance(base enode.ID, ld int, ip net.IP) *enode.Node {
var r enr.Record
r.Set(enr.IP(ip))
r.Set(enr.UDP(30303))
return enode.SignNull(&r, idAtDistance(base, ld))
}
// nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld.
func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node {
results := make([]*enode.Node, n)
for i := range results {
results[i] = nodeAtDistance(base, ld, intIP(i))
}
return results
}
func nodesToRecords(nodes []*enode.Node) []*enr.Record {
records := make([]*enr.Record, len(nodes))
for i := range nodes {
records[i] = nodes[i].Record()
}
return records
}
// idAtDistance returns a random hash such that enode.LogDist(a, b) == n
func idAtDistance(a enode.ID, n int) (b enode.ID) {
if n == 0 {
return a
}
// flip bit at position n, fill the rest with random bits
b = a
pos := len(a) - n/8 - 1
bit := byte(0x01) << (byte(n%8) - 1)
if bit == 0 {
pos++
bit = 0x80
}
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
for i := pos + 1; i < len(a); i++ {
b[i] = byte(rand.Intn(255))
}
return b
}
// intIP returns a LAN IP address based on i.
func intIP(i int) net.IP {
return net.IP{10, 0, byte(i >> 8), byte(i & 0xFF)}
}
// fillBucket inserts nodes into the given bucket until it is full.
func fillBucket(tab *Table, id enode.ID) (last *tableNode) {
ld := enode.LogDist(tab.self().ID(), id)
b := tab.bucket(id)
for len(b.entries) < bucketSize {
node := nodeAtDistance(tab.self().ID(), ld, intIP(ld))
if !tab.addFoundNode(node, false) {
panic("node not added")
}
}
return b.entries[bucketSize-1]
}
// fillTable adds nodes the table to the end of their corresponding bucket
// if the bucket is not full. The caller must not hold tab.mutex.
func fillTable(tab *Table, nodes []*enode.Node, setLive bool) {
for _, n := range nodes {
tab.addFoundNode(n, setLive)
}
}
type pingRecorder struct {
mu sync.Mutex
cond *sync.Cond
dead map[enode.ID]bool
records map[enode.ID]*enode.Node
pinged []*enode.Node
n *enode.Node
}
func newPingRecorder() *pingRecorder {
var r enr.Record
r.Set(enr.IP{0, 0, 0, 0})
n := enode.SignNull(&r, enode.ID{})
t := &pingRecorder{
dead: make(map[enode.ID]bool),
records: make(map[enode.ID]*enode.Node),
n: n,
}
t.cond = sync.NewCond(&t.mu)
return t
}
// updateRecord updates a node record. Future calls to ping and
// RequestENR will return this record.
func (t *pingRecorder) updateRecord(n *enode.Node) {
t.mu.Lock()
defer t.mu.Unlock()
t.records[n.ID()] = n
}
// Stubs to satisfy the transport interface.
func (t *pingRecorder) Self() *enode.Node { return nullNode }
func (t *pingRecorder) lookupSelf() []*enode.Node { return nil }
func (t *pingRecorder) lookupRandom() []*enode.Node { return nil }
func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node {
t.mu.Lock()
defer t.mu.Unlock()
// Wake up the loop on timeout.
var timedout atomic.Bool
timer := time.AfterFunc(timeout, func() {
timedout.Store(true)
t.cond.Broadcast()
})
defer timer.Stop()
// Wait for a ping.
for {
if timedout.Load() {
return nil
}
if len(t.pinged) > 0 {
n := t.pinged[0]
t.pinged = append(t.pinged[:0], t.pinged[1:]...)
return n
}
t.cond.Wait()
}
}
// ping simulates a ping request.
func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) {
t.mu.Lock()
defer t.mu.Unlock()
t.pinged = append(t.pinged, n)
t.cond.Broadcast()
if t.dead[n.ID()] {
return 0, errTimeout
}
if t.records[n.ID()] != nil {
seq = t.records[n.ID()].Seq()
}
return seq, nil
}
// RequestENR simulates an ENR request.
func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.dead[n.ID()] || t.records[n.ID()] == nil {
return nil, errTimeout
}
return t.records[n.ID()], nil
}
func hasDuplicates(slice []*enode.Node) bool {
seen := make(map[enode.ID]bool, len(slice))
for i, e := range slice {
if e == nil {
panic(fmt.Sprintf("nil *Node at %d", i))
}
if seen[e.ID()] {
return true
}
seen[e.ID()] = true
}
return false
}
// checkNodesEqual checks whether the two given node lists contain the same nodes.
func checkNodesEqual(got, want []*enode.Node) error {
if len(got) == len(want) {
for i := range got {
if !nodeEqual(got[i], want[i]) {
goto NotEqual
}
}
}
return nil
NotEqual:
output := new(bytes.Buffer)
fmt.Fprintf(output, "got %d nodes:\n", len(got))
for _, n := range got {
fmt.Fprintf(output, " %v %v\n", n.ID(), n)
}
fmt.Fprintf(output, "want %d:\n", len(want))
for _, n := range want {
fmt.Fprintf(output, " %v %v\n", n.ID(), n)
}
return errors.New(output.String())
}
func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool {
return n1.ID() == n2.ID() && n1.IPAddr() == n2.IPAddr()
}
func sortByID[N nodeType](nodes []N) {
slices.SortFunc(nodes, func(a, b N) int {
return bytes.Compare(a.ID().Bytes(), b.ID().Bytes())
})
}
func sortedByDistanceTo(distbase enode.ID, slice []*enode.Node) bool {
return slices.IsSortedFunc(slice, func(a, b *enode.Node) int {
return enode.DistCmp(distbase, a.ID(), b.ID())
})
}
// hexEncPrivkey decodes h as a private key.
func hexEncPrivkey(h string) *ecdsa.PrivateKey {
b, err := hex.DecodeString(h)
if err != nil {
panic(err)
}
key, err := crypto.ToECDSA(b)
if err != nil {
panic(err)
}
return key
}
// hexEncPubkey decodes h as a public key.
func hexEncPubkey(h string) (ret v4wire.Pubkey) {
b, err := hex.DecodeString(h)
if err != nil {
panic(err)
}
if len(b) != len(ret) {
panic("invalid length")
}
copy(ret[:], b)
return ret
}
type nodeEventRecorder struct {
evc chan recordedNodeEvent
}
type recordedNodeEvent struct {
node *tableNode
added bool
}
func newNodeEventRecorder(buffer int) *nodeEventRecorder {
return &nodeEventRecorder{
evc: make(chan recordedNodeEvent, buffer),
}
}
func (set *nodeEventRecorder) nodeAdded(b *bucket, n *tableNode) {
select {
case set.evc <- recordedNodeEvent{n, true}:
default:
panic("no space in event buffer")
}
}
func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *tableNode) {
select {
case set.evc <- recordedNodeEvent{n, false}:
default:
panic("no space in event buffer")
}
}
func (set *nodeEventRecorder) waitNodePresent(id enode.ID, timeout time.Duration) bool {
return set.waitNodeEvent(id, timeout, true)
}
func (set *nodeEventRecorder) waitNodeAbsent(id enode.ID, timeout time.Duration) bool {
return set.waitNodeEvent(id, timeout, false)
}
func (set *nodeEventRecorder) waitNodeEvent(id enode.ID, timeout time.Duration, added bool) bool {
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case ev := <-set.evc:
if ev.node.ID() == id && ev.added == added {
return true
}
case <-timer.C:
return false
}
}
}