p2p: fixes for actual connections
The unit test hooks were turned on 'in production'.
This commit is contained in:
parent
8564eb9f7e
commit
e34d134102
|
@ -174,10 +174,10 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) {
|
||||||
// read magic and payload size
|
// read magic and payload size
|
||||||
start := make([]byte, 8)
|
start := make([]byte, 8)
|
||||||
if _, err = io.ReadFull(rw.bufconn, start); err != nil {
|
if _, err = io.ReadFull(rw.bufconn, start); err != nil {
|
||||||
return msg, newPeerError(errRead, "%v", err)
|
return msg, err
|
||||||
}
|
}
|
||||||
if !bytes.HasPrefix(start, magicToken) {
|
if !bytes.HasPrefix(start, magicToken) {
|
||||||
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
|
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
|
||||||
}
|
}
|
||||||
size := binary.BigEndian.Uint32(start[4:])
|
size := binary.BigEndian.Uint32(start[4:])
|
||||||
|
|
||||||
|
|
37
p2p/peer.go
37
p2p/peer.go
|
@ -1,6 +1,7 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -71,7 +72,8 @@ type Peer struct {
|
||||||
runlock sync.RWMutex // protects running
|
runlock sync.RWMutex // protects running
|
||||||
running map[string]*proto
|
running map[string]*proto
|
||||||
|
|
||||||
protocolHandshakeEnabled bool
|
// disables protocol handshake, for testing
|
||||||
|
noHandshake bool
|
||||||
|
|
||||||
protoWG sync.WaitGroup
|
protoWG sync.WaitGroup
|
||||||
protoErr chan error
|
protoErr chan error
|
||||||
|
@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
||||||
|
|
||||||
// String implements fmt.Stringer.
|
// String implements fmt.Stringer.
|
||||||
func (p *Peer) String() string {
|
func (p *Peer) String() string {
|
||||||
return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
|
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
||||||
logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
|
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
||||||
return &Peer{
|
return &Peer{
|
||||||
Logger: logger.NewLogger(logtag),
|
Logger: logger.NewLogger(logtag),
|
||||||
rw: newFrameRW(conn, msgWriteTimeout),
|
rw: newFrameRW(conn, msgWriteTimeout),
|
||||||
|
@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason {
|
||||||
var readErr = make(chan error, 1)
|
var readErr = make(chan error, 1)
|
||||||
defer p.closeProtocols()
|
defer p.closeProtocols()
|
||||||
defer close(p.closed)
|
defer close(p.closed)
|
||||||
defer p.rw.Close()
|
|
||||||
|
|
||||||
// start the read loop
|
|
||||||
go func() { readErr <- p.readLoop() }()
|
go func() { readErr <- p.readLoop() }()
|
||||||
|
|
||||||
if p.protocolHandshakeEnabled {
|
if !p.noHandshake {
|
||||||
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
||||||
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
||||||
|
p.rw.Close()
|
||||||
return DiscProtocolError
|
return DiscProtocolError
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for an error or disconnect
|
// Wait for an error or disconnect.
|
||||||
var reason DiscReason
|
var reason DiscReason
|
||||||
select {
|
select {
|
||||||
case err := <-readErr:
|
case err := <-readErr:
|
||||||
// We rely on protocols to abort if there is a write error. It
|
// We rely on protocols to abort if there is a write error. It
|
||||||
// might be more robust to handle them here as well.
|
// might be more robust to handle them here as well.
|
||||||
p.DebugDetailf("Read error: %v\n", err)
|
p.DebugDetailf("Read error: %v\n", err)
|
||||||
reason = DiscNetworkError
|
p.rw.Close()
|
||||||
|
return DiscNetworkError
|
||||||
|
|
||||||
case err := <-p.protoErr:
|
case err := <-p.protoErr:
|
||||||
reason = discReasonForError(err)
|
reason = discReasonForError(err)
|
||||||
case reason = <-p.disc:
|
case reason = <-p.disc:
|
||||||
}
|
}
|
||||||
if reason != DiscNetworkError {
|
p.politeDisconnect(reason)
|
||||||
p.politeDisconnect(reason)
|
|
||||||
}
|
// Wait for readLoop. It will end because conn is now closed.
|
||||||
|
<-readErr
|
||||||
p.Debugf("Disconnected: %v\n", reason)
|
p.Debugf("Disconnected: %v\n", reason)
|
||||||
return reason
|
return reason
|
||||||
}
|
}
|
||||||
|
@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason {
|
||||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
// send reason
|
|
||||||
EncodeMsg(p.rw, discMsg, uint(reason))
|
EncodeMsg(p.rw, discMsg, uint(reason))
|
||||||
// discard any data that might arrive
|
// Wait for the other side to close the connection.
|
||||||
|
// Discard any data that they send until then.
|
||||||
io.Copy(ioutil.Discard, p.rw)
|
io.Copy(ioutil.Discard, p.rw)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(disconnectGracePeriod):
|
case <-time.After(disconnectGracePeriod):
|
||||||
}
|
}
|
||||||
|
p.rw.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) readLoop() error {
|
func (p *Peer) readLoop() error {
|
||||||
if p.protocolHandshakeEnabled {
|
if !p.noHandshake {
|
||||||
if err := readProtocolHandshake(p, p.rw); err != nil {
|
if err := readProtocolHandshake(p, p.rw); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
||||||
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
||||||
}
|
}
|
||||||
if msg.Size > baseProtocolMaxMsgSize {
|
if msg.Size > baseProtocolMaxMsgSize {
|
||||||
return newPeerError(errMisc, "message too big")
|
return newPeerError(errInvalidMsg, "message too big")
|
||||||
}
|
}
|
||||||
var hs handshake
|
var hs handshake
|
||||||
if err := msg.Decode(&hs); err != nil {
|
if err := msg.Decode(&hs); err != nil {
|
||||||
|
@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||||
err := impl.Run(p, rw)
|
err := impl.Run(p, rw)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
||||||
err = newPeerError(errMisc, "protocol returned")
|
err = errors.New("protocol returned")
|
||||||
} else {
|
} else {
|
||||||
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -123,7 +123,7 @@ func discReasonForError(err error) DiscReason {
|
||||||
return DiscProtocolError
|
return DiscProtocolError
|
||||||
case errPingTimeout:
|
case errPingTimeout:
|
||||||
return DiscReadTimeout
|
return DiscReadTimeout
|
||||||
case errRead, errWrite, errMisc:
|
case errRead, errWrite:
|
||||||
return DiscNetworkError
|
return DiscNetworkError
|
||||||
default:
|
default:
|
||||||
return DiscSubprotocolError
|
return DiscSubprotocolError
|
||||||
|
|
|
@ -30,10 +30,10 @@ var discard = Protocol{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
|
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
|
||||||
conn1, conn2 := net.Pipe()
|
conn1, conn2 := net.Pipe()
|
||||||
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
|
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
|
||||||
peer.protocolHandshakeEnabled = handshake
|
peer.noHandshake = noHandshake
|
||||||
errc := make(chan DiscReason, 1)
|
errc := make(chan DiscReason, 1)
|
||||||
go func() { errc <- peer.run() }()
|
go func() { errc <- peer.run() }()
|
||||||
return newFrameRW(conn2, msgWriteTimeout), peer, errc
|
return newFrameRW(conn2, msgWriteTimeout), peer, errc
|
||||||
|
@ -61,7 +61,7 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rw, peer, errc := testPeer(false, []Protocol{proto})
|
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
rw, peer, errc := testPeer(false, []Protocol{proto})
|
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
|
@ -130,7 +130,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rw, peer, _ := testPeer(false, []Protocol{proto})
|
rw, peer, _ := testPeer(true, []Protocol{proto})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, peer, peerErr := testPeer(false, []Protocol{discard})
|
rw, peer, peerErr := testPeer(true, []Protocol{discard})
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{discard.cap()})
|
peer.startSubprotocols([]Cap{discard.cap()})
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
func TestPeerPing(t *testing.T) {
|
func TestPeerPing(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, _, _ := testPeer(false, nil)
|
rw, _, _ := testPeer(true, nil)
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
if err := EncodeMsg(rw, pingMsg); err != nil {
|
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -192,7 +192,7 @@ func TestPeerPing(t *testing.T) {
|
||||||
func TestPeerDisconnect(t *testing.T) {
|
func TestPeerDisconnect(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
rw, _, disc := testPeer(false, nil)
|
rw, _, disc := testPeer(true, nil)
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -233,7 +233,7 @@ func TestPeerHandshake(t *testing.T) {
|
||||||
{Name: "c", Version: 3, Length: 1, Run: run},
|
{Name: "c", Version: 3, Length: 1, Run: run},
|
||||||
{Name: "d", Version: 4, Length: 1, Run: run},
|
{Name: "d", Version: 4, Length: 1, Run: run},
|
||||||
}
|
}
|
||||||
rw, p, disc := testPeer(true, protocols)
|
rw, p, disc := testPeer(false, protocols)
|
||||||
p.remoteID = remote.ourID
|
p.remoteID = remote.ourID
|
||||||
defer rw.Close()
|
defer rw.Close()
|
||||||
|
|
||||||
|
@ -269,6 +269,7 @@ func TestPeerHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
close(stop)
|
close(stop)
|
||||||
|
expectMsg(rw, discMsg, nil)
|
||||||
t.Logf("disc reason: %v", <-disc)
|
t.Logf("disc reason: %v", <-disc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -408,7 +408,9 @@ func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.newPeerHook(p)
|
if srv.newPeerHook != nil {
|
||||||
|
srv.newPeerHook(p)
|
||||||
|
}
|
||||||
p.run()
|
p.run()
|
||||||
srv.removePeer(p)
|
srv.removePeer(p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,6 +118,7 @@ func TestServerBroadcast(t *testing.T) {
|
||||||
srv := startTestServer(t, func(p *Peer) {
|
srv := startTestServer(t, func(p *Peer) {
|
||||||
p.protocols = []Protocol{discard}
|
p.protocols = []Protocol{discard}
|
||||||
p.startSubprotocols([]Cap{discard.cap()})
|
p.startSubprotocols([]Cap{discard.cap()})
|
||||||
|
p.noHandshake = true
|
||||||
connected.Done()
|
connected.Done()
|
||||||
})
|
})
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
Loading…
Reference in New Issue