swarm/network: fix data race warning on TestBzzHandshakeLightNode (#18459)

This commit is contained in:
Elad 2019-01-17 17:38:23 +07:00 committed by Anton Evangelatov
parent ba6349d39a
commit 81e26d5a48
2 changed files with 13 additions and 6 deletions

View File

@ -168,7 +168,7 @@ func (b *Bzz) APIs() []rpc.API {
func (b *Bzz) RunProtocol(spec *protocols.Spec, run func(*BzzPeer) error) func(*p2p.Peer, p2p.MsgReadWriter) error { func (b *Bzz) RunProtocol(spec *protocols.Spec, run func(*BzzPeer) error) func(*p2p.Peer, p2p.MsgReadWriter) error {
return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { return func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
// wait for the bzz protocol to perform the handshake // wait for the bzz protocol to perform the handshake
handshake, _ := b.GetHandshake(p.ID()) handshake, _ := b.GetOrCreateHandshake(p.ID())
defer b.removeHandshake(p.ID()) defer b.removeHandshake(p.ID())
select { select {
case <-handshake.done: case <-handshake.done:
@ -213,7 +213,7 @@ func (b *Bzz) performHandshake(p *protocols.Peer, handshake *HandshakeMsg) error
// runBzz is the p2p protocol run function for the bzz base protocol // runBzz is the p2p protocol run function for the bzz base protocol
// that negotiates the bzz handshake // that negotiates the bzz handshake
func (b *Bzz) runBzz(p *p2p.Peer, rw p2p.MsgReadWriter) error { func (b *Bzz) runBzz(p *p2p.Peer, rw p2p.MsgReadWriter) error {
handshake, _ := b.GetHandshake(p.ID()) handshake, _ := b.GetOrCreateHandshake(p.ID())
if !<-handshake.init { if !<-handshake.init {
return fmt.Errorf("%08x: bzz already started on peer %08x", b.localAddr.Over()[:4], p.ID().Bytes()[:4]) return fmt.Errorf("%08x: bzz already started on peer %08x", b.localAddr.Over()[:4], p.ID().Bytes()[:4])
} }
@ -303,7 +303,7 @@ func (b *Bzz) removeHandshake(peerID enode.ID) {
} }
// GetHandshake returns the bzz handhake that the remote peer with peerID sent // GetHandshake returns the bzz handhake that the remote peer with peerID sent
func (b *Bzz) GetHandshake(peerID enode.ID) (*HandshakeMsg, bool) { func (b *Bzz) GetOrCreateHandshake(peerID enode.ID) (*HandshakeMsg, bool) {
b.mtx.Lock() b.mtx.Lock()
defer b.mtx.Unlock() defer b.mtx.Unlock()
handshake, found := b.handshakes[peerID] handshake, found := b.handshakes[peerID]

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"os" "os"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
@ -224,7 +225,7 @@ func TestBzzHandshakeLightNode(t *testing.T) {
for _, test := range lightNodeTests { for _, test := range lightNodeTests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
randomAddr := RandomAddr() randomAddr := RandomAddr()
pt := newBzzHandshakeTester(t, 1, randomAddr, false) pt := newBzzHandshakeTester(nil, 1, randomAddr, false) // TODO change signature - t is not used anywhere
node := pt.Nodes[0] node := pt.Nodes[0]
addr := NewAddr(node) addr := NewAddr(node)
@ -237,9 +238,15 @@ func TestBzzHandshakeLightNode(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
select {
case <-pt.bzz.handshakes[node.ID()].done:
if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode { if pt.bzz.handshakes[node.ID()].LightNode != test.lightNode {
t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode) t.Fatalf("peer LightNode flag is %v, should be %v", pt.bzz.handshakes[node.ID()].LightNode, test.lightNode)
} }
case <-time.After(10 * time.Second):
t.Fatal("test timeout")
}
}) })
} }
} }