diff --git a/p2p/handshake.go b/p2p/handshake.go index 031064407f..5a259cd769 100644 --- a/p2p/handshake.go +++ b/p2p/handshake.go @@ -409,7 +409,7 @@ func readProtocolHandshake(r MsgReader, our *protoHandshake) (*protoHandshake, e // spec and we send it ourself if Server.addPeer fails. var reason DiscReason rlp.Decode(msg.Payload, &reason) - return nil, discRequestedError(reason) + return nil, reason } if msg.Code != handshakeMsg { return nil, fmt.Errorf("expected handshake, got %x", msg.Code) diff --git a/p2p/peer.go b/p2p/peer.go index 6b97ea58d9..a82ee4bcad 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -44,7 +44,7 @@ type Peer struct { rw *conn running map[string]*protoRW - protoWG sync.WaitGroup + wg sync.WaitGroup protoErr chan error closed chan struct{} disc chan DiscReason @@ -102,58 +102,50 @@ func (p *Peer) String() string { func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer { logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], fd.RemoteAddr()) + protomap := matchProtocols(protocols, conn.Caps, conn) p := &Peer{ Logger: logger.NewLogger(logtag), conn: fd, rw: conn, - running: matchProtocols(protocols, conn.Caps, conn), + running: protomap, disc: make(chan DiscReason), - protoErr: make(chan error), + protoErr: make(chan error, len(protomap)+1), // protocols + pingLoop closed: make(chan struct{}), } return p } func (p *Peer) run() DiscReason { - var readErr = make(chan error, 1) - defer p.closeProtocols() - defer close(p.closed) + readErr := make(chan error, 1) + p.wg.Add(2) + go p.readLoop(readErr) + go p.pingLoop() p.startProtocols() - go func() { readErr <- p.readLoop() }() - - ping := time.NewTicker(pingInterval) - defer ping.Stop() // Wait for an error or disconnect. var reason DiscReason -loop: - for { - select { - case <-ping.C: - go func() { - if err := SendItems(p.rw, pingMsg); err != nil { - p.protoErr <- err - return - } - }() - case err := <-readErr: - // We rely on protocols to abort if there is a write error. It - // might be more robust to handle them here as well. - p.DebugDetailf("Read error: %v\n", err) - p.conn.Close() - return DiscNetworkError - case err := <-p.protoErr: - reason = discReasonForError(err) - break loop - case reason = <-p.disc: - break loop + select { + case err := <-readErr: + if r, ok := err.(DiscReason); ok { + reason = r + break } + // Note: We rely on protocols to abort if there is a write + // error. It might be more robust to handle them here as well. + p.DebugDetailf("Read error: %v\n", err) + p.conn.Close() + reason = DiscNetworkError + case err := <-p.protoErr: + reason = discReasonForError(err) + case reason = <-p.disc: } - p.politeDisconnect(reason) - // Wait for readLoop. It will end because conn is now closed. - <-readErr + close(p.closed) + p.wg.Wait() + if reason != DiscNetworkError { + p.politeDisconnect(reason) + } p.Debugf("Disconnected: %v\n", reason) return reason } @@ -174,18 +166,37 @@ func (p *Peer) politeDisconnect(reason DiscReason) { p.conn.Close() } -func (p *Peer) readLoop() error { +func (p *Peer) pingLoop() { + ping := time.NewTicker(pingInterval) + defer p.wg.Done() + defer ping.Stop() + for { + select { + case <-ping.C: + if err := SendItems(p.rw, pingMsg); err != nil { + p.protoErr <- err + return + } + case <-p.closed: + return + } + } +} + +func (p *Peer) readLoop(errc chan<- error) { + defer p.wg.Done() for { p.conn.SetDeadline(time.Now().Add(frameReadTimeout)) msg, err := p.rw.ReadMsg() if err != nil { - return err + errc <- err + return } if err = p.handle(msg); err != nil { - return err + errc <- err + return } } - return nil } func (p *Peer) handle(msg Msg) error { @@ -195,12 +206,11 @@ func (p *Peer) handle(msg Msg) error { go SendItems(p.rw, pongMsg) case msg.Code == discMsg: var reason [1]DiscReason - // no need to discard or for error checking, we'll close the - // connection after this. + // This is the last message. We don't need to discard or + // check errors because, the connection will be closed after it. rlp.Decode(msg.Payload, &reason) p.Debugf("Disconnect requested: %v\n", reason[0]) - p.Disconnect(DiscRequested) - return discRequestedError(reason[0]) + return DiscRequested case msg.Code < baseProtocolLength: // ignore other base protocol messages return msg.Discard() @@ -210,7 +220,12 @@ func (p *Peer) handle(msg Msg) error { if err != nil { return fmt.Errorf("msg code out of range: %v", msg.Code) } - proto.in <- msg + select { + case proto.in <- msg: + return nil + case <-p.closed: + return io.EOF + } } return nil } @@ -234,10 +249,11 @@ outer: } func (p *Peer) startProtocols() { + p.wg.Add(len(p.running)) for _, proto := range p.running { proto := proto + proto.closed = p.closed p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version) - p.protoWG.Add(1) go func() { err := proto.Run(p, proto) if err == nil { @@ -246,11 +262,8 @@ func (p *Peer) startProtocols() { } else { p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err) } - select { - case p.protoErr <- err: - case <-p.closed: - } - p.protoWG.Done() + p.protoErr <- err + p.wg.Done() }() } } @@ -266,13 +279,6 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) { return nil, newPeerError(errInvalidMsgCode, "%d", code) } -func (p *Peer) closeProtocols() { - for _, p := range p.running { - close(p.in) - } - p.protoWG.Wait() -} - // writeProtoMsg sends the given message on behalf of the given named protocol. // this exists because of Server.Broadcast. func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { @@ -289,8 +295,8 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { type protoRW struct { Protocol - in chan Msg + closed <-chan struct{} offset uint64 w MsgWriter } @@ -304,10 +310,11 @@ func (rw *protoRW) WriteMsg(msg Msg) error { } func (rw *protoRW) ReadMsg() (Msg, error) { - msg, ok := <-rw.in - if !ok { - return msg, io.EOF + select { + case msg := <-rw.in: + msg.Code -= rw.offset + return msg, nil + case <-rw.closed: + return Msg{}, io.EOF } - msg.Code -= rw.offset - return msg, nil } diff --git a/p2p/peer_error.go b/p2p/peer_error.go index 0ff4f4b43e..4021316306 100644 --- a/p2p/peer_error.go +++ b/p2p/peer_error.go @@ -98,15 +98,13 @@ func (d DiscReason) String() string { return discReasonToString[d] } -type discRequestedError DiscReason - -func (err discRequestedError) Error() string { - return fmt.Sprintf("disconnect requested: %v", DiscReason(err)) +func (d DiscReason) Error() string { + return d.String() } func discReasonForError(err error) DiscReason { - if reason, ok := err.(discRequestedError); ok { - return DiscReason(reason) + if reason, ok := err.(DiscReason); ok { + return reason } peerError, ok := err.(*peerError) if !ok { diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 3c4c71c0c0..fb76818a06 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -2,8 +2,9 @@ package p2p import ( "bytes" + "errors" "fmt" - "io" + "math/rand" "net" "reflect" "testing" @@ -27,7 +28,7 @@ var discard = Protocol{ }, } -func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) { +func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) { fd1, _ := net.Pipe() hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion} @@ -41,7 +42,11 @@ func testPeer(protos []Protocol) (io.Closer, *conn, *Peer, <-chan DiscReason) { errc := make(chan DiscReason, 1) go func() { errc <- peer.run() }() - return p1, &conn{p2, hs2}, peer, errc + closer := func() { + p1.Close() + fd1.Close() + } + return closer, &conn{p2, hs2}, peer, errc } func TestPeerProtoReadMsg(t *testing.T) { @@ -67,7 +72,7 @@ func TestPeerProtoReadMsg(t *testing.T) { } closer, rw, _, errc := testPeer([]Protocol{proto}) - defer closer.Close() + defer closer() Send(rw, baseProtocolLength+2, []uint{1}) Send(rw, baseProtocolLength+3, []uint{2}) @@ -99,7 +104,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { }, } closer, rw, _, _ := testPeer([]Protocol{proto}) - defer closer.Close() + defer closer() if err := ExpectMsg(rw, 17, []string{"foo", "bar"}); err != nil { t.Error(err) @@ -110,7 +115,7 @@ func TestPeerWriteForBroadcast(t *testing.T) { defer testlog(t).detach() closer, rw, peer, peerErr := testPeer([]Protocol{discard}) - defer closer.Close() + defer closer() emptymsg := func(code uint64) Msg { return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)} @@ -150,7 +155,7 @@ func TestPeerPing(t *testing.T) { defer testlog(t).detach() closer, rw, _, _ := testPeer(nil) - defer closer.Close() + defer closer() if err := SendItems(rw, pingMsg); err != nil { t.Fatal(err) } @@ -163,19 +168,70 @@ func TestPeerDisconnect(t *testing.T) { defer testlog(t).detach() closer, rw, _, disc := testPeer(nil) - defer closer.Close() + defer closer() if err := SendItems(rw, discMsg, DiscQuitting); err != nil { t.Fatal(err) } if err := ExpectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { t.Error(err) } - closer.Close() // make test end faster + closer() if reason := <-disc; reason != DiscRequested { t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) } } +// This test is supposed to verify that Peer can reliably handle +// multiple causes of disconnection occurring at the same time. +func TestPeerDisconnectRace(t *testing.T) { + defer testlog(t).detach() + maybe := func() bool { return rand.Intn(1) == 1 } + + for i := 0; i < 1000; i++ { + protoclose := make(chan error) + protodisc := make(chan DiscReason) + closer, rw, p, disc := testPeer([]Protocol{ + { + Name: "closereq", + Run: func(p *Peer, rw MsgReadWriter) error { return <-protoclose }, + Length: 1, + }, + { + Name: "disconnect", + Run: func(p *Peer, rw MsgReadWriter) error { p.Disconnect(<-protodisc); return nil }, + Length: 1, + }, + }) + + // Simulate incoming messages. + go SendItems(rw, baseProtocolLength+1) + go SendItems(rw, baseProtocolLength+2) + // Close the network connection. + go closer() + // Make protocol "closereq" return. + protoclose <- errors.New("protocol closed") + // Make protocol "disconnect" call peer.Disconnect + protodisc <- DiscAlreadyConnected + // In some cases, simulate something else calling peer.Disconnect. + if maybe() { + go p.Disconnect(DiscInvalidIdentity) + } + // In some cases, simulate remote requesting a disconnect. + if maybe() { + go SendItems(rw, discMsg, DiscQuitting) + } + + select { + case <-disc: + case <-time.After(2 * time.Second): + // Peer.run should return quickly. If it doesn't the Peer + // goroutines are probably deadlocked. Call panic in order to + // show the stacks. + panic("Peer.run took to long to return.") + } + } +} + func TestNewPeer(t *testing.T) { name := "nodename" caps := []Cap{{"foo", 2}, {"bar", 3}}