p2p: disable encryption handshake
The diff is a bit bigger than expected because the protocol handshake logic has moved out of Peer. This is necessary because the protocol handshake will have custom framing in the final protocol.
This commit is contained in:
parent
4322632c59
commit
73f94f3755
|
@ -1,21 +1,20 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
// "binary"
|
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var clogger = ethlogger.NewLogger("CRYPTOID")
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||||
sigLen = 65 // elliptic S256
|
sigLen = 65 // elliptic S256
|
||||||
|
@ -30,26 +29,76 @@ const (
|
||||||
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||||
)
|
)
|
||||||
|
|
||||||
type hexkey []byte
|
type conn struct {
|
||||||
|
*frameRW
|
||||||
func (self hexkey) String() string {
|
*protoHandshake
|
||||||
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
|
func newConn(fd net.Conn, hs *protoHandshake) *conn {
|
||||||
remoteID discover.NodeID,
|
return &conn{newFrameRW(fd, msgWriteTimeout), hs}
|
||||||
sessionToken []byte,
|
}
|
||||||
err error,
|
|
||||||
) {
|
// encHandshake represents information about the remote end
|
||||||
|
// of a connection that is negotiated during the encryption handshake.
|
||||||
|
type encHandshake struct {
|
||||||
|
ID discover.NodeID
|
||||||
|
IngressMAC []byte
|
||||||
|
EgressMAC []byte
|
||||||
|
Token []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// protoHandshake is the RLP structure of the protocol handshake.
|
||||||
|
type protoHandshake struct {
|
||||||
|
Version uint64
|
||||||
|
Name string
|
||||||
|
Caps []Cap
|
||||||
|
ListenPort uint64
|
||||||
|
ID discover.NodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupConn starts a protocol session on the given connection.
|
||||||
|
// It runs the encryption handshake and the protocol handshake.
|
||||||
|
// If dial is non-nil, the connection the local node is the initiator.
|
||||||
|
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||||
if dial == nil {
|
if dial == nil {
|
||||||
var remotePubkey []byte
|
return setupInboundConn(fd, prv, our)
|
||||||
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
|
|
||||||
copy(remoteID[:], remotePubkey)
|
|
||||||
} else {
|
} else {
|
||||||
remoteID = dial.ID
|
return setupOutboundConn(fd, prv, our, dial)
|
||||||
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
|
|
||||||
}
|
}
|
||||||
return remoteID, sessionToken, err
|
}
|
||||||
|
|
||||||
|
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake) (*conn, error) {
|
||||||
|
// var remotePubkey []byte
|
||||||
|
// sessionToken, remotePubkey, err = inboundEncHandshake(fd, prv, nil)
|
||||||
|
// copy(remoteID[:], remotePubkey)
|
||||||
|
|
||||||
|
rw := newFrameRW(fd, msgWriteTimeout)
|
||||||
|
rhs, err := readProtocolHandshake(rw, our)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||||
|
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||||
|
}
|
||||||
|
return &conn{rw, rhs}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||||
|
// remoteID = dial.ID
|
||||||
|
// sessionToken, err = outboundEncHandshake(fd, prv, remoteID[:], nil)
|
||||||
|
|
||||||
|
rw := newFrameRW(fd, msgWriteTimeout)
|
||||||
|
if err := writeProtocolHandshake(rw, our); err != nil {
|
||||||
|
return nil, fmt.Errorf("protocol write error: %v", err)
|
||||||
|
}
|
||||||
|
rhs, err := readProtocolHandshake(rw, our)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("protocol handshake read error: %v", err)
|
||||||
|
}
|
||||||
|
if rhs.ID != dial.ID {
|
||||||
|
return nil, errors.New("dialed node id mismatch")
|
||||||
|
}
|
||||||
|
return &conn{rw, rhs}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// outboundEncHandshake negotiates a session token on conn.
|
// outboundEncHandshake negotiates a session token on conn.
|
||||||
|
@ -66,18 +115,9 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sessionToken != nil {
|
|
||||||
clogger.Debugf("session-token: %v", hexkey(sessionToken))
|
|
||||||
}
|
|
||||||
|
|
||||||
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
|
|
||||||
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
|
||||||
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
|
|
||||||
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
|
|
||||||
if _, err = conn.Write(auth); err != nil {
|
if _, err = conn.Write(auth); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
clogger.Debugf("initiator handshake: %v", hexkey(auth))
|
|
||||||
|
|
||||||
response := make([]byte, rHSLen)
|
response := make([]byte, rHSLen)
|
||||||
if _, err = io.ReadFull(conn, response); err != nil {
|
if _, err = io.ReadFull(conn, response); err != nil {
|
||||||
|
@ -88,9 +128,6 @@ func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePu
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
|
||||||
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
|
|
||||||
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
|
|
||||||
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,12 +258,9 @@ func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionTo
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
|
||||||
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
|
||||||
if _, err = conn.Write(response); err != nil {
|
if _, err = conn.Write(response); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
|
|
||||||
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||||
return token, remotePubKey, err
|
return token, remotePubKey, err
|
||||||
}
|
}
|
||||||
|
@ -361,3 +395,40 @@ func xor(one, other []byte) (xor []byte) {
|
||||||
}
|
}
|
||||||
return xor
|
return xor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeProtocolHandshake(w MsgWriter, our *protoHandshake) error {
|
||||||
|
return EncodeMsg(w, handshakeMsg, our.Version, our.Name, our.Caps, our.ListenPort, our.ID[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
||||||
|
// read and handle remote handshake
|
||||||
|
msg, err := r.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if msg.Code == discMsg {
|
||||||
|
// disconnect before protocol handshake is valid according to the
|
||||||
|
// spec and we send it ourself if Server.addPeer fails.
|
||||||
|
var reason DiscReason
|
||||||
|
rlp.Decode(msg.Payload, &reason)
|
||||||
|
return nil, discRequestedError(reason)
|
||||||
|
}
|
||||||
|
if msg.Code != handshakeMsg {
|
||||||
|
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||||
|
}
|
||||||
|
if msg.Size > baseProtocolMaxMsgSize {
|
||||||
|
return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
|
||||||
|
}
|
||||||
|
var hs protoHandshake
|
||||||
|
if err := msg.Decode(&hs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// validate handshake info
|
||||||
|
if hs.Version != our.Version {
|
||||||
|
return nil, newPeerError(errP2PVersionMismatch, "required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
||||||
|
}
|
||||||
|
if (hs.ID == discover.NodeID{}) {
|
||||||
|
return nil, newPeerError(errPubkeyInvalid, "missing")
|
||||||
|
}
|
||||||
|
return &hs, nil
|
||||||
|
}
|
|
@ -5,10 +5,12 @@ import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPublicKeyEncoding(t *testing.T) {
|
func TestPublicKeyEncoding(t *testing.T) {
|
||||||
|
@ -91,14 +93,14 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v", err)
|
t.Errorf("%v", err)
|
||||||
}
|
}
|
||||||
t.Logf("-> %v", hexkey(auth))
|
// t.Logf("-> %v", hexkey(auth))
|
||||||
|
|
||||||
// receiver reads auth and responds with response
|
// receiver reads auth and responds with response
|
||||||
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
|
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v", err)
|
t.Errorf("%v", err)
|
||||||
}
|
}
|
||||||
t.Logf("<- %v\n", hexkey(response))
|
// t.Logf("<- %v\n", hexkey(response))
|
||||||
|
|
||||||
// initiator reads receiver's response and the key exchange completes
|
// initiator reads receiver's response and the key exchange completes
|
||||||
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
|
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
|
||||||
|
@ -132,7 +134,7 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandshake(t *testing.T) {
|
func TestEncHandshake(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
prv0, _ := crypto.GenerateKey()
|
prv0, _ := crypto.GenerateKey()
|
||||||
|
@ -165,3 +167,58 @@ func TestHandshake(t *testing.T) {
|
||||||
t.Error("session token mismatch")
|
t.Error("session token mismatch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetupConn(t *testing.T) {
|
||||||
|
prv0, _ := crypto.GenerateKey()
|
||||||
|
prv1, _ := crypto.GenerateKey()
|
||||||
|
node0 := &discover.Node{
|
||||||
|
ID: discover.PubkeyID(&prv0.PublicKey),
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
TCPPort: 33,
|
||||||
|
}
|
||||||
|
node1 := &discover.Node{
|
||||||
|
ID: discover.PubkeyID(&prv1.PublicKey),
|
||||||
|
IP: net.IP{5, 6, 7, 8},
|
||||||
|
TCPPort: 44,
|
||||||
|
}
|
||||||
|
hs0 := &protoHandshake{
|
||||||
|
Version: baseProtocolVersion,
|
||||||
|
ID: node0.ID,
|
||||||
|
Caps: []Cap{{"a", 0}, {"b", 2}},
|
||||||
|
}
|
||||||
|
hs1 := &protoHandshake{
|
||||||
|
Version: baseProtocolVersion,
|
||||||
|
ID: node1.ID,
|
||||||
|
Caps: []Cap{{"c", 1}, {"d", 3}},
|
||||||
|
}
|
||||||
|
fd0, fd1 := net.Pipe()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
conn0, err := setupConn(fd0, prv0, hs0, node1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("outbound side error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn0.ID != node1.ID {
|
||||||
|
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
|
||||||
|
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn1, err := setupConn(fd1, prv1, hs1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("inbound side error: %v", err)
|
||||||
|
}
|
||||||
|
if conn1.ID != node0.ID {
|
||||||
|
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
|
||||||
|
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
}
|
|
@ -197,7 +197,7 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
|
||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
if !bytes.HasPrefix(start, magicToken) {
|
if !bytes.HasPrefix(start, magicToken) {
|
||||||
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
|
return msg, fmt.Errorf("bad magic token %x", start[:4])
|
||||||
}
|
}
|
||||||
size := binary.BigEndian.Uint32(start[4:])
|
size := binary.BigEndian.Uint32(start[4:])
|
||||||
|
|
||||||
|
|
229
p2p/peer.go
229
p2p/peer.go
|
@ -33,37 +33,14 @@ const (
|
||||||
peersMsg = 0x05
|
peersMsg = 0x05
|
||||||
)
|
)
|
||||||
|
|
||||||
// handshake is the RLP structure of the protocol handshake.
|
|
||||||
type handshake struct {
|
|
||||||
Version uint64
|
|
||||||
Name string
|
|
||||||
Caps []Cap
|
|
||||||
ListenPort uint64
|
|
||||||
NodeID discover.NodeID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peer represents a connected remote node.
|
// Peer represents a connected remote node.
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
// Peers have all the log methods.
|
// Peers have all the log methods.
|
||||||
// Use them to display messages related to the peer.
|
// Use them to display messages related to the peer.
|
||||||
*logger.Logger
|
*logger.Logger
|
||||||
|
|
||||||
infoMu sync.Mutex
|
rw *conn
|
||||||
name string
|
running map[string]*protoRW
|
||||||
caps []Cap
|
|
||||||
|
|
||||||
ourID, remoteID *discover.NodeID
|
|
||||||
ourName string
|
|
||||||
|
|
||||||
rw *frameRW
|
|
||||||
|
|
||||||
// These fields maintain the running protocols.
|
|
||||||
protocols []Protocol
|
|
||||||
runlock sync.RWMutex // protects running
|
|
||||||
running map[string]*proto
|
|
||||||
|
|
||||||
// disables protocol handshake, for testing
|
|
||||||
noHandshake bool
|
|
||||||
|
|
||||||
protoWG sync.WaitGroup
|
protoWG sync.WaitGroup
|
||||||
protoErr chan error
|
protoErr chan error
|
||||||
|
@ -73,36 +50,27 @@ type Peer struct {
|
||||||
|
|
||||||
// NewPeer returns a peer for testing purposes.
|
// NewPeer returns a peer for testing purposes.
|
||||||
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||||
conn, _ := net.Pipe()
|
pipe, _ := net.Pipe()
|
||||||
peer := newPeer(conn, nil, "", nil, &id)
|
conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
|
||||||
peer.setHandshakeInfo(name, caps)
|
peer := newPeer(conn, nil)
|
||||||
close(peer.closed) // ensures Disconnect doesn't block
|
close(peer.closed) // ensures Disconnect doesn't block
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the node's public key.
|
// ID returns the node's public key.
|
||||||
func (p *Peer) ID() discover.NodeID {
|
func (p *Peer) ID() discover.NodeID {
|
||||||
return *p.remoteID
|
return p.rw.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the node name that the remote node advertised.
|
// Name returns the node name that the remote node advertised.
|
||||||
func (p *Peer) Name() string {
|
func (p *Peer) Name() string {
|
||||||
// this needs a lock because the information is part of the
|
return p.rw.Name
|
||||||
// protocol handshake.
|
|
||||||
p.infoMu.Lock()
|
|
||||||
name := p.name
|
|
||||||
p.infoMu.Unlock()
|
|
||||||
return name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||||
func (p *Peer) Caps() []Cap {
|
func (p *Peer) Caps() []Cap {
|
||||||
// this needs a lock because the information is part of the
|
// TODO: maybe return copy
|
||||||
// protocol handshake.
|
return p.rw.Caps
|
||||||
p.infoMu.Lock()
|
|
||||||
caps := p.caps
|
|
||||||
p.infoMu.Unlock()
|
|
||||||
return caps
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr returns the remote address of the network connection.
|
// RemoteAddr returns the remote address of the network connection.
|
||||||
|
@ -126,30 +94,20 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
||||||
|
|
||||||
// String implements fmt.Stringer.
|
// String implements fmt.Stringer.
|
||||||
func (p *Peer) String() string {
|
func (p *Peer) String() string {
|
||||||
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
||||||
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
|
||||||
return &Peer{
|
p := &Peer{
|
||||||
Logger: logger.NewLogger(logtag),
|
Logger: logger.NewLogger(logtag),
|
||||||
rw: newFrameRW(conn, msgWriteTimeout),
|
rw: conn,
|
||||||
ourID: ourID,
|
running: matchProtocols(protocols, conn.Caps, conn),
|
||||||
ourName: ourName,
|
disc: make(chan DiscReason),
|
||||||
remoteID: remoteID,
|
protoErr: make(chan error),
|
||||||
protocols: protocols,
|
closed: make(chan struct{}),
|
||||||
running: make(map[string]*proto),
|
|
||||||
disc: make(chan DiscReason),
|
|
||||||
protoErr: make(chan error),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
}
|
return p
|
||||||
|
|
||||||
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
|
|
||||||
p.infoMu.Lock()
|
|
||||||
p.name = name
|
|
||||||
p.caps = caps
|
|
||||||
p.infoMu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) run() DiscReason {
|
func (p *Peer) run() DiscReason {
|
||||||
|
@ -157,16 +115,9 @@ func (p *Peer) run() DiscReason {
|
||||||
defer p.closeProtocols()
|
defer p.closeProtocols()
|
||||||
defer close(p.closed)
|
defer close(p.closed)
|
||||||
|
|
||||||
|
p.startProtocols()
|
||||||
go func() { readErr <- p.readLoop() }()
|
go func() { readErr <- p.readLoop() }()
|
||||||
|
|
||||||
if !p.noHandshake {
|
|
||||||
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
|
||||||
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
|
||||||
p.rw.Close()
|
|
||||||
return DiscProtocolError
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for an error or disconnect.
|
// Wait for an error or disconnect.
|
||||||
var reason DiscReason
|
var reason DiscReason
|
||||||
select {
|
select {
|
||||||
|
@ -206,11 +157,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) readLoop() error {
|
func (p *Peer) readLoop() error {
|
||||||
if !p.noHandshake {
|
|
||||||
if err := readProtocolHandshake(p, p.rw); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
msg, err := p.rw.ReadMsg()
|
msg, err := p.rw.ReadMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -249,105 +195,51 @@ func (p *Peer) handle(msg Msg) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
// matchProtocols creates structures for matching named subprotocols.
|
||||||
// read and handle remote handshake
|
func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
|
||||||
msg, err := rw.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Code == discMsg {
|
|
||||||
// disconnect before protocol handshake is valid according to the
|
|
||||||
// spec and we send it ourself if Server.addPeer fails.
|
|
||||||
var reason DiscReason
|
|
||||||
rlp.Decode(msg.Payload, &reason)
|
|
||||||
return discRequestedError(reason)
|
|
||||||
}
|
|
||||||
if msg.Code != handshakeMsg {
|
|
||||||
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
|
||||||
}
|
|
||||||
if msg.Size > baseProtocolMaxMsgSize {
|
|
||||||
return newPeerError(errInvalidMsg, "message too big")
|
|
||||||
}
|
|
||||||
var hs handshake
|
|
||||||
if err := msg.Decode(&hs); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// validate handshake info
|
|
||||||
if hs.Version != baseProtocolVersion {
|
|
||||||
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
|
|
||||||
baseProtocolVersion, hs.Version)
|
|
||||||
}
|
|
||||||
if hs.NodeID == *p.remoteID {
|
|
||||||
return newPeerError(errPubkeyForbidden, "node ID mismatch")
|
|
||||||
}
|
|
||||||
// TODO: remove Caps with empty name
|
|
||||||
p.setHandshakeInfo(hs.Name, hs.Caps)
|
|
||||||
p.startSubprotocols(hs.Caps)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
|
|
||||||
var caps []interface{}
|
|
||||||
for _, proto := range ps {
|
|
||||||
caps = append(caps, proto.cap())
|
|
||||||
}
|
|
||||||
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// startProtocols starts matching named subprotocols.
|
|
||||||
func (p *Peer) startSubprotocols(caps []Cap) {
|
|
||||||
sort.Sort(capsByName(caps))
|
sort.Sort(capsByName(caps))
|
||||||
p.runlock.Lock()
|
|
||||||
defer p.runlock.Unlock()
|
|
||||||
offset := baseProtocolLength
|
offset := baseProtocolLength
|
||||||
|
result := make(map[string]*protoRW)
|
||||||
outer:
|
outer:
|
||||||
for _, cap := range caps {
|
for _, cap := range caps {
|
||||||
for _, proto := range p.protocols {
|
for _, proto := range protocols {
|
||||||
if proto.Name == cap.Name &&
|
if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
|
||||||
proto.Version == cap.Version &&
|
result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
|
||||||
p.running[cap.Name] == nil {
|
|
||||||
p.running[cap.Name] = p.startProto(offset, proto)
|
|
||||||
offset += proto.Length
|
offset += proto.Length
|
||||||
continue outer
|
continue outer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
func (p *Peer) startProtocols() {
|
||||||
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
|
for _, proto := range p.running {
|
||||||
rw := &proto{
|
proto := proto
|
||||||
name: impl.Name,
|
p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
|
||||||
in: make(chan Msg),
|
p.protoWG.Add(1)
|
||||||
offset: offset,
|
go func() {
|
||||||
maxcode: impl.Length,
|
err := proto.Run(p, proto)
|
||||||
w: p.rw,
|
if err == nil {
|
||||||
|
p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
|
||||||
|
err = errors.New("protocol returned")
|
||||||
|
} else {
|
||||||
|
p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case p.protoErr <- err:
|
||||||
|
case <-p.closed:
|
||||||
|
}
|
||||||
|
p.protoWG.Done()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
p.protoWG.Add(1)
|
|
||||||
go func() {
|
|
||||||
err := impl.Run(p, rw)
|
|
||||||
if err == nil {
|
|
||||||
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
|
||||||
err = errors.New("protocol returned")
|
|
||||||
} else {
|
|
||||||
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case p.protoErr <- err:
|
|
||||||
case <-p.closed:
|
|
||||||
}
|
|
||||||
p.protoWG.Done()
|
|
||||||
}()
|
|
||||||
return rw
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getProto finds the protocol responsible for handling
|
// getProto finds the protocol responsible for handling
|
||||||
// the given message code.
|
// the given message code.
|
||||||
func (p *Peer) getProto(code uint64) (*proto, error) {
|
func (p *Peer) getProto(code uint64) (*protoRW, error) {
|
||||||
p.runlock.RLock()
|
|
||||||
defer p.runlock.RUnlock()
|
|
||||||
for _, proto := range p.running {
|
for _, proto := range p.running {
|
||||||
if code >= proto.offset && code < proto.offset+proto.maxcode {
|
if code >= proto.offset && code < proto.offset+proto.Length {
|
||||||
return proto, nil
|
return proto, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -355,46 +247,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) closeProtocols() {
|
func (p *Peer) closeProtocols() {
|
||||||
p.runlock.RLock()
|
|
||||||
for _, p := range p.running {
|
for _, p := range p.running {
|
||||||
close(p.in)
|
close(p.in)
|
||||||
}
|
}
|
||||||
p.runlock.RUnlock()
|
|
||||||
p.protoWG.Wait()
|
p.protoWG.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||||
// this exists because of Server.Broadcast.
|
// this exists because of Server.Broadcast.
|
||||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||||
p.runlock.RLock()
|
|
||||||
proto, ok := p.running[protoName]
|
proto, ok := p.running[protoName]
|
||||||
p.runlock.RUnlock()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
||||||
}
|
}
|
||||||
if msg.Code >= proto.maxcode {
|
if msg.Code >= proto.Length {
|
||||||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||||
}
|
}
|
||||||
msg.Code += proto.offset
|
msg.Code += proto.offset
|
||||||
return p.rw.WriteMsg(msg)
|
return p.rw.WriteMsg(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
type proto struct {
|
type protoRW struct {
|
||||||
name string
|
Protocol
|
||||||
in chan Msg
|
|
||||||
maxcode, offset uint64
|
in chan Msg
|
||||||
w MsgWriter
|
offset uint64
|
||||||
|
w MsgWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *proto) WriteMsg(msg Msg) error {
|
func (rw *protoRW) WriteMsg(msg Msg) error {
|
||||||
if msg.Code >= rw.maxcode {
|
if msg.Code >= rw.Length {
|
||||||
return newPeerError(errInvalidMsgCode, "not handled")
|
return newPeerError(errInvalidMsgCode, "not handled")
|
||||||
}
|
}
|
||||||
msg.Code += rw.offset
|
msg.Code += rw.offset
|
||||||
return rw.w.WriteMsg(msg)
|
return rw.w.WriteMsg(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *proto) ReadMsg() (Msg, error) {
|
func (rw *protoRW) ReadMsg() (Msg, error) {
|
||||||
msg, ok := <-rw.in
|
msg, ok := <-rw.in
|
||||||
if !ok {
|
if !ok {
|
||||||
return msg, io.EOF
|
return msg, io.EOF
|
||||||
|
|
105
p2p/peer_test.go
105
p2p/peer_test.go
|
@ -6,11 +6,9 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,6 +21,7 @@ var discard = Protocol{
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
fmt.Printf("discarding %d\n", msg.Code)
|
||||||
if err = msg.Discard(); err != nil {
|
if err = msg.Discard(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -30,13 +29,20 @@ var discard = Protocol{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
|
func testPeer(protos []Protocol) (*conn, *Peer, <-chan DiscReason) {
|
||||||
conn1, conn2 := net.Pipe()
|
fd1, fd2 := net.Pipe()
|
||||||
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
|
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||||
peer.noHandshake = noHandshake
|
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||||
|
for _, p := range protos {
|
||||||
|
hs1.Caps = append(hs1.Caps, p.cap())
|
||||||
|
hs2.Caps = append(hs2.Caps, p.cap())
|
||||||
|
}
|
||||||
|
|
||||||
|
peer := newPeer(newConn(fd1, hs1), protos)
|
||||||
errc := make(chan DiscReason, 1)
|
errc := make(chan DiscReason, 1)
|
||||||
go func() { errc <- peer.run() }()
|
go func() { errc <- peer.run() }()
|
||||||
return newFrameRW(conn2, msgWriteTimeout), peer, errc
|
|
||||||
|
return newConn(fd2, hs2), peer, errc
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerProtoReadMsg(t *testing.T) {
|
func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
|
@ -61,9 +67,8 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rw, peer, errc := testPeer(true, []Protocol{proto})
|
rw, _, errc := testPeer([]Protocol{proto})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
|
||||||
|
|
||||||
EncodeMsg(rw, baseProtocolLength+2, 1)
|
EncodeMsg(rw, baseProtocolLength+2, 1)
|
||||||
EncodeMsg(rw, baseProtocolLength+3, 2)
|
EncodeMsg(rw, baseProtocolLength+3, 2)
|
||||||
|
@ -100,9 +105,8 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rw, peer, errc := testPeer(true, []Protocol{proto})
|
rw, _, errc := testPeer([]Protocol{proto})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
|
||||||
|
|
||||||
EncodeMsg(rw, 18, make([]byte, msgsize))
|
EncodeMsg(rw, 18, make([]byte, msgsize))
|
||||||
select {
|
select {
|
||||||
|
@ -130,9 +134,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rw, peer, _ := testPeer(true, []Protocol{proto})
|
rw, _, _ := testPeer([]Protocol{proto})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
|
||||||
|
|
||||||
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
|
@ -142,9 +145,8 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, peer, peerErr := testPeer(true, []Protocol{discard})
|
rw, peer, peerErr := testPeer([]Protocol{discard})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{discard.cap()})
|
|
||||||
|
|
||||||
// test write errors
|
// test write errors
|
||||||
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
|
||||||
|
@ -160,7 +162,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
read := make(chan struct{})
|
read := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
if err := expectMsg(rw, 16, nil); err != nil {
|
if err := expectMsg(rw, 16, nil); err != nil {
|
||||||
t.Error()
|
t.Error(err)
|
||||||
}
|
}
|
||||||
close(read)
|
close(read)
|
||||||
}()
|
}()
|
||||||
|
@ -179,7 +181,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
func TestPeerPing(t *testing.T) {
|
func TestPeerPing(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, _, _ := testPeer(true, nil)
|
rw, _, _ := testPeer(nil)
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -192,7 +194,7 @@ func TestPeerPing(t *testing.T) {
|
||||||
func TestPeerDisconnect(t *testing.T) {
|
func TestPeerDisconnect(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, _, disc := testPeer(true, nil)
|
rw, _, disc := testPeer(nil)
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -206,73 +208,6 @@ func TestPeerDisconnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerHandshake(t *testing.T) {
|
|
||||||
defer testlog(t).detach()
|
|
||||||
|
|
||||||
// remote has two matching protocols: a and c
|
|
||||||
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
|
|
||||||
remoteID := randomID()
|
|
||||||
remote.ourID = &remoteID
|
|
||||||
remote.ourName = "remote peer"
|
|
||||||
|
|
||||||
start := make(chan string)
|
|
||||||
stop := make(chan struct{})
|
|
||||||
run := func(p *Peer, rw MsgReadWriter) error {
|
|
||||||
name := rw.(*proto).name
|
|
||||||
if name != "a" && name != "c" {
|
|
||||||
t.Errorf("protocol %q should not be started", name)
|
|
||||||
} else {
|
|
||||||
start <- name
|
|
||||||
}
|
|
||||||
<-stop
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
protocols := []Protocol{
|
|
||||||
{Name: "a", Version: 1, Length: 1, Run: run},
|
|
||||||
{Name: "b", Version: 2, Length: 1, Run: run},
|
|
||||||
{Name: "c", Version: 3, Length: 1, Run: run},
|
|
||||||
{Name: "d", Version: 4, Length: 1, Run: run},
|
|
||||||
}
|
|
||||||
rw, p, disc := testPeer(false, protocols)
|
|
||||||
p.remoteID = remote.ourID
|
|
||||||
defer rw.Close()
|
|
||||||
|
|
||||||
// run the handshake
|
|
||||||
remoteProtocols := []Protocol{protocols[0], protocols[2]}
|
|
||||||
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
|
|
||||||
t.Fatalf("handshake write error: %v", err)
|
|
||||||
}
|
|
||||||
if err := readProtocolHandshake(remote, rw); err != nil {
|
|
||||||
t.Fatalf("handshake read error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that all protocols have been started
|
|
||||||
var started []string
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
select {
|
|
||||||
case name := <-start:
|
|
||||||
started = append(started, name)
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort.Strings(started)
|
|
||||||
if !reflect.DeepEqual(started, []string{"a", "c"}) {
|
|
||||||
t.Errorf("wrong protocols started: %v", started)
|
|
||||||
}
|
|
||||||
|
|
||||||
// check that metadata has been set
|
|
||||||
if p.ID() != remoteID {
|
|
||||||
t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
|
|
||||||
}
|
|
||||||
if p.Name() != remote.ourName {
|
|
||||||
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
|
|
||||||
}
|
|
||||||
|
|
||||||
close(stop)
|
|
||||||
expectMsg(rw, discMsg, nil)
|
|
||||||
t.Logf("disc reason: %v", <-disc)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewPeer(t *testing.T) {
|
func TestNewPeer(t *testing.T) {
|
||||||
name := "nodename"
|
name := "nodename"
|
||||||
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -83,9 +82,11 @@ type Server struct {
|
||||||
|
|
||||||
// Hooks for testing. These are useful because we can inhibit
|
// Hooks for testing. These are useful because we can inhibit
|
||||||
// the whole protocol stack.
|
// the whole protocol stack.
|
||||||
handshakeFunc
|
setupFunc
|
||||||
newPeerHook
|
newPeerHook
|
||||||
|
|
||||||
|
ourHandshake *protoHandshake
|
||||||
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
running bool
|
running bool
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
@ -99,7 +100,7 @@ type Server struct {
|
||||||
peerConnect chan *discover.Node
|
peerConnect chan *discover.Node
|
||||||
}
|
}
|
||||||
|
|
||||||
type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
|
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node) (*conn, error)
|
||||||
type newPeerHook func(*Peer)
|
type newPeerHook func(*Peer)
|
||||||
|
|
||||||
// Peers returns all connected peers.
|
// Peers returns all connected peers.
|
||||||
|
@ -170,8 +171,8 @@ func (srv *Server) Start() (err error) {
|
||||||
srv.peers = make(map[discover.NodeID]*Peer)
|
srv.peers = make(map[discover.NodeID]*Peer)
|
||||||
srv.peerConnect = make(chan *discover.Node)
|
srv.peerConnect = make(chan *discover.Node)
|
||||||
|
|
||||||
if srv.handshakeFunc == nil {
|
if srv.setupFunc == nil {
|
||||||
srv.handshakeFunc = encHandshake
|
srv.setupFunc = setupConn
|
||||||
}
|
}
|
||||||
if srv.Blacklist == nil {
|
if srv.Blacklist == nil {
|
||||||
srv.Blacklist = NewBlacklist()
|
srv.Blacklist = NewBlacklist()
|
||||||
|
@ -183,11 +184,17 @@ func (srv *Server) Start() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial stuff
|
// dial stuff
|
||||||
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
|
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
srv.ntab = dt
|
srv.ntab = ntab
|
||||||
|
|
||||||
|
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self()}
|
||||||
|
for _, p := range srv.Protocols {
|
||||||
|
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
||||||
|
}
|
||||||
|
|
||||||
if srv.Dialer == nil {
|
if srv.Dialer == nil {
|
||||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||||
}
|
}
|
||||||
|
@ -347,18 +354,17 @@ func (srv *Server) findPeers() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
|
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
||||||
// TODO: handle/store session token
|
// TODO: handle/store session token
|
||||||
conn.SetDeadline(time.Now().Add(handshakeTimeout))
|
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||||
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
|
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
fd.Close()
|
||||||
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
|
srvlog.Debugf("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ourID := srv.ntab.Self()
|
p := newPeer(conn, srv.Protocols)
|
||||||
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
|
if ok, reason := srv.addPeer(conn.ID, p); !ok {
|
||||||
if ok, reason := srv.addPeer(remoteID, p); !ok {
|
|
||||||
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
||||||
p.politeDisconnect(reason)
|
p.politeDisconnect(reason)
|
||||||
return
|
return
|
||||||
|
@ -394,7 +400,7 @@ func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
|
||||||
|
|
||||||
func (srv *Server) removePeer(p *Peer) {
|
func (srv *Server) removePeer(p *Peer) {
|
||||||
srv.lock.Lock()
|
srv.lock.Lock()
|
||||||
delete(srv.peers, *p.remoteID)
|
delete(srv.peers, p.ID())
|
||||||
srv.lock.Unlock()
|
srv.lock.Unlock()
|
||||||
srv.peerWG.Done()
|
srv.peerWG.Done()
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,8 +21,12 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
||||||
ListenAddr: "127.0.0.1:0",
|
ListenAddr: "127.0.0.1:0",
|
||||||
PrivateKey: newkey(),
|
PrivateKey: newkey(),
|
||||||
newPeerHook: pf,
|
newPeerHook: pf,
|
||||||
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
|
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node) (*conn, error) {
|
||||||
return randomID(), nil, err
|
id := randomID()
|
||||||
|
return &conn{
|
||||||
|
frameRW: newFrameRW(fd, msgWriteTimeout),
|
||||||
|
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
|
||||||
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
|
@ -116,9 +120,7 @@ func TestServerBroadcast(t *testing.T) {
|
||||||
|
|
||||||
var connected sync.WaitGroup
|
var connected sync.WaitGroup
|
||||||
srv := startTestServer(t, func(p *Peer) {
|
srv := startTestServer(t, func(p *Peer) {
|
||||||
p.protocols = []Protocol{discard}
|
p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
|
||||||
p.startSubprotocols([]Cap{discard.cap()})
|
|
||||||
p.noHandshake = true
|
|
||||||
connected.Done()
|
connected.Done()
|
||||||
})
|
})
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
Loading…
Reference in New Issue