p2p, p2p/discover: add signed ENR generation (#17753)
This PR adds enode.LocalNode and integrates it into the p2p subsystem. This new object is the keeper of the local node record. For now, a new version of the record is produced every time the client restarts. We'll make it smarter to avoid that in the future. There are a couple of other changes in this commit: discovery now waits for all of its goroutines at shutdown and the p2p server now closes the node database after discovery has shut down. This fixes a leveldb crash in tests. p2p server startup is faster because it doesn't need to wait for the external IP query anymore.
This commit is contained in:
parent
dcae0d348b
commit
6f607de5d5
|
@ -119,16 +119,17 @@ func main() {
|
|||
}
|
||||
|
||||
if *runv5 {
|
||||
if _, err := discv5.ListenUDP(nodeKey, conn, realaddr, "", restrictList); err != nil {
|
||||
if _, err := discv5.ListenUDP(nodeKey, conn, "", restrictList); err != nil {
|
||||
utils.Fatalf("%v", err)
|
||||
}
|
||||
} else {
|
||||
db, _ := enode.OpenDB("")
|
||||
ln := enode.NewLocalNode(db, nodeKey)
|
||||
cfg := discover.Config{
|
||||
PrivateKey: nodeKey,
|
||||
AnnounceAddr: realaddr,
|
||||
NetRestrict: restrictList,
|
||||
}
|
||||
if _, err := discover.ListenUDP(conn, cfg); err != nil {
|
||||
if _, err := discover.ListenUDP(conn, ln, cfg); err != nil {
|
||||
utils.Fatalf("%v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -454,9 +454,9 @@ func TestProtocolGather(t *testing.T) {
|
|||
Count int
|
||||
Maker InstrumentingWrapper
|
||||
}{
|
||||
"Zero Protocols": {0, InstrumentedServiceMakerA},
|
||||
"Single Protocol": {1, InstrumentedServiceMakerB},
|
||||
"Many Protocols": {25, InstrumentedServiceMakerC},
|
||||
"zero": {0, InstrumentedServiceMakerA},
|
||||
"one": {1, InstrumentedServiceMakerB},
|
||||
"many": {10, InstrumentedServiceMakerC},
|
||||
}
|
||||
for id, config := range services {
|
||||
protocols := make([]p2p.Protocol, config.Count)
|
||||
|
@ -480,7 +480,7 @@ func TestProtocolGather(t *testing.T) {
|
|||
defer stack.Stop()
|
||||
|
||||
protocols := stack.Server().Protocols
|
||||
if len(protocols) != 26 {
|
||||
if len(protocols) != 11 {
|
||||
t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26)
|
||||
}
|
||||
for id, config := range services {
|
||||
|
|
|
@ -71,6 +71,7 @@ type dialstate struct {
|
|||
maxDynDials int
|
||||
ntab discoverTable
|
||||
netrestrict *netutil.Netlist
|
||||
self enode.ID
|
||||
|
||||
lookupRunning bool
|
||||
dialing map[enode.ID]connFlag
|
||||
|
@ -84,7 +85,6 @@ type dialstate struct {
|
|||
}
|
||||
|
||||
type discoverTable interface {
|
||||
Self() *enode.Node
|
||||
Close()
|
||||
Resolve(*enode.Node) *enode.Node
|
||||
LookupRandom() []*enode.Node
|
||||
|
@ -126,10 +126,11 @@ type waitExpireTask struct {
|
|||
time.Duration
|
||||
}
|
||||
|
||||
func newDialState(static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
|
||||
func newDialState(self enode.ID, static []*enode.Node, bootnodes []*enode.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
|
||||
s := &dialstate{
|
||||
maxDynDials: maxdyn,
|
||||
ntab: ntab,
|
||||
self: self,
|
||||
netrestrict: netrestrict,
|
||||
static: make(map[enode.ID]*dialTask),
|
||||
dialing: make(map[enode.ID]connFlag),
|
||||
|
@ -266,7 +267,7 @@ func (s *dialstate) checkDial(n *enode.Node, peers map[enode.ID]*Peer) error {
|
|||
return errAlreadyDialing
|
||||
case peers[n.ID()] != nil:
|
||||
return errAlreadyConnected
|
||||
case s.ntab != nil && n.ID() == s.ntab.Self().ID():
|
||||
case n.ID() == s.self:
|
||||
return errSelf
|
||||
case s.netrestrict != nil && !s.netrestrict.Contains(n.IP()):
|
||||
return errNotWhitelisted
|
||||
|
|
|
@ -89,7 +89,7 @@ func (t fakeTable) ReadRandomNodes(buf []*enode.Node) int { return copy(buf, t)
|
|||
// This test checks that dynamic dials are launched from discovery results.
|
||||
func TestDialStateDynDial(t *testing.T) {
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(nil, nil, fakeTable{}, 5, nil),
|
||||
init: newDialState(enode.ID{}, nil, nil, fakeTable{}, 5, nil),
|
||||
rounds: []round{
|
||||
// A discovery query is launched.
|
||||
{
|
||||
|
@ -236,7 +236,7 @@ func TestDialStateDynDialBootnode(t *testing.T) {
|
|||
newNode(uintID(8), nil),
|
||||
}
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(nil, bootnodes, table, 5, nil),
|
||||
init: newDialState(enode.ID{}, nil, bootnodes, table, 5, nil),
|
||||
rounds: []round{
|
||||
// 2 dynamic dials attempted, bootnodes pending fallback interval
|
||||
{
|
||||
|
@ -324,7 +324,7 @@ func TestDialStateDynDialFromTable(t *testing.T) {
|
|||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(nil, nil, table, 10, nil),
|
||||
init: newDialState(enode.ID{}, nil, nil, table, 10, nil),
|
||||
rounds: []round{
|
||||
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
|
||||
{
|
||||
|
@ -430,7 +430,7 @@ func TestDialStateNetRestrict(t *testing.T) {
|
|||
restrict.Add("127.0.2.0/24")
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(nil, nil, table, 10, restrict),
|
||||
init: newDialState(enode.ID{}, nil, nil, table, 10, restrict),
|
||||
rounds: []round{
|
||||
{
|
||||
new: []task{
|
||||
|
@ -453,7 +453,7 @@ func TestDialStateStaticDial(t *testing.T) {
|
|||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
|
||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
||||
rounds: []round{
|
||||
// Static dials are launched for the nodes that
|
||||
// aren't yet connected.
|
||||
|
@ -557,7 +557,7 @@ func TestDialStaticAfterReset(t *testing.T) {
|
|||
},
|
||||
}
|
||||
dTest := dialtest{
|
||||
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
|
||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
||||
rounds: rounds,
|
||||
}
|
||||
runDialTest(t, dTest)
|
||||
|
@ -578,7 +578,7 @@ func TestDialStateCache(t *testing.T) {
|
|||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(wantStatic, nil, fakeTable{}, 0, nil),
|
||||
init: newDialState(enode.ID{}, wantStatic, nil, fakeTable{}, 0, nil),
|
||||
rounds: []round{
|
||||
// Static dials are launched for the nodes that
|
||||
// aren't yet connected.
|
||||
|
@ -640,7 +640,7 @@ func TestDialStateCache(t *testing.T) {
|
|||
func TestDialResolve(t *testing.T) {
|
||||
resolved := newNode(uintID(1), net.IP{127, 0, 55, 234})
|
||||
table := &resolveMock{answer: resolved}
|
||||
state := newDialState(nil, nil, table, 0, nil)
|
||||
state := newDialState(enode.ID{}, nil, nil, table, 0, nil)
|
||||
|
||||
// Check that the task is generated with an incomplete ID.
|
||||
dest := newNode(uintID(1), nil)
|
||||
|
|
|
@ -72,21 +72,20 @@ type Table struct {
|
|||
ips netutil.DistinctNetSet
|
||||
|
||||
db *enode.DB // database of known nodes
|
||||
net transport
|
||||
refreshReq chan chan struct{}
|
||||
initDone chan struct{}
|
||||
closeReq chan struct{}
|
||||
closed chan struct{}
|
||||
|
||||
nodeAddedHook func(*node) // for testing
|
||||
|
||||
net transport
|
||||
self *node // metadata of the local node
|
||||
}
|
||||
|
||||
// transport is implemented by the UDP transport.
|
||||
// it is an interface so we can test without opening lots of UDP
|
||||
// sockets and without generating a private key.
|
||||
type transport interface {
|
||||
self() *enode.Node
|
||||
ping(enode.ID, *net.UDPAddr) error
|
||||
findnode(toid enode.ID, addr *net.UDPAddr, target encPubkey) ([]*node, error)
|
||||
close()
|
||||
|
@ -100,11 +99,10 @@ type bucket struct {
|
|||
ips netutil.DistinctNetSet
|
||||
}
|
||||
|
||||
func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
|
||||
func newTable(t transport, db *enode.DB, bootnodes []*enode.Node) (*Table, error) {
|
||||
tab := &Table{
|
||||
net: t,
|
||||
db: db,
|
||||
self: wrapNode(self),
|
||||
refreshReq: make(chan chan struct{}),
|
||||
initDone: make(chan struct{}),
|
||||
closeReq: make(chan struct{}),
|
||||
|
@ -127,6 +125,10 @@ func newTable(t transport, self *enode.Node, db *enode.DB, bootnodes []*enode.No
|
|||
return tab, nil
|
||||
}
|
||||
|
||||
func (tab *Table) self() *enode.Node {
|
||||
return tab.net.self()
|
||||
}
|
||||
|
||||
func (tab *Table) seedRand() {
|
||||
var b [8]byte
|
||||
crand.Read(b[:])
|
||||
|
@ -136,11 +138,6 @@ func (tab *Table) seedRand() {
|
|||
tab.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Self returns the local node.
|
||||
func (tab *Table) Self() *enode.Node {
|
||||
return unwrapNode(tab.self)
|
||||
}
|
||||
|
||||
// ReadRandomNodes fills the given slice with random nodes from the table. The results
|
||||
// are guaranteed to be unique for a single invocation, no node will appear twice.
|
||||
func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
|
||||
|
@ -183,6 +180,10 @@ func (tab *Table) ReadRandomNodes(buf []*enode.Node) (n int) {
|
|||
|
||||
// Close terminates the network listener and flushes the node database.
|
||||
func (tab *Table) Close() {
|
||||
if tab.net != nil {
|
||||
tab.net.close()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-tab.closed:
|
||||
// already closed.
|
||||
|
@ -257,7 +258,7 @@ func (tab *Table) lookup(targetKey encPubkey, refreshIfEmpty bool) []*node {
|
|||
)
|
||||
// don't query further if we hit ourself.
|
||||
// unlikely to happen often in practice.
|
||||
asked[tab.self.ID()] = true
|
||||
asked[tab.self().ID()] = true
|
||||
|
||||
for {
|
||||
tab.mutex.Lock()
|
||||
|
@ -340,8 +341,8 @@ func (tab *Table) loop() {
|
|||
revalidate = time.NewTimer(tab.nextRevalidateTime())
|
||||
refresh = time.NewTicker(refreshInterval)
|
||||
copyNodes = time.NewTicker(copyNodesInterval)
|
||||
revalidateDone = make(chan struct{})
|
||||
refreshDone = make(chan struct{}) // where doRefresh reports completion
|
||||
revalidateDone chan struct{} // where doRevalidate reports completion
|
||||
waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs
|
||||
)
|
||||
defer refresh.Stop()
|
||||
|
@ -372,9 +373,11 @@ loop:
|
|||
}
|
||||
waiting, refreshDone = nil, nil
|
||||
case <-revalidate.C:
|
||||
revalidateDone = make(chan struct{})
|
||||
go tab.doRevalidate(revalidateDone)
|
||||
case <-revalidateDone:
|
||||
revalidate.Reset(tab.nextRevalidateTime())
|
||||
revalidateDone = nil
|
||||
case <-copyNodes.C:
|
||||
go tab.copyLiveNodes()
|
||||
case <-tab.closeReq:
|
||||
|
@ -382,15 +385,15 @@ loop:
|
|||
}
|
||||
}
|
||||
|
||||
if tab.net != nil {
|
||||
tab.net.close()
|
||||
}
|
||||
if refreshDone != nil {
|
||||
<-refreshDone
|
||||
}
|
||||
for _, ch := range waiting {
|
||||
close(ch)
|
||||
}
|
||||
if revalidateDone != nil {
|
||||
<-revalidateDone
|
||||
}
|
||||
close(tab.closed)
|
||||
}
|
||||
|
||||
|
@ -408,7 +411,7 @@ func (tab *Table) doRefresh(done chan struct{}) {
|
|||
// Run self lookup to discover new neighbor nodes.
|
||||
// We can only do this if we have a secp256k1 identity.
|
||||
var key ecdsa.PublicKey
|
||||
if err := tab.self.Load((*enode.Secp256k1)(&key)); err == nil {
|
||||
if err := tab.self().Load((*enode.Secp256k1)(&key)); err == nil {
|
||||
tab.lookup(encodePubkey(&key), false)
|
||||
}
|
||||
|
||||
|
@ -530,7 +533,7 @@ func (tab *Table) len() (n int) {
|
|||
|
||||
// bucket returns the bucket for the given node ID hash.
|
||||
func (tab *Table) bucket(id enode.ID) *bucket {
|
||||
d := enode.LogDist(tab.self.ID(), id)
|
||||
d := enode.LogDist(tab.self().ID(), id)
|
||||
if d <= bucketMinDistance {
|
||||
return tab.buckets[0]
|
||||
}
|
||||
|
@ -543,7 +546,7 @@ func (tab *Table) bucket(id enode.ID) *bucket {
|
|||
//
|
||||
// The caller must not hold tab.mutex.
|
||||
func (tab *Table) add(n *node) {
|
||||
if n.ID() == tab.self.ID() {
|
||||
if n.ID() == tab.self().ID() {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -576,7 +579,7 @@ func (tab *Table) stuff(nodes []*node) {
|
|||
defer tab.mutex.Unlock()
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() == tab.self.ID() {
|
||||
if n.ID() == tab.self().ID() {
|
||||
continue // don't add self
|
||||
}
|
||||
b := tab.bucket(n.ID())
|
||||
|
|
|
@ -141,7 +141,7 @@ func TestTable_IPLimit(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
for i := 0; i < tableIPLimit+1; i++ {
|
||||
n := nodeAtDistance(tab.self.ID(), i, net.IP{172, 0, 1, byte(i)})
|
||||
n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)})
|
||||
tab.add(n)
|
||||
}
|
||||
if tab.len() > tableIPLimit {
|
||||
|
@ -158,7 +158,7 @@ func TestTable_BucketIPLimit(t *testing.T) {
|
|||
|
||||
d := 3
|
||||
for i := 0; i < bucketIPLimit+1; i++ {
|
||||
n := nodeAtDistance(tab.self.ID(), d, net.IP{172, 0, 1, byte(i)})
|
||||
n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)})
|
||||
tab.add(n)
|
||||
}
|
||||
if tab.len() > bucketIPLimit {
|
||||
|
@ -240,7 +240,7 @@ func TestTable_ReadRandomNodesGetAll(t *testing.T) {
|
|||
|
||||
for i := 0; i < len(buf); i++ {
|
||||
ld := cfg.Rand.Intn(len(tab.buckets))
|
||||
tab.stuff([]*node{nodeAtDistance(tab.self.ID(), ld, intIP(ld))})
|
||||
tab.stuff([]*node{nodeAtDistance(tab.self().ID(), ld, intIP(ld))})
|
||||
}
|
||||
gotN := tab.ReadRandomNodes(buf)
|
||||
if gotN != tab.len() {
|
||||
|
@ -510,6 +510,10 @@ type preminedTestnet struct {
|
|||
dists [hashBits + 1][]encPubkey
|
||||
}
|
||||
|
||||
func (tn *preminedTestnet) self() *enode.Node {
|
||||
return nullNode
|
||||
}
|
||||
|
||||
func (tn *preminedTestnet) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
|
||||
// current log distance is encoded in port number
|
||||
// fmt.Println("findnode query at dist", toaddr.Port)
|
||||
|
|
|
@ -28,12 +28,17 @@ import (
|
|||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
)
|
||||
|
||||
func newTestTable(t transport) (*Table, *enode.DB) {
|
||||
var nullNode *enode.Node
|
||||
|
||||
func init() {
|
||||
var r enr.Record
|
||||
r.Set(enr.IP{0, 0, 0, 0})
|
||||
n := enode.SignNull(&r, enode.ID{})
|
||||
nullNode = enode.SignNull(&r, enode.ID{})
|
||||
}
|
||||
|
||||
func newTestTable(t transport) (*Table, *enode.DB) {
|
||||
db, _ := enode.OpenDB("")
|
||||
tab, _ := newTable(t, n, db, nil)
|
||||
tab, _ := newTable(t, db, nil)
|
||||
return tab, db
|
||||
}
|
||||
|
||||
|
@ -70,10 +75,10 @@ func intIP(i int) net.IP {
|
|||
|
||||
// fillBucket inserts nodes into the given bucket until it is full.
|
||||
func fillBucket(tab *Table, n *node) (last *node) {
|
||||
ld := enode.LogDist(tab.self.ID(), n.ID())
|
||||
ld := enode.LogDist(tab.self().ID(), n.ID())
|
||||
b := tab.bucket(n.ID())
|
||||
for len(b.entries) < bucketSize {
|
||||
b.entries = append(b.entries, nodeAtDistance(tab.self.ID(), ld, intIP(ld)))
|
||||
b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld)))
|
||||
}
|
||||
return b.entries[bucketSize-1]
|
||||
}
|
||||
|
@ -81,15 +86,25 @@ func fillBucket(tab *Table, n *node) (last *node) {
|
|||
type pingRecorder struct {
|
||||
mu sync.Mutex
|
||||
dead, pinged map[enode.ID]bool
|
||||
n *enode.Node
|
||||
}
|
||||
|
||||
func newPingRecorder() *pingRecorder {
|
||||
var r enr.Record
|
||||
r.Set(enr.IP{0, 0, 0, 0})
|
||||
n := enode.SignNull(&r, enode.ID{})
|
||||
|
||||
return &pingRecorder{
|
||||
dead: make(map[enode.ID]bool),
|
||||
pinged: make(map[enode.ID]bool),
|
||||
n: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *pingRecorder) self() *enode.Node {
|
||||
return nullNode
|
||||
}
|
||||
|
||||
func (t *pingRecorder) findnode(toid enode.ID, toaddr *net.UDPAddr, target encPubkey) ([]*node, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -23,12 +23,12 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/p2p/netutil"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
@ -118,9 +118,11 @@ type (
|
|||
)
|
||||
|
||||
func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
|
||||
ip := addr.IP.To4()
|
||||
if ip == nil {
|
||||
ip = addr.IP.To16()
|
||||
ip := net.IP{}
|
||||
if ip4 := addr.IP.To4(); ip4 != nil {
|
||||
ip = ip4
|
||||
} else if ip6 := addr.IP.To16(); ip6 != nil {
|
||||
ip = ip6
|
||||
}
|
||||
return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
|
||||
}
|
||||
|
@ -165,20 +167,19 @@ type conn interface {
|
|||
LocalAddr() net.Addr
|
||||
}
|
||||
|
||||
// udp implements the RPC protocol.
|
||||
// udp implements the discovery v4 UDP wire protocol.
|
||||
type udp struct {
|
||||
conn conn
|
||||
netrestrict *netutil.Netlist
|
||||
priv *ecdsa.PrivateKey
|
||||
ourEndpoint rpcEndpoint
|
||||
localNode *enode.LocalNode
|
||||
db *enode.DB
|
||||
tab *Table
|
||||
wg sync.WaitGroup
|
||||
|
||||
addpending chan *pending
|
||||
gotreply chan reply
|
||||
|
||||
closing chan struct{}
|
||||
nat nat.Interface
|
||||
|
||||
*Table
|
||||
}
|
||||
|
||||
// pending represents a pending reply.
|
||||
|
@ -230,60 +231,57 @@ type Config struct {
|
|||
PrivateKey *ecdsa.PrivateKey
|
||||
|
||||
// These settings are optional:
|
||||
AnnounceAddr *net.UDPAddr // local address announced in the DHT
|
||||
NodeDBPath string // if set, the node database is stored at this filesystem location
|
||||
NetRestrict *netutil.Netlist // network whitelist
|
||||
Bootnodes []*enode.Node // list of bootstrap nodes
|
||||
Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
|
||||
}
|
||||
|
||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||
func ListenUDP(c conn, cfg Config) (*Table, error) {
|
||||
tab, _, err := newUDP(c, cfg)
|
||||
func ListenUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, error) {
|
||||
tab, _, err := newUDP(c, ln, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("UDP listener up", "self", tab.self)
|
||||
return tab, nil
|
||||
}
|
||||
|
||||
func newUDP(c conn, cfg Config) (*Table, *udp, error) {
|
||||
realaddr := c.LocalAddr().(*net.UDPAddr)
|
||||
if cfg.AnnounceAddr != nil {
|
||||
realaddr = cfg.AnnounceAddr
|
||||
}
|
||||
self := enode.NewV4(&cfg.PrivateKey.PublicKey, realaddr.IP, realaddr.Port, realaddr.Port)
|
||||
db, err := enode.OpenDB(cfg.NodeDBPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func newUDP(c conn, ln *enode.LocalNode, cfg Config) (*Table, *udp, error) {
|
||||
udp := &udp{
|
||||
conn: c,
|
||||
priv: cfg.PrivateKey,
|
||||
netrestrict: cfg.NetRestrict,
|
||||
localNode: ln,
|
||||
db: ln.Database(),
|
||||
closing: make(chan struct{}),
|
||||
gotreply: make(chan reply),
|
||||
addpending: make(chan *pending),
|
||||
}
|
||||
// TODO: separate TCP port
|
||||
udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
|
||||
tab, err := newTable(udp, self, db, cfg.Bootnodes)
|
||||
tab, err := newTable(udp, ln.Database(), cfg.Bootnodes)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
udp.Table = tab
|
||||
udp.tab = tab
|
||||
|
||||
udp.wg.Add(2)
|
||||
go udp.loop()
|
||||
go udp.readLoop(cfg.Unhandled)
|
||||
return udp.Table, udp, nil
|
||||
return udp.tab, udp, nil
|
||||
}
|
||||
|
||||
func (t *udp) self() *enode.Node {
|
||||
return t.localNode.Node()
|
||||
}
|
||||
|
||||
func (t *udp) close() {
|
||||
close(t.closing)
|
||||
t.conn.Close()
|
||||
t.db.Close()
|
||||
// TODO: wait for the loops to end.
|
||||
t.wg.Wait()
|
||||
}
|
||||
|
||||
func (t *udp) ourEndpoint() rpcEndpoint {
|
||||
n := t.self()
|
||||
a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()}
|
||||
return makeEndpoint(a, uint16(n.TCP()))
|
||||
}
|
||||
|
||||
// ping sends a ping message to the given node and waits for a reply.
|
||||
|
@ -296,7 +294,7 @@ func (t *udp) ping(toid enode.ID, toaddr *net.UDPAddr) error {
|
|||
func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-chan error {
|
||||
req := &ping{
|
||||
Version: 4,
|
||||
From: t.ourEndpoint,
|
||||
From: t.ourEndpoint(),
|
||||
To: makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
|
||||
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||
}
|
||||
|
@ -313,6 +311,7 @@ func (t *udp) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) <-ch
|
|||
}
|
||||
return ok
|
||||
})
|
||||
t.localNode.UDPContact(toaddr)
|
||||
t.write(toaddr, req.name(), packet)
|
||||
return errc
|
||||
}
|
||||
|
@ -381,6 +380,8 @@ func (t *udp) handleReply(from enode.ID, ptype byte, req packet) bool {
|
|||
// loop runs in its own goroutine. it keeps track of
|
||||
// the refresh timer and the pending reply queue.
|
||||
func (t *udp) loop() {
|
||||
defer t.wg.Done()
|
||||
|
||||
var (
|
||||
plist = list.New()
|
||||
timeout = time.NewTimer(0)
|
||||
|
@ -542,10 +543,11 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet,
|
|||
|
||||
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||
func (t *udp) readLoop(unhandled chan<- ReadPacket) {
|
||||
defer t.conn.Close()
|
||||
defer t.wg.Done()
|
||||
if unhandled != nil {
|
||||
defer close(unhandled)
|
||||
}
|
||||
|
||||
// Discovery packets are defined to be no larger than 1280 bytes.
|
||||
// Packets larger than this size will be cut at the end and treated
|
||||
// as invalid because their hash won't match.
|
||||
|
@ -629,10 +631,11 @@ func (req *ping) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte
|
|||
n := wrapNode(enode.NewV4(key, from.IP, int(req.From.TCP), from.Port))
|
||||
t.handleReply(n.ID(), pingPacket, req)
|
||||
if time.Since(t.db.LastPongReceived(n.ID())) > bondExpiration {
|
||||
t.sendPing(n.ID(), from, func() { t.addThroughPing(n) })
|
||||
t.sendPing(n.ID(), from, func() { t.tab.addThroughPing(n) })
|
||||
} else {
|
||||
t.addThroughPing(n)
|
||||
t.tab.addThroughPing(n)
|
||||
}
|
||||
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
|
||||
t.db.UpdateLastPingReceived(n.ID(), time.Now())
|
||||
return nil
|
||||
}
|
||||
|
@ -647,6 +650,7 @@ func (req *pong) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []byte
|
|||
if !t.handleReply(fromID, pongPacket, req) {
|
||||
return errUnsolicitedReply
|
||||
}
|
||||
t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)})
|
||||
t.db.UpdateLastPongReceived(fromID, time.Now())
|
||||
return nil
|
||||
}
|
||||
|
@ -668,9 +672,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromKey encPubkey, mac []
|
|||
return errUnknownNode
|
||||
}
|
||||
target := enode.ID(crypto.Keccak256Hash(req.Target[:]))
|
||||
t.mutex.Lock()
|
||||
closest := t.closest(target, bucketSize).entries
|
||||
t.mutex.Unlock()
|
||||
t.tab.mutex.Lock()
|
||||
closest := t.tab.closest(target, bucketSize).entries
|
||||
t.tab.mutex.Unlock()
|
||||
|
||||
p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
|
||||
var sent bool
|
||||
|
|
|
@ -71,7 +71,9 @@ func newUDPTest(t *testing.T) *udpTest {
|
|||
remotekey: newkey(),
|
||||
remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303},
|
||||
}
|
||||
test.table, test.udp, _ = newUDP(test.pipe, Config{PrivateKey: test.localkey})
|
||||
db, _ := enode.OpenDB("")
|
||||
ln := enode.NewLocalNode(db, test.localkey)
|
||||
test.table, test.udp, _ = newUDP(test.pipe, ln, Config{PrivateKey: test.localkey})
|
||||
// Wait for initial refresh so the table doesn't send unexpected findnode.
|
||||
<-test.table.initDone
|
||||
return test
|
||||
|
@ -355,12 +357,13 @@ func TestUDP_successfulPing(t *testing.T) {
|
|||
|
||||
// remote is unknown, the table pings back.
|
||||
hash, _ := test.waitPacketOut(func(p *ping) error {
|
||||
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint) {
|
||||
t.Errorf("got ping.From %v, want %v", p.From, test.udp.ourEndpoint)
|
||||
if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) {
|
||||
t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint())
|
||||
}
|
||||
wantTo := rpcEndpoint{
|
||||
// The mirrored UDP address is the UDP packet sender.
|
||||
IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port),
|
||||
IP: test.remoteaddr.IP,
|
||||
UDP: uint16(test.remoteaddr.Port),
|
||||
TCP: 0,
|
||||
}
|
||||
if !reflect.DeepEqual(p.To, wantTo) {
|
||||
|
|
|
@ -230,7 +230,8 @@ type udp struct {
|
|||
}
|
||||
|
||||
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
|
||||
func ListenUDP(priv *ecdsa.PrivateKey, conn conn, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
|
||||
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
transport, err := listenUDP(priv, conn, realaddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -0,0 +1,246 @@
|
|||
// Copyright 2018 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package enode
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
"github.com/ethereum/go-ethereum/p2p/netutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// IP tracker configuration
|
||||
iptrackMinStatements = 10
|
||||
iptrackWindow = 5 * time.Minute
|
||||
iptrackContactWindow = 10 * time.Minute
|
||||
)
|
||||
|
||||
// LocalNode produces the signed node record of a local node, i.e. a node run in the
|
||||
// current process. Setting ENR entries via the Set method updates the record. A new version
|
||||
// of the record is signed on demand when the Node method is called.
|
||||
type LocalNode struct {
|
||||
cur atomic.Value // holds a non-nil node pointer while the record is up-to-date.
|
||||
id ID
|
||||
key *ecdsa.PrivateKey
|
||||
db *DB
|
||||
|
||||
// everything below is protected by a lock
|
||||
mu sync.Mutex
|
||||
seq uint64
|
||||
entries map[string]enr.Entry
|
||||
udpTrack *netutil.IPTracker // predicts external UDP endpoint
|
||||
staticIP net.IP
|
||||
fallbackIP net.IP
|
||||
fallbackUDP int
|
||||
}
|
||||
|
||||
// NewLocalNode creates a local node.
|
||||
func NewLocalNode(db *DB, key *ecdsa.PrivateKey) *LocalNode {
|
||||
ln := &LocalNode{
|
||||
id: PubkeyToIDV4(&key.PublicKey),
|
||||
db: db,
|
||||
key: key,
|
||||
udpTrack: netutil.NewIPTracker(iptrackWindow, iptrackContactWindow, iptrackMinStatements),
|
||||
entries: make(map[string]enr.Entry),
|
||||
}
|
||||
ln.seq = db.localSeq(ln.id)
|
||||
ln.invalidate()
|
||||
return ln
|
||||
}
|
||||
|
||||
// Database returns the node database associated with the local node.
|
||||
func (ln *LocalNode) Database() *DB {
|
||||
return ln.db
|
||||
}
|
||||
|
||||
// Node returns the current version of the local node record.
|
||||
func (ln *LocalNode) Node() *Node {
|
||||
n := ln.cur.Load().(*Node)
|
||||
if n != nil {
|
||||
return n
|
||||
}
|
||||
// Record was invalidated, sign a new copy.
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
ln.sign()
|
||||
return ln.cur.Load().(*Node)
|
||||
}
|
||||
|
||||
// ID returns the local node ID.
|
||||
func (ln *LocalNode) ID() ID {
|
||||
return ln.id
|
||||
}
|
||||
|
||||
// Set puts the given entry into the local record, overwriting
|
||||
// any existing value.
|
||||
func (ln *LocalNode) Set(e enr.Entry) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.set(e)
|
||||
}
|
||||
|
||||
func (ln *LocalNode) set(e enr.Entry) {
|
||||
val, exists := ln.entries[e.ENRKey()]
|
||||
if !exists || !reflect.DeepEqual(val, e) {
|
||||
ln.entries[e.ENRKey()] = e
|
||||
ln.invalidate()
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes the given entry from the local record.
|
||||
func (ln *LocalNode) Delete(e enr.Entry) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.delete(e)
|
||||
}
|
||||
|
||||
func (ln *LocalNode) delete(e enr.Entry) {
|
||||
_, exists := ln.entries[e.ENRKey()]
|
||||
if exists {
|
||||
delete(ln.entries, e.ENRKey())
|
||||
ln.invalidate()
|
||||
}
|
||||
}
|
||||
|
||||
// SetStaticIP sets the local IP to the given one unconditionally.
|
||||
// This disables endpoint prediction.
|
||||
func (ln *LocalNode) SetStaticIP(ip net.IP) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.staticIP = ip
|
||||
ln.updateEndpoints()
|
||||
}
|
||||
|
||||
// SetFallbackIP sets the last-resort IP address. This address is used
|
||||
// if no endpoint prediction can be made and no static IP is set.
|
||||
func (ln *LocalNode) SetFallbackIP(ip net.IP) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.fallbackIP = ip
|
||||
ln.updateEndpoints()
|
||||
}
|
||||
|
||||
// SetFallbackUDP sets the last-resort UDP port. This port is used
|
||||
// if no endpoint prediction can be made.
|
||||
func (ln *LocalNode) SetFallbackUDP(port int) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.fallbackUDP = port
|
||||
ln.updateEndpoints()
|
||||
}
|
||||
|
||||
// UDPEndpointStatement should be called whenever a statement about the local node's
|
||||
// UDP endpoint is received. It feeds the local endpoint predictor.
|
||||
func (ln *LocalNode) UDPEndpointStatement(fromaddr, endpoint *net.UDPAddr) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.udpTrack.AddStatement(fromaddr.String(), endpoint.String())
|
||||
ln.updateEndpoints()
|
||||
}
|
||||
|
||||
// UDPContact should be called whenever the local node has announced itself to another node
|
||||
// via UDP. It feeds the local endpoint predictor.
|
||||
func (ln *LocalNode) UDPContact(toaddr *net.UDPAddr) {
|
||||
ln.mu.Lock()
|
||||
defer ln.mu.Unlock()
|
||||
|
||||
ln.udpTrack.AddContact(toaddr.String())
|
||||
ln.updateEndpoints()
|
||||
}
|
||||
|
||||
func (ln *LocalNode) updateEndpoints() {
|
||||
// Determine the endpoints.
|
||||
newIP := ln.fallbackIP
|
||||
newUDP := ln.fallbackUDP
|
||||
if ln.staticIP != nil {
|
||||
newIP = ln.staticIP
|
||||
} else if ip, port := predictAddr(ln.udpTrack); ip != nil {
|
||||
newIP = ip
|
||||
newUDP = port
|
||||
}
|
||||
|
||||
// Update the record.
|
||||
if newIP != nil && !newIP.IsUnspecified() {
|
||||
ln.set(enr.IP(newIP))
|
||||
if newUDP != 0 {
|
||||
ln.set(enr.UDP(newUDP))
|
||||
} else {
|
||||
ln.delete(enr.UDP(0))
|
||||
}
|
||||
} else {
|
||||
ln.delete(enr.IP{})
|
||||
}
|
||||
}
|
||||
|
||||
// predictAddr wraps IPTracker.PredictEndpoint, converting from its string-based
|
||||
// endpoint representation to IP and port types.
|
||||
func predictAddr(t *netutil.IPTracker) (net.IP, int) {
|
||||
ep := t.PredictEndpoint()
|
||||
if ep == "" {
|
||||
return nil, 0
|
||||
}
|
||||
ipString, portString, _ := net.SplitHostPort(ep)
|
||||
ip := net.ParseIP(ipString)
|
||||
port, _ := strconv.Atoi(portString)
|
||||
return ip, port
|
||||
}
|
||||
|
||||
func (ln *LocalNode) invalidate() {
|
||||
ln.cur.Store((*Node)(nil))
|
||||
}
|
||||
|
||||
func (ln *LocalNode) sign() {
|
||||
if n := ln.cur.Load().(*Node); n != nil {
|
||||
return // no changes
|
||||
}
|
||||
|
||||
var r enr.Record
|
||||
for _, e := range ln.entries {
|
||||
r.Set(e)
|
||||
}
|
||||
ln.bumpSeq()
|
||||
r.SetSeq(ln.seq)
|
||||
if err := SignV4(&r, ln.key); err != nil {
|
||||
panic(fmt.Errorf("enode: can't sign record: %v", err))
|
||||
}
|
||||
n, err := New(ValidSchemes, &r)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("enode: can't verify local record: %v", err))
|
||||
}
|
||||
ln.cur.Store(n)
|
||||
log.Info("New local node record", "seq", ln.seq, "id", n.ID(), "ip", n.IP(), "udp", n.UDP(), "tcp", n.TCP())
|
||||
}
|
||||
|
||||
func (ln *LocalNode) bumpSeq() {
|
||||
ln.seq++
|
||||
ln.db.storeLocalSeq(ln.id, ln.seq)
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2018 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package enode
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
)
|
||||
|
||||
func newLocalNodeForTesting() (*LocalNode, *DB) {
|
||||
db, _ := OpenDB("")
|
||||
key, _ := crypto.GenerateKey()
|
||||
return NewLocalNode(db, key), db
|
||||
}
|
||||
|
||||
func TestLocalNode(t *testing.T) {
|
||||
ln, db := newLocalNodeForTesting()
|
||||
defer db.Close()
|
||||
|
||||
if ln.Node().ID() != ln.ID() {
|
||||
t.Fatal("inconsistent ID")
|
||||
}
|
||||
|
||||
ln.Set(enr.WithEntry("x", uint(3)))
|
||||
var x uint
|
||||
if err := ln.Node().Load(enr.WithEntry("x", &x)); err != nil {
|
||||
t.Fatal("can't load entry 'x':", err)
|
||||
} else if x != 3 {
|
||||
t.Fatal("wrong value for entry 'x':", x)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalNodeSeqPersist(t *testing.T) {
|
||||
ln, db := newLocalNodeForTesting()
|
||||
defer db.Close()
|
||||
|
||||
if s := ln.Node().Seq(); s != 1 {
|
||||
t.Fatalf("wrong initial seq %d, want 1", s)
|
||||
}
|
||||
ln.Set(enr.WithEntry("x", uint(1)))
|
||||
if s := ln.Node().Seq(); s != 2 {
|
||||
t.Fatalf("wrong seq %d after set, want 2", s)
|
||||
}
|
||||
|
||||
// Create a new instance, it should reload the sequence number.
|
||||
// The number increases just after that because a new record is
|
||||
// created without the "x" entry.
|
||||
ln2 := NewLocalNode(db, ln.key)
|
||||
if s := ln2.Node().Seq(); s != 3 {
|
||||
t.Fatalf("wrong seq %d on new instance, want 3", s)
|
||||
}
|
||||
|
||||
// Create a new instance with a different node key on the same database.
|
||||
// This should reset the sequence number.
|
||||
key, _ := crypto.GenerateKey()
|
||||
ln3 := NewLocalNode(db, key)
|
||||
if s := ln3.Node().Seq(); s != 1 {
|
||||
t.Fatalf("wrong seq %d on instance with changed key, want 1", s)
|
||||
}
|
||||
}
|
|
@ -98,6 +98,13 @@ func (n *Node) Pubkey() *ecdsa.PublicKey {
|
|||
return &key
|
||||
}
|
||||
|
||||
// Record returns the node's record. The return value is a copy and may
|
||||
// be modified by the caller.
|
||||
func (n *Node) Record() *enr.Record {
|
||||
cpy := n.r
|
||||
return &cpy
|
||||
}
|
||||
|
||||
// checks whether n is a valid complete node.
|
||||
func (n *Node) ValidateComplete() error {
|
||||
if n.Incomplete() {
|
||||
|
|
|
@ -35,11 +35,24 @@ import (
|
|||
"github.com/syndtr/goleveldb/leveldb/util"
|
||||
)
|
||||
|
||||
// Keys in the node database.
|
||||
const (
|
||||
dbVersionKey = "version" // Version of the database to flush if changes
|
||||
dbItemPrefix = "n:" // Identifier to prefix node entries with
|
||||
|
||||
dbDiscoverRoot = ":discover"
|
||||
dbDiscoverSeq = dbDiscoverRoot + ":seq"
|
||||
dbDiscoverPing = dbDiscoverRoot + ":lastping"
|
||||
dbDiscoverPong = dbDiscoverRoot + ":lastpong"
|
||||
dbDiscoverFindFails = dbDiscoverRoot + ":findfail"
|
||||
dbLocalRoot = ":local"
|
||||
dbLocalSeq = dbLocalRoot + ":seq"
|
||||
)
|
||||
|
||||
var (
|
||||
nodeDBNilID = ID{} // Special node ID to use as a nil element.
|
||||
nodeDBNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
|
||||
nodeDBCleanupCycle = time.Hour // Time period for running the expiration task.
|
||||
nodeDBVersion = 6
|
||||
dbNodeExpiration = 24 * time.Hour // Time after which an unseen node should be dropped.
|
||||
dbCleanupCycle = time.Hour // Time period for running the expiration task.
|
||||
dbVersion = 7
|
||||
)
|
||||
|
||||
// DB is the node database, storing previously seen nodes and any collected metadata about
|
||||
|
@ -50,17 +63,6 @@ type DB struct {
|
|||
quit chan struct{} // Channel to signal the expiring thread to stop
|
||||
}
|
||||
|
||||
// Schema layout for the node database
|
||||
var (
|
||||
nodeDBVersionKey = []byte("version") // Version of the database to flush if changes
|
||||
nodeDBItemPrefix = []byte("n:") // Identifier to prefix node entries with
|
||||
|
||||
nodeDBDiscoverRoot = ":discover"
|
||||
nodeDBDiscoverPing = nodeDBDiscoverRoot + ":lastping"
|
||||
nodeDBDiscoverPong = nodeDBDiscoverRoot + ":lastpong"
|
||||
nodeDBDiscoverFindFails = nodeDBDiscoverRoot + ":findfail"
|
||||
)
|
||||
|
||||
// OpenDB opens a node database for storing and retrieving infos about known peers in the
|
||||
// network. If no path is given an in-memory, temporary database is constructed.
|
||||
func OpenDB(path string) (*DB, error) {
|
||||
|
@ -93,13 +95,13 @@ func newPersistentDB(path string) (*DB, error) {
|
|||
// The nodes contained in the cache correspond to a certain protocol version.
|
||||
// Flush all nodes if the version doesn't match.
|
||||
currentVer := make([]byte, binary.MaxVarintLen64)
|
||||
currentVer = currentVer[:binary.PutVarint(currentVer, int64(nodeDBVersion))]
|
||||
currentVer = currentVer[:binary.PutVarint(currentVer, int64(dbVersion))]
|
||||
|
||||
blob, err := db.Get(nodeDBVersionKey, nil)
|
||||
blob, err := db.Get([]byte(dbVersionKey), nil)
|
||||
switch err {
|
||||
case leveldb.ErrNotFound:
|
||||
// Version not found (i.e. empty cache), insert it
|
||||
if err := db.Put(nodeDBVersionKey, currentVer, nil); err != nil {
|
||||
if err := db.Put([]byte(dbVersionKey), currentVer, nil); err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
@ -120,28 +122,27 @@ func newPersistentDB(path string) (*DB, error) {
|
|||
// makeKey generates the leveldb key-blob from a node id and its particular
|
||||
// field of interest.
|
||||
func makeKey(id ID, field string) []byte {
|
||||
if bytes.Equal(id[:], nodeDBNilID[:]) {
|
||||
if (id == ID{}) {
|
||||
return []byte(field)
|
||||
}
|
||||
return append(nodeDBItemPrefix, append(id[:], field...)...)
|
||||
return append([]byte(dbItemPrefix), append(id[:], field...)...)
|
||||
}
|
||||
|
||||
// splitKey tries to split a database key into a node id and a field part.
|
||||
func splitKey(key []byte) (id ID, field string) {
|
||||
// If the key is not of a node, return it plainly
|
||||
if !bytes.HasPrefix(key, nodeDBItemPrefix) {
|
||||
if !bytes.HasPrefix(key, []byte(dbItemPrefix)) {
|
||||
return ID{}, string(key)
|
||||
}
|
||||
// Otherwise split the id and field
|
||||
item := key[len(nodeDBItemPrefix):]
|
||||
item := key[len(dbItemPrefix):]
|
||||
copy(id[:], item[:len(id)])
|
||||
field = string(item[len(id):])
|
||||
|
||||
return id, field
|
||||
}
|
||||
|
||||
// fetchInt64 retrieves an integer instance associated with a particular
|
||||
// database key.
|
||||
// fetchInt64 retrieves an integer associated with a particular key.
|
||||
func (db *DB) fetchInt64(key []byte) int64 {
|
||||
blob, err := db.lvl.Get(key, nil)
|
||||
if err != nil {
|
||||
|
@ -154,18 +155,33 @@ func (db *DB) fetchInt64(key []byte) int64 {
|
|||
return val
|
||||
}
|
||||
|
||||
// storeInt64 update a specific database entry to the current time instance as a
|
||||
// unix timestamp.
|
||||
// storeInt64 stores an integer in the given key.
|
||||
func (db *DB) storeInt64(key []byte, n int64) error {
|
||||
blob := make([]byte, binary.MaxVarintLen64)
|
||||
blob = blob[:binary.PutVarint(blob, n)]
|
||||
return db.lvl.Put(key, blob, nil)
|
||||
}
|
||||
|
||||
// fetchUint64 retrieves an integer associated with a particular key.
|
||||
func (db *DB) fetchUint64(key []byte) uint64 {
|
||||
blob, err := db.lvl.Get(key, nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
val, _ := binary.Uvarint(blob)
|
||||
return val
|
||||
}
|
||||
|
||||
// storeUint64 stores an integer in the given key.
|
||||
func (db *DB) storeUint64(key []byte, n uint64) error {
|
||||
blob := make([]byte, binary.MaxVarintLen64)
|
||||
blob = blob[:binary.PutUvarint(blob, n)]
|
||||
return db.lvl.Put(key, blob, nil)
|
||||
}
|
||||
|
||||
// Node retrieves a node with a given id from the database.
|
||||
func (db *DB) Node(id ID) *Node {
|
||||
blob, err := db.lvl.Get(makeKey(id, nodeDBDiscoverRoot), nil)
|
||||
blob, err := db.lvl.Get(makeKey(id, dbDiscoverRoot), nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -184,11 +200,31 @@ func mustDecodeNode(id, data []byte) *Node {
|
|||
|
||||
// UpdateNode inserts - potentially overwriting - a node into the peer database.
|
||||
func (db *DB) UpdateNode(node *Node) error {
|
||||
if node.Seq() < db.NodeSeq(node.ID()) {
|
||||
return nil
|
||||
}
|
||||
blob, err := rlp.EncodeToBytes(&node.r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.lvl.Put(makeKey(node.ID(), nodeDBDiscoverRoot), blob, nil)
|
||||
if err := db.lvl.Put(makeKey(node.ID(), dbDiscoverRoot), blob, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return db.storeUint64(makeKey(node.ID(), dbDiscoverSeq), node.Seq())
|
||||
}
|
||||
|
||||
// NodeSeq returns the stored record sequence number of the given node.
|
||||
func (db *DB) NodeSeq(id ID) uint64 {
|
||||
return db.fetchUint64(makeKey(id, dbDiscoverSeq))
|
||||
}
|
||||
|
||||
// Resolve returns the stored record of the node if it has a larger sequence
|
||||
// number than n.
|
||||
func (db *DB) Resolve(n *Node) *Node {
|
||||
if n.Seq() > db.NodeSeq(n.ID()) {
|
||||
return n
|
||||
}
|
||||
return db.Node(n.ID())
|
||||
}
|
||||
|
||||
// DeleteNode deletes all information/keys associated with a node.
|
||||
|
@ -218,7 +254,7 @@ func (db *DB) ensureExpirer() {
|
|||
// expirer should be started in a go routine, and is responsible for looping ad
|
||||
// infinitum and dropping stale data from the database.
|
||||
func (db *DB) expirer() {
|
||||
tick := time.NewTicker(nodeDBCleanupCycle)
|
||||
tick := time.NewTicker(dbCleanupCycle)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
|
@ -235,7 +271,7 @@ func (db *DB) expirer() {
|
|||
// expireNodes iterates over the database and deletes all nodes that have not
|
||||
// been seen (i.e. received a pong from) for some allotted time.
|
||||
func (db *DB) expireNodes() error {
|
||||
threshold := time.Now().Add(-nodeDBNodeExpiration)
|
||||
threshold := time.Now().Add(-dbNodeExpiration)
|
||||
|
||||
// Find discovered nodes that are older than the allowance
|
||||
it := db.lvl.NewIterator(nil, nil)
|
||||
|
@ -244,7 +280,7 @@ func (db *DB) expireNodes() error {
|
|||
for it.Next() {
|
||||
// Skip the item if not a discovery node
|
||||
id, field := splitKey(it.Key())
|
||||
if field != nodeDBDiscoverRoot {
|
||||
if field != dbDiscoverRoot {
|
||||
continue
|
||||
}
|
||||
// Skip the node if not expired yet (and not self)
|
||||
|
@ -260,34 +296,44 @@ func (db *DB) expireNodes() error {
|
|||
// LastPingReceived retrieves the time of the last ping packet received from
|
||||
// a remote node.
|
||||
func (db *DB) LastPingReceived(id ID) time.Time {
|
||||
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPing)), 0)
|
||||
return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPing)), 0)
|
||||
}
|
||||
|
||||
// UpdateLastPingReceived updates the last time we tried contacting a remote node.
|
||||
func (db *DB) UpdateLastPingReceived(id ID, instance time.Time) error {
|
||||
return db.storeInt64(makeKey(id, nodeDBDiscoverPing), instance.Unix())
|
||||
return db.storeInt64(makeKey(id, dbDiscoverPing), instance.Unix())
|
||||
}
|
||||
|
||||
// LastPongReceived retrieves the time of the last successful pong from remote node.
|
||||
func (db *DB) LastPongReceived(id ID) time.Time {
|
||||
// Launch expirer
|
||||
db.ensureExpirer()
|
||||
return time.Unix(db.fetchInt64(makeKey(id, nodeDBDiscoverPong)), 0)
|
||||
return time.Unix(db.fetchInt64(makeKey(id, dbDiscoverPong)), 0)
|
||||
}
|
||||
|
||||
// UpdateLastPongReceived updates the last pong time of a node.
|
||||
func (db *DB) UpdateLastPongReceived(id ID, instance time.Time) error {
|
||||
return db.storeInt64(makeKey(id, nodeDBDiscoverPong), instance.Unix())
|
||||
return db.storeInt64(makeKey(id, dbDiscoverPong), instance.Unix())
|
||||
}
|
||||
|
||||
// FindFails retrieves the number of findnode failures since bonding.
|
||||
func (db *DB) FindFails(id ID) int {
|
||||
return int(db.fetchInt64(makeKey(id, nodeDBDiscoverFindFails)))
|
||||
return int(db.fetchInt64(makeKey(id, dbDiscoverFindFails)))
|
||||
}
|
||||
|
||||
// UpdateFindFails updates the number of findnode failures since bonding.
|
||||
func (db *DB) UpdateFindFails(id ID, fails int) error {
|
||||
return db.storeInt64(makeKey(id, nodeDBDiscoverFindFails), int64(fails))
|
||||
return db.storeInt64(makeKey(id, dbDiscoverFindFails), int64(fails))
|
||||
}
|
||||
|
||||
// LocalSeq retrieves the local record sequence counter.
|
||||
func (db *DB) localSeq(id ID) uint64 {
|
||||
return db.fetchUint64(makeKey(id, dbLocalSeq))
|
||||
}
|
||||
|
||||
// storeLocalSeq stores the local record sequence counter.
|
||||
func (db *DB) storeLocalSeq(id ID, n uint64) {
|
||||
db.storeUint64(makeKey(id, dbLocalSeq), n)
|
||||
}
|
||||
|
||||
// QuerySeeds retrieves random nodes to be used as potential seed nodes
|
||||
|
@ -309,7 +355,7 @@ seek:
|
|||
ctr := id[0]
|
||||
rand.Read(id[:])
|
||||
id[0] = ctr + id[0]%16
|
||||
it.Seek(makeKey(id, nodeDBDiscoverRoot))
|
||||
it.Seek(makeKey(id, dbDiscoverRoot))
|
||||
|
||||
n := nextNode(it)
|
||||
if n == nil {
|
||||
|
@ -334,7 +380,7 @@ seek:
|
|||
func nextNode(it iterator.Iterator) *Node {
|
||||
for end := false; !end; end = !it.Next() {
|
||||
id, field := splitKey(it.Key())
|
||||
if field != nodeDBDiscoverRoot {
|
||||
if field != dbDiscoverRoot {
|
||||
continue
|
||||
}
|
||||
return mustDecodeNode(id[:], it.Value())
|
||||
|
|
|
@ -332,7 +332,7 @@ var nodeDBExpirationNodes = []struct {
|
|||
30303,
|
||||
30303,
|
||||
),
|
||||
pong: time.Now().Add(-nodeDBNodeExpiration + time.Minute),
|
||||
pong: time.Now().Add(-dbNodeExpiration + time.Minute),
|
||||
exp: false,
|
||||
}, {
|
||||
node: NewV4(
|
||||
|
@ -341,7 +341,7 @@ var nodeDBExpirationNodes = []struct {
|
|||
30303,
|
||||
30303,
|
||||
),
|
||||
pong: time.Now().Add(-nodeDBNodeExpiration - time.Minute),
|
||||
pong: time.Now().Add(-dbNodeExpiration - time.Minute),
|
||||
exp: true,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -156,7 +156,7 @@ func (r *Record) Set(e Entry) {
|
|||
}
|
||||
|
||||
func (r *Record) invalidate() {
|
||||
if r.signature == nil {
|
||||
if r.signature != nil {
|
||||
r.seq++
|
||||
}
|
||||
r.signature = nil
|
||||
|
|
|
@ -169,6 +169,18 @@ func TestDirty(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSeq(t *testing.T) {
|
||||
var r Record
|
||||
|
||||
assert.Equal(t, uint64(0), r.Seq())
|
||||
r.Set(UDP(1))
|
||||
assert.Equal(t, uint64(0), r.Seq())
|
||||
signTest([]byte{5}, &r)
|
||||
assert.Equal(t, uint64(0), r.Seq())
|
||||
r.Set(UDP(2))
|
||||
assert.Equal(t, uint64(1), r.Seq())
|
||||
}
|
||||
|
||||
// TestGetSetOverwrite tests value overwrite when setting a new value with an existing key in record.
|
||||
func TestGetSetOverwrite(t *testing.T) {
|
||||
var r Record
|
||||
|
|
|
@ -129,21 +129,15 @@ func Map(m Interface, c chan struct{}, protocol string, extport, intport int, na
|
|||
// ExtIP assumes that the local machine is reachable on the given
|
||||
// external IP address, and that any required ports were mapped manually.
|
||||
// Mapping operations will not return an error but won't actually do anything.
|
||||
func ExtIP(ip net.IP) Interface {
|
||||
if ip == nil {
|
||||
panic("IP must not be nil")
|
||||
}
|
||||
return extIP(ip)
|
||||
}
|
||||
type ExtIP net.IP
|
||||
|
||||
type extIP net.IP
|
||||
|
||||
func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
|
||||
func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
|
||||
func (n ExtIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
|
||||
func (n ExtIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
|
||||
|
||||
// These do nothing.
|
||||
func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
|
||||
func (extIP) DeleteMapping(string, int, int) error { return nil }
|
||||
|
||||
func (ExtIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
|
||||
func (ExtIP) DeleteMapping(string, int, int) error { return nil }
|
||||
|
||||
// Any returns a port mapper that tries to discover any supported
|
||||
// mechanism on the local network.
|
||||
|
|
|
@ -28,7 +28,7 @@ import (
|
|||
func TestAutoDiscRace(t *testing.T) {
|
||||
ad := startautodisc("thing", func() Interface {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
return extIP{33, 44, 55, 66}
|
||||
return ExtIP{33, 44, 55, 66}
|
||||
})
|
||||
|
||||
// Spawn a few concurrent calls to ad.ExternalIP.
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2018 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
)
|
||||
|
||||
// IPTracker predicts the external endpoint, i.e. IP address and port, of the local host
|
||||
// based on statements made by other hosts.
|
||||
type IPTracker struct {
|
||||
window time.Duration
|
||||
contactWindow time.Duration
|
||||
minStatements int
|
||||
clock mclock.Clock
|
||||
statements map[string]ipStatement
|
||||
contact map[string]mclock.AbsTime
|
||||
lastStatementGC mclock.AbsTime
|
||||
lastContactGC mclock.AbsTime
|
||||
}
|
||||
|
||||
type ipStatement struct {
|
||||
endpoint string
|
||||
time mclock.AbsTime
|
||||
}
|
||||
|
||||
// NewIPTracker creates an IP tracker.
|
||||
//
|
||||
// The window parameters configure the amount of past network events which are kept. The
|
||||
// minStatements parameter enforces a minimum number of statements which must be recorded
|
||||
// before any prediction is made. Higher values for these parameters decrease 'flapping' of
|
||||
// predictions as network conditions change. Window duration values should typically be in
|
||||
// the range of minutes.
|
||||
func NewIPTracker(window, contactWindow time.Duration, minStatements int) *IPTracker {
|
||||
return &IPTracker{
|
||||
window: window,
|
||||
contactWindow: contactWindow,
|
||||
statements: make(map[string]ipStatement),
|
||||
minStatements: minStatements,
|
||||
contact: make(map[string]mclock.AbsTime),
|
||||
clock: mclock.System{},
|
||||
}
|
||||
}
|
||||
|
||||
// PredictFullConeNAT checks whether the local host is behind full cone NAT. It predicts by
|
||||
// checking whether any statement has been received from a node we didn't contact before
|
||||
// the statement was made.
|
||||
func (it *IPTracker) PredictFullConeNAT() bool {
|
||||
now := it.clock.Now()
|
||||
it.gcContact(now)
|
||||
it.gcStatements(now)
|
||||
for host, st := range it.statements {
|
||||
if c, ok := it.contact[host]; !ok || c > st.time {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PredictEndpoint returns the current prediction of the external endpoint.
|
||||
func (it *IPTracker) PredictEndpoint() string {
|
||||
it.gcStatements(it.clock.Now())
|
||||
|
||||
// The current strategy is simple: find the endpoint with most statements.
|
||||
counts := make(map[string]int)
|
||||
maxcount, max := 0, ""
|
||||
for _, s := range it.statements {
|
||||
c := counts[s.endpoint] + 1
|
||||
counts[s.endpoint] = c
|
||||
if c > maxcount && c >= it.minStatements {
|
||||
maxcount, max = c, s.endpoint
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
// AddStatement records that a certain host thinks our external endpoint is the one given.
|
||||
func (it *IPTracker) AddStatement(host, endpoint string) {
|
||||
now := it.clock.Now()
|
||||
it.statements[host] = ipStatement{endpoint, now}
|
||||
if time.Duration(now-it.lastStatementGC) >= it.window {
|
||||
it.gcStatements(now)
|
||||
}
|
||||
}
|
||||
|
||||
// AddContact records that a packet containing our endpoint information has been sent to a
|
||||
// certain host.
|
||||
func (it *IPTracker) AddContact(host string) {
|
||||
now := it.clock.Now()
|
||||
it.contact[host] = now
|
||||
if time.Duration(now-it.lastContactGC) >= it.contactWindow {
|
||||
it.gcContact(now)
|
||||
}
|
||||
}
|
||||
|
||||
func (it *IPTracker) gcStatements(now mclock.AbsTime) {
|
||||
it.lastStatementGC = now
|
||||
cutoff := now.Add(-it.window)
|
||||
for host, s := range it.statements {
|
||||
if s.time < cutoff {
|
||||
delete(it.statements, host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (it *IPTracker) gcContact(now mclock.AbsTime) {
|
||||
it.lastContactGC = now
|
||||
cutoff := now.Add(-it.contactWindow)
|
||||
for host, ct := range it.contact {
|
||||
if ct < cutoff {
|
||||
delete(it.contact, host)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright 2018 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
mrand "math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
)
|
||||
|
||||
const (
|
||||
opStatement = iota
|
||||
opContact
|
||||
opPredict
|
||||
opCheckFullCone
|
||||
)
|
||||
|
||||
type iptrackTestEvent struct {
|
||||
op int
|
||||
time int // absolute, in milliseconds
|
||||
ip, from string
|
||||
}
|
||||
|
||||
func TestIPTracker(t *testing.T) {
|
||||
tests := map[string][]iptrackTestEvent{
|
||||
"minStatements": {
|
||||
{opPredict, 0, "", ""},
|
||||
{opStatement, 0, "127.0.0.1", "127.0.0.2"},
|
||||
{opPredict, 1000, "", ""},
|
||||
{opStatement, 1000, "127.0.0.1", "127.0.0.3"},
|
||||
{opPredict, 1000, "", ""},
|
||||
{opStatement, 1000, "127.0.0.1", "127.0.0.4"},
|
||||
{opPredict, 1000, "127.0.0.1", ""},
|
||||
},
|
||||
"window": {
|
||||
{opStatement, 0, "127.0.0.1", "127.0.0.2"},
|
||||
{opStatement, 2000, "127.0.0.1", "127.0.0.3"},
|
||||
{opStatement, 3000, "127.0.0.1", "127.0.0.4"},
|
||||
{opPredict, 10000, "127.0.0.1", ""},
|
||||
{opPredict, 10001, "", ""}, // first statement expired
|
||||
{opStatement, 10100, "127.0.0.1", "127.0.0.2"},
|
||||
{opPredict, 10200, "127.0.0.1", ""},
|
||||
},
|
||||
"fullcone": {
|
||||
{opContact, 0, "", "127.0.0.2"},
|
||||
{opStatement, 10, "127.0.0.1", "127.0.0.2"},
|
||||
{opContact, 2000, "", "127.0.0.3"},
|
||||
{opStatement, 2010, "127.0.0.1", "127.0.0.3"},
|
||||
{opContact, 3000, "", "127.0.0.4"},
|
||||
{opStatement, 3010, "127.0.0.1", "127.0.0.4"},
|
||||
{opCheckFullCone, 3500, "false", ""},
|
||||
},
|
||||
"fullcone_2": {
|
||||
{opContact, 0, "", "127.0.0.2"},
|
||||
{opStatement, 10, "127.0.0.1", "127.0.0.2"},
|
||||
{opContact, 2000, "", "127.0.0.3"},
|
||||
{opStatement, 2010, "127.0.0.1", "127.0.0.3"},
|
||||
{opStatement, 3000, "127.0.0.1", "127.0.0.4"},
|
||||
{opContact, 3010, "", "127.0.0.4"},
|
||||
{opCheckFullCone, 3500, "true", ""},
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) { runIPTrackerTest(t, test) })
|
||||
}
|
||||
}
|
||||
|
||||
func runIPTrackerTest(t *testing.T, evs []iptrackTestEvent) {
|
||||
var (
|
||||
clock mclock.Simulated
|
||||
it = NewIPTracker(10*time.Second, 10*time.Second, 3)
|
||||
)
|
||||
it.clock = &clock
|
||||
for i, ev := range evs {
|
||||
evtime := time.Duration(ev.time) * time.Millisecond
|
||||
clock.Run(evtime - time.Duration(clock.Now()))
|
||||
switch ev.op {
|
||||
case opStatement:
|
||||
it.AddStatement(ev.from, ev.ip)
|
||||
case opContact:
|
||||
it.AddContact(ev.from)
|
||||
case opPredict:
|
||||
if pred := it.PredictEndpoint(); pred != ev.ip {
|
||||
t.Errorf("op %d: wrong prediction %q, want %q", i, pred, ev.ip)
|
||||
}
|
||||
case opCheckFullCone:
|
||||
pred := fmt.Sprintf("%t", it.PredictFullConeNAT())
|
||||
if pred != ev.ip {
|
||||
t.Errorf("op %d: wrong prediction %s, want %s", i, pred, ev.ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This checks that old statements and contacts are GCed even if Predict* isn't called.
|
||||
func TestIPTrackerForceGC(t *testing.T) {
|
||||
var (
|
||||
clock mclock.Simulated
|
||||
window = 10 * time.Second
|
||||
rate = 50 * time.Millisecond
|
||||
max = int(window/rate) + 1
|
||||
it = NewIPTracker(window, window, 3)
|
||||
)
|
||||
it.clock = &clock
|
||||
|
||||
for i := 0; i < 5*max; i++ {
|
||||
e1 := make([]byte, 4)
|
||||
e2 := make([]byte, 4)
|
||||
mrand.Read(e1)
|
||||
mrand.Read(e2)
|
||||
it.AddStatement(string(e1), string(e2))
|
||||
it.AddContact(string(e1))
|
||||
clock.Run(rate)
|
||||
}
|
||||
if len(it.contact) > 2*max {
|
||||
t.Errorf("contacts not GCed, have %d", len(it.contact))
|
||||
}
|
||||
if len(it.statements) > 2*max {
|
||||
t.Errorf("statements not GCed, have %d", len(it.statements))
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
)
|
||||
|
||||
// Protocol represents a P2P subprotocol implementation.
|
||||
|
@ -52,6 +53,9 @@ type Protocol struct {
|
|||
// about a certain peer in the network. If an info retrieval function is set,
|
||||
// but returns nil, it is assumed that the protocol handshake is still running.
|
||||
PeerInfo func(id enode.ID) interface{}
|
||||
|
||||
// Attributes contains protocol specific information for the node record.
|
||||
Attributes []enr.Entry
|
||||
}
|
||||
|
||||
func (p Protocol) cap() Cap {
|
||||
|
@ -64,10 +68,6 @@ type Cap struct {
|
|||
Version uint
|
||||
}
|
||||
|
||||
func (cap Cap) RlpData() interface{} {
|
||||
return []interface{}{cap.Name, cap.Version}
|
||||
}
|
||||
|
||||
func (cap Cap) String() string {
|
||||
return fmt.Sprintf("%s/%d", cap.Name, cap.Version)
|
||||
}
|
||||
|
@ -79,3 +79,5 @@ func (cs capsByNameAndVersion) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
|
|||
func (cs capsByNameAndVersion) Less(i, j int) bool {
|
||||
return cs[i].Name < cs[j].Name || (cs[i].Name == cs[j].Name && cs[i].Version < cs[j].Version)
|
||||
}
|
||||
|
||||
func (capsByNameAndVersion) ENRKey() string { return "cap" }
|
||||
|
|
208
p2p/server.go
208
p2p/server.go
|
@ -20,9 +20,11 @@ package p2p
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -35,8 +37,10 @@ import (
|
|||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/discv5"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/p2p/netutil"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -160,6 +164,8 @@ type Server struct {
|
|||
lock sync.Mutex // protects running
|
||||
running bool
|
||||
|
||||
nodedb *enode.DB
|
||||
localnode *enode.LocalNode
|
||||
ntab discoverTable
|
||||
listener net.Listener
|
||||
ourHandshake *protoHandshake
|
||||
|
@ -347,43 +353,13 @@ func (srv *Server) SubscribeEvents(ch chan *PeerEvent) event.Subscription {
|
|||
// Self returns the local node's endpoint information.
|
||||
func (srv *Server) Self() *enode.Node {
|
||||
srv.lock.Lock()
|
||||
running, listener, ntab := srv.running, srv.listener, srv.ntab
|
||||
ln := srv.localnode
|
||||
srv.lock.Unlock()
|
||||
|
||||
if !running {
|
||||
if ln == nil {
|
||||
return enode.NewV4(&srv.PrivateKey.PublicKey, net.ParseIP("0.0.0.0"), 0, 0)
|
||||
}
|
||||
return srv.makeSelf(listener, ntab)
|
||||
}
|
||||
|
||||
func (srv *Server) makeSelf(listener net.Listener, ntab discoverTable) *enode.Node {
|
||||
// If the node is running but discovery is off, manually assemble the node infos.
|
||||
if ntab == nil {
|
||||
addr := srv.tcpAddr(listener)
|
||||
return enode.NewV4(&srv.PrivateKey.PublicKey, addr.IP, addr.Port, 0)
|
||||
}
|
||||
// Otherwise return the discovery node.
|
||||
return ntab.Self()
|
||||
}
|
||||
|
||||
func (srv *Server) tcpAddr(listener net.Listener) net.TCPAddr {
|
||||
addr := net.TCPAddr{IP: net.IP{0, 0, 0, 0}}
|
||||
if listener == nil {
|
||||
return addr // Inbound connections disabled, use zero address.
|
||||
}
|
||||
// Otherwise inject the listener address too.
|
||||
if a, ok := listener.Addr().(*net.TCPAddr); ok {
|
||||
addr = *a
|
||||
}
|
||||
if srv.NAT != nil {
|
||||
if ip, err := srv.NAT.ExternalIP(); err == nil {
|
||||
addr.IP = ip
|
||||
}
|
||||
}
|
||||
if addr.IP.IsUnspecified() {
|
||||
addr.IP = net.IP{127, 0, 0, 1}
|
||||
}
|
||||
return addr
|
||||
return ln.Node()
|
||||
}
|
||||
|
||||
// Stop terminates the server and all active peer connections.
|
||||
|
@ -443,7 +419,9 @@ func (srv *Server) Start() (err error) {
|
|||
if srv.log == nil {
|
||||
srv.log = log.New()
|
||||
}
|
||||
srv.log.Info("Starting P2P networking")
|
||||
if srv.NoDial && srv.ListenAddr == "" {
|
||||
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
|
||||
}
|
||||
|
||||
// static fields
|
||||
if srv.PrivateKey == nil {
|
||||
|
@ -466,65 +444,120 @@ func (srv *Server) Start() (err error) {
|
|||
srv.peerOp = make(chan peerOpFunc)
|
||||
srv.peerOpDone = make(chan struct{})
|
||||
|
||||
var (
|
||||
conn *net.UDPConn
|
||||
sconn *sharedUDPConn
|
||||
realaddr *net.UDPAddr
|
||||
unhandled chan discover.ReadPacket
|
||||
)
|
||||
if err := srv.setupLocalNode(); err != nil {
|
||||
return err
|
||||
}
|
||||
if srv.ListenAddr != "" {
|
||||
if err := srv.setupListening(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := srv.setupDiscovery(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dynPeers := srv.maxDialedConns()
|
||||
dialer := newDialState(srv.localnode.ID(), srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
|
||||
srv.loopWG.Add(1)
|
||||
go srv.run(dialer)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) setupLocalNode() error {
|
||||
// Create the devp2p handshake.
|
||||
pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
|
||||
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
|
||||
for _, p := range srv.Protocols {
|
||||
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
||||
}
|
||||
sort.Sort(capsByNameAndVersion(srv.ourHandshake.Caps))
|
||||
|
||||
// Create the local node.
|
||||
db, err := enode.OpenDB(srv.Config.NodeDatabase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.nodedb = db
|
||||
srv.localnode = enode.NewLocalNode(db, srv.PrivateKey)
|
||||
srv.localnode.SetFallbackIP(net.IP{127, 0, 0, 1})
|
||||
srv.localnode.Set(capsByNameAndVersion(srv.ourHandshake.Caps))
|
||||
// TODO: check conflicts
|
||||
for _, p := range srv.Protocols {
|
||||
for _, e := range p.Attributes {
|
||||
srv.localnode.Set(e)
|
||||
}
|
||||
}
|
||||
switch srv.NAT.(type) {
|
||||
case nil:
|
||||
// No NAT interface, do nothing.
|
||||
case nat.ExtIP:
|
||||
// ExtIP doesn't block, set the IP right away.
|
||||
ip, _ := srv.NAT.ExternalIP()
|
||||
srv.localnode.SetStaticIP(ip)
|
||||
default:
|
||||
// Ask the router about the IP. This takes a while and blocks startup,
|
||||
// do it in the background.
|
||||
srv.loopWG.Add(1)
|
||||
go func() {
|
||||
defer srv.loopWG.Done()
|
||||
if ip, err := srv.NAT.ExternalIP(); err == nil {
|
||||
srv.localnode.SetStaticIP(ip)
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) setupDiscovery() error {
|
||||
if srv.NoDiscovery && !srv.DiscoveryV5 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !srv.NoDiscovery || srv.DiscoveryV5 {
|
||||
addr, err := net.ResolveUDPAddr("udp", srv.ListenAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err = net.ListenUDP("udp", addr)
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
realaddr = conn.LocalAddr().(*net.UDPAddr)
|
||||
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
srv.log.Debug("UDP listener up", "addr", realaddr)
|
||||
if srv.NAT != nil {
|
||||
if !realaddr.IP.IsLoopback() {
|
||||
go nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
||||
}
|
||||
// TODO: react to external IP changes over time.
|
||||
if ext, err := srv.NAT.ExternalIP(); err == nil {
|
||||
realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
|
||||
}
|
||||
}
|
||||
}
|
||||
srv.localnode.SetFallbackUDP(realaddr.Port)
|
||||
|
||||
if !srv.NoDiscovery && srv.DiscoveryV5 {
|
||||
// Discovery V4
|
||||
var unhandled chan discover.ReadPacket
|
||||
var sconn *sharedUDPConn
|
||||
if !srv.NoDiscovery {
|
||||
if srv.DiscoveryV5 {
|
||||
unhandled = make(chan discover.ReadPacket, 100)
|
||||
sconn = &sharedUDPConn{conn, unhandled}
|
||||
}
|
||||
|
||||
// node table
|
||||
if !srv.NoDiscovery {
|
||||
cfg := discover.Config{
|
||||
PrivateKey: srv.PrivateKey,
|
||||
AnnounceAddr: realaddr,
|
||||
NodeDBPath: srv.NodeDatabase,
|
||||
NetRestrict: srv.NetRestrict,
|
||||
Bootnodes: srv.BootstrapNodes,
|
||||
Unhandled: unhandled,
|
||||
}
|
||||
ntab, err := discover.ListenUDP(conn, cfg)
|
||||
ntab, err := discover.ListenUDP(conn, srv.localnode, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.ntab = ntab
|
||||
}
|
||||
|
||||
// Discovery V5
|
||||
if srv.DiscoveryV5 {
|
||||
var (
|
||||
ntab *discv5.Network
|
||||
err error
|
||||
)
|
||||
var ntab *discv5.Network
|
||||
var err error
|
||||
if sconn != nil {
|
||||
ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
|
||||
ntab, err = discv5.ListenUDP(srv.PrivateKey, sconn, "", srv.NetRestrict)
|
||||
} else {
|
||||
ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, realaddr, "", srv.NetRestrict) //srv.NodeDatabase)
|
||||
ntab, err = discv5.ListenUDP(srv.PrivateKey, conn, "", srv.NetRestrict)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -534,32 +567,10 @@ func (srv *Server) Start() (err error) {
|
|||
}
|
||||
srv.DiscV5 = ntab
|
||||
}
|
||||
|
||||
dynPeers := srv.maxDialedConns()
|
||||
dialer := newDialState(srv.StaticNodes, srv.BootstrapNodes, srv.ntab, dynPeers, srv.NetRestrict)
|
||||
|
||||
// handshake
|
||||
pubkey := crypto.FromECDSAPub(&srv.PrivateKey.PublicKey)
|
||||
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: pubkey[1:]}
|
||||
for _, p := range srv.Protocols {
|
||||
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
||||
}
|
||||
// listen/dial
|
||||
if srv.ListenAddr != "" {
|
||||
if err := srv.startListening(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if srv.NoDial && srv.ListenAddr == "" {
|
||||
srv.log.Warn("P2P server will be useless, neither dialing nor listening")
|
||||
}
|
||||
|
||||
srv.loopWG.Add(1)
|
||||
go srv.run(dialer)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) startListening() error {
|
||||
func (srv *Server) setupListening() error {
|
||||
// Launch the TCP listener.
|
||||
listener, err := net.Listen("tcp", srv.ListenAddr)
|
||||
if err != nil {
|
||||
|
@ -568,8 +579,11 @@ func (srv *Server) startListening() error {
|
|||
laddr := listener.Addr().(*net.TCPAddr)
|
||||
srv.ListenAddr = laddr.String()
|
||||
srv.listener = listener
|
||||
srv.localnode.Set(enr.TCP(laddr.Port))
|
||||
|
||||
srv.loopWG.Add(1)
|
||||
go srv.listenLoop()
|
||||
|
||||
// Map the TCP listening port if NAT is configured.
|
||||
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||
srv.loopWG.Add(1)
|
||||
|
@ -589,7 +603,10 @@ type dialer interface {
|
|||
}
|
||||
|
||||
func (srv *Server) run(dialstate dialer) {
|
||||
srv.log.Info("Started P2P networking", "self", srv.localnode.Node())
|
||||
defer srv.loopWG.Done()
|
||||
defer srv.nodedb.Close()
|
||||
|
||||
var (
|
||||
peers = make(map[enode.ID]*Peer)
|
||||
inboundCount = 0
|
||||
|
@ -781,7 +798,7 @@ func (srv *Server) encHandshakeChecks(peers map[enode.ID]*Peer, inboundCount int
|
|||
return DiscTooManyPeers
|
||||
case peers[c.node.ID()] != nil:
|
||||
return DiscAlreadyConnected
|
||||
case c.node.ID() == srv.Self().ID():
|
||||
case c.node.ID() == srv.localnode.ID():
|
||||
return DiscSelf
|
||||
default:
|
||||
return nil
|
||||
|
@ -802,15 +819,11 @@ func (srv *Server) maxDialedConns() int {
|
|||
return srv.MaxPeers / r
|
||||
}
|
||||
|
||||
type tempError interface {
|
||||
Temporary() bool
|
||||
}
|
||||
|
||||
// listenLoop runs in its own goroutine and accepts
|
||||
// inbound connections.
|
||||
func (srv *Server) listenLoop() {
|
||||
defer srv.loopWG.Done()
|
||||
srv.log.Info("RLPx listener up", "self", srv.Self())
|
||||
srv.log.Debug("TCP listener up", "addr", srv.listener.Addr())
|
||||
|
||||
tokens := defaultMaxPendingPeers
|
||||
if srv.MaxPendingPeers > 0 {
|
||||
|
@ -831,7 +844,7 @@ func (srv *Server) listenLoop() {
|
|||
)
|
||||
for {
|
||||
fd, err = srv.listener.Accept()
|
||||
if tempErr, ok := err.(tempError); ok && tempErr.Temporary() {
|
||||
if netutil.IsTemporaryError(err) {
|
||||
srv.log.Debug("Temporary read error", "err", err)
|
||||
continue
|
||||
} else if err != nil {
|
||||
|
@ -864,10 +877,6 @@ func (srv *Server) listenLoop() {
|
|||
// as a peer. It returns when the connection has been added as a peer
|
||||
// or the handshakes have failed.
|
||||
func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) error {
|
||||
self := srv.Self()
|
||||
if self == nil {
|
||||
return errors.New("shutdown")
|
||||
}
|
||||
c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
|
||||
err := srv.setupConn(c, flags, dialDest)
|
||||
if err != nil {
|
||||
|
@ -1003,6 +1012,7 @@ type NodeInfo struct {
|
|||
ID string `json:"id"` // Unique node identifier (also the encryption key)
|
||||
Name string `json:"name"` // Name of the node, including client type, version, OS, custom data
|
||||
Enode string `json:"enode"` // Enode URL for adding this peer from remote peers
|
||||
ENR string `json:"enr"` // Ethereum Node Record
|
||||
IP string `json:"ip"` // IP address of the node
|
||||
Ports struct {
|
||||
Discovery int `json:"discovery"` // UDP listening port for discovery protocol
|
||||
|
@ -1014,9 +1024,8 @@ type NodeInfo struct {
|
|||
|
||||
// NodeInfo gathers and returns a collection of metadata known about the host.
|
||||
func (srv *Server) NodeInfo() *NodeInfo {
|
||||
node := srv.Self()
|
||||
|
||||
// Gather and assemble the generic node infos
|
||||
node := srv.Self()
|
||||
info := &NodeInfo{
|
||||
Name: srv.Name,
|
||||
Enode: node.String(),
|
||||
|
@ -1027,6 +1036,9 @@ func (srv *Server) NodeInfo() *NodeInfo {
|
|||
}
|
||||
info.Ports.Discovery = node.UDP()
|
||||
info.Ports.Listener = node.TCP()
|
||||
if enc, err := rlp.EncodeToBytes(node.Record()); err == nil {
|
||||
info.ENR = "0x" + hex.EncodeToString(enc)
|
||||
}
|
||||
|
||||
// Gather all the running protocol infos (only once per protocol type)
|
||||
for _, proto := range srv.Protocols {
|
||||
|
|
|
@ -225,8 +225,11 @@ func TestServerTaskScheduling(t *testing.T) {
|
|||
|
||||
// The Server in this test isn't actually running
|
||||
// because we're only interested in what run does.
|
||||
db, _ := enode.OpenDB("")
|
||||
srv := &Server{
|
||||
Config: Config{MaxPeers: 10},
|
||||
localnode: enode.NewLocalNode(db, newkey()),
|
||||
nodedb: db,
|
||||
quit: make(chan struct{}),
|
||||
ntab: fakeTable{},
|
||||
running: true,
|
||||
|
@ -271,8 +274,11 @@ func TestServerManyTasks(t *testing.T) {
|
|||
}
|
||||
|
||||
var (
|
||||
db, _ = enode.OpenDB("")
|
||||
srv = &Server{
|
||||
quit: make(chan struct{}),
|
||||
localnode: enode.NewLocalNode(db, newkey()),
|
||||
nodedb: db,
|
||||
ntab: fakeTable{},
|
||||
running: true,
|
||||
log: log.New(),
|
||||
|
|
Loading…
Reference in New Issue