Merge pull request #1261 from fjl/p2p-no-writes-at-shutdown
p2p: prevent writes at shutdown time
This commit is contained in:
commit
f475a01326
69
p2p/peer.go
69
p2p/peer.go
|
@ -115,41 +115,60 @@ func newPeer(conn *conn, protocols []Protocol) *Peer {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) run() DiscReason {
|
func (p *Peer) run() DiscReason {
|
||||||
readErr := make(chan error, 1)
|
var (
|
||||||
|
writeStart = make(chan struct{}, 1)
|
||||||
|
writeErr = make(chan error, 1)
|
||||||
|
readErr = make(chan error, 1)
|
||||||
|
reason DiscReason
|
||||||
|
requested bool
|
||||||
|
)
|
||||||
p.wg.Add(2)
|
p.wg.Add(2)
|
||||||
go p.readLoop(readErr)
|
go p.readLoop(readErr)
|
||||||
go p.pingLoop()
|
go p.pingLoop()
|
||||||
|
|
||||||
p.startProtocols()
|
// Start all protocol handlers.
|
||||||
|
writeStart <- struct{}{}
|
||||||
|
p.startProtocols(writeStart, writeErr)
|
||||||
|
|
||||||
// Wait for an error or disconnect.
|
// Wait for an error or disconnect.
|
||||||
var (
|
loop:
|
||||||
reason DiscReason
|
for {
|
||||||
requested bool
|
|
||||||
)
|
|
||||||
select {
|
select {
|
||||||
|
case err := <-writeErr:
|
||||||
|
// A write finished. Allow the next write to start if
|
||||||
|
// there was no error.
|
||||||
|
if err != nil {
|
||||||
|
glog.V(logger.Detail).Infof("%v: write error: %v\n", p, err)
|
||||||
|
reason = DiscNetworkError
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
writeStart <- struct{}{}
|
||||||
case err := <-readErr:
|
case err := <-readErr:
|
||||||
if r, ok := err.(DiscReason); ok {
|
if r, ok := err.(DiscReason); ok {
|
||||||
|
glog.V(logger.Debug).Infof("%v: remote requested disconnect: %v\n", p, r)
|
||||||
|
requested = true
|
||||||
reason = r
|
reason = r
|
||||||
} else {
|
} else {
|
||||||
// Note: We rely on protocols to abort if there is a write
|
glog.V(logger.Detail).Infof("%v: read error: %v\n", p, err)
|
||||||
// error. It might be more robust to handle them here as well.
|
|
||||||
glog.V(logger.Detail).Infof("%v: Read error: %v\n", p, err)
|
|
||||||
reason = DiscNetworkError
|
reason = DiscNetworkError
|
||||||
}
|
}
|
||||||
|
break loop
|
||||||
case err := <-p.protoErr:
|
case err := <-p.protoErr:
|
||||||
reason = discReasonForError(err)
|
reason = discReasonForError(err)
|
||||||
|
glog.V(logger.Debug).Infof("%v: protocol error: %v (%v)\n", p, err, reason)
|
||||||
|
break loop
|
||||||
case reason = <-p.disc:
|
case reason = <-p.disc:
|
||||||
requested = true
|
glog.V(logger.Debug).Infof("%v: locally requested disconnect: %v\n", p, reason)
|
||||||
|
break loop
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
close(p.closed)
|
close(p.closed)
|
||||||
p.rw.close(reason)
|
p.rw.close(reason)
|
||||||
p.wg.Wait()
|
p.wg.Wait()
|
||||||
|
|
||||||
if requested {
|
if requested {
|
||||||
reason = DiscRequested
|
reason = DiscRequested
|
||||||
}
|
}
|
||||||
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
|
|
||||||
return reason
|
return reason
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,7 +215,6 @@ func (p *Peer) handle(msg Msg) error {
|
||||||
// This is the last message. We don't need to discard or
|
// This is the last message. We don't need to discard or
|
||||||
// check errors because, the connection will be closed after it.
|
// check errors because, the connection will be closed after it.
|
||||||
rlp.Decode(msg.Payload, &reason)
|
rlp.Decode(msg.Payload, &reason)
|
||||||
glog.V(logger.Debug).Infof("%v: Disconnect Requested: %v\n", p, reason[0])
|
|
||||||
return reason[0]
|
return reason[0]
|
||||||
case msg.Code < baseProtocolLength:
|
case msg.Code < baseProtocolLength:
|
||||||
// ignore other base protocol messages
|
// ignore other base protocol messages
|
||||||
|
@ -247,11 +265,13 @@ outer:
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) startProtocols() {
|
func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
|
||||||
p.wg.Add(len(p.running))
|
p.wg.Add(len(p.running))
|
||||||
for _, proto := range p.running {
|
for _, proto := range p.running {
|
||||||
proto := proto
|
proto := proto
|
||||||
proto.closed = p.closed
|
proto.closed = p.closed
|
||||||
|
proto.wstart = writeStart
|
||||||
|
proto.werr = writeErr
|
||||||
glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
|
glog.V(logger.Detail).Infof("%v: Starting protocol %s/%d\n", p, proto.Name, proto.Version)
|
||||||
go func() {
|
go func() {
|
||||||
err := proto.Run(p, proto)
|
err := proto.Run(p, proto)
|
||||||
|
@ -280,18 +300,31 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
|
||||||
|
|
||||||
type protoRW struct {
|
type protoRW struct {
|
||||||
Protocol
|
Protocol
|
||||||
in chan Msg
|
in chan Msg // receices read messages
|
||||||
closed <-chan struct{}
|
closed <-chan struct{} // receives when peer is shutting down
|
||||||
|
wstart <-chan struct{} // receives when write may start
|
||||||
|
werr chan<- error // for write results
|
||||||
offset uint64
|
offset uint64
|
||||||
w MsgWriter
|
w MsgWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *protoRW) WriteMsg(msg Msg) error {
|
func (rw *protoRW) WriteMsg(msg Msg) (err error) {
|
||||||
if msg.Code >= rw.Length {
|
if msg.Code >= rw.Length {
|
||||||
return newPeerError(errInvalidMsgCode, "not handled")
|
return newPeerError(errInvalidMsgCode, "not handled")
|
||||||
}
|
}
|
||||||
msg.Code += rw.offset
|
msg.Code += rw.offset
|
||||||
return rw.w.WriteMsg(msg)
|
select {
|
||||||
|
case <-rw.wstart:
|
||||||
|
err = rw.w.WriteMsg(msg)
|
||||||
|
// Report write status back to Peer.run. It will initiate
|
||||||
|
// shutdown if the error is non-nil and unblock the next write
|
||||||
|
// otherwise. The calling protocol code should exit for errors
|
||||||
|
// as well but we don't want to rely on that.
|
||||||
|
rw.werr <- err
|
||||||
|
case <-rw.closed:
|
||||||
|
err = fmt.Errorf("shutting down")
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *protoRW) ReadMsg() (Msg, error) {
|
func (rw *protoRW) ReadMsg() (Msg, error) {
|
||||||
|
|
|
@ -121,7 +121,7 @@ func TestPeerDisconnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case reason := <-disc:
|
case reason := <-disc:
|
||||||
if reason != DiscQuitting {
|
if reason != DiscRequested {
|
||||||
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
|
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
|
||||||
}
|
}
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
|
Loading…
Reference in New Issue