diff --git a/core/types/block.go b/core/types/block.go index 5cdde44620..d5cd8a21ea 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -148,6 +148,23 @@ func NewBlockWithHeader(header *Header) *Block { return &Block{header: header} } +func (self *Block) ValidateFields() error { + if self.header == nil { + return fmt.Errorf("header is nil") + } + for i, transaction := range self.transactions { + if transaction == nil { + return fmt.Errorf("transaction %d is nil", i) + } + } + for i, uncle := range self.uncles { + if uncle == nil { + return fmt.Errorf("uncle %d is nil", i) + } + } + return nil +} + func (self *Block) DecodeRLP(s *rlp.Stream) error { var eb extblock if err := s.Decode(&eb); err != nil { diff --git a/eth/protocol.go b/eth/protocol.go index e32ea233b0..a0ab177cdc 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -185,7 +185,10 @@ func (self *ethProtocol) handle() error { if err := msg.Decode(&txs); err != nil { return self.protoError(ErrDecode, "msg %v: %v", msg, err) } - for _, tx := range txs { + for i, tx := range txs { + if tx == nil { + return self.protoError(ErrDecode, "transaction %d is nil", i) + } jsonlogger.LogJson(&logger.EthTxReceived{ TxHash: tx.Hash().Hex(), RemoteId: self.peer.ID().String(), @@ -268,6 +271,9 @@ func (self *ethProtocol) handle() error { return self.protoError(ErrDecode, "msg %v: %v", msg, err) } } + if err := block.ValidateFields(); err != nil { + return self.protoError(ErrDecode, "block validation %v: %v", msg, err) + } self.blockPool.AddBlock(&block, self.id) } @@ -276,6 +282,9 @@ func (self *ethProtocol) handle() error { if err := msg.Decode(&request); err != nil { return self.protoError(ErrDecode, "%v: %v", msg, err) } + if err := request.Block.ValidateFields(); err != nil { + return self.protoError(ErrDecode, "block validation %v: %v", msg, err) + } hash := request.Block.Hash() _, chainHead, _ := self.chainManager.Status() diff --git a/eth/protocol_test.go b/eth/protocol_test.go index 8ca6d1be61..d3466326a6 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -63,6 +63,10 @@ func (self *testChainManager) GetBlockHashesFromHash(hash common.Hash, amount ui func (self *testChainManager) Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash) { if self.status != nil { td, currentBlock, genesisBlock = self.status() + } else { + td = common.Big1 + currentBlock = common.Hash{1} + genesisBlock = common.Hash{2} } return } @@ -163,14 +167,29 @@ func (self *ethProtocolTester) run() { self.quit <- err } +func (self *ethProtocolTester) handshake(t *testing.T, mock bool) { + td, currentBlock, genesis := self.chainManager.Status() + // first outgoing msg should be StatusMsg. + err := p2p.ExpectMsg(self, StatusMsg, &statusMsgData{ + ProtocolVersion: ProtocolVersion, + NetworkId: NetworkId, + TD: td, + CurrentBlock: currentBlock, + GenesisBlock: genesis, + }) + if err != nil { + t.Fatalf("incorrect outgoing status: %v", err) + } + if mock { + go p2p.Send(self, StatusMsg, &statusMsgData{ProtocolVersion, NetworkId, td, currentBlock, genesis}) + } +} + func TestStatusMsgErrors(t *testing.T) { logInit() eth := newEth(t) - td := common.Big1 - currentBlock := common.Hash{1} - genesis := common.Hash{2} - eth.chainManager.status = func() (*big.Int, common.Hash, common.Hash) { return td, currentBlock, genesis } go eth.run() + td, currentBlock, genesis := eth.chainManager.Status() tests := []struct { code uint64 @@ -195,18 +214,7 @@ func TestStatusMsgErrors(t *testing.T) { }, } for _, test := range tests { - // first outgoing msg should be StatusMsg. - err := p2p.ExpectMsg(eth, StatusMsg, &statusMsgData{ - ProtocolVersion: ProtocolVersion, - NetworkId: NetworkId, - TD: td, - CurrentBlock: currentBlock, - GenesisBlock: genesis, - }) - if err != nil { - t.Fatalf("incorrect outgoing status: %v", err) - } - + eth.handshake(t, false) // the send call might hang until reset because // the protocol might not read the payload. go p2p.Send(eth, test.code, test.data) @@ -216,3 +224,177 @@ func TestStatusMsgErrors(t *testing.T) { go eth.run() } } + +func TestNewBlockMsg(t *testing.T) { + // logInit() + eth := newEth(t) + + var disconnected bool + eth.blockPool.removePeer = func(peerId string) { + disconnected = true + } + + go eth.run() + + eth.handshake(t, true) + err := p2p.ExpectMsg(eth, TxMsg, []interface{}{}) + if err != nil { + t.Errorf("transactions expected, got %v", err) + } + + var tds = make(chan *big.Int) + eth.blockPool.addPeer = func(td *big.Int, currentBlock common.Hash, peerId string, requestHashes func(common.Hash) error, requestBlocks func([]common.Hash) error, peerError func(*errs.Error)) (best bool, suspended bool) { + tds <- td + return + } + + var delay = 1 * time.Second + // eth.reset() + block := types.NewBlock(common.Hash{1}, common.Address{1}, common.Hash{1}, common.Big1, 1, "extra") + + go p2p.Send(eth, NewBlockMsg, &newBlockMsgData{Block: block}) + timer := time.After(delay) + + select { + case td := <-tds: + if td.Cmp(common.Big0) != 0 { + t.Errorf("incorrect td %v, expected %v", td, common.Big0) + } + case <-timer: + t.Errorf("no td recorded after %v", delay) + return + case err := <-eth.quit: + t.Errorf("no error expected, got %v", err) + return + } + + go p2p.Send(eth, NewBlockMsg, &newBlockMsgData{block, common.Big2}) + timer = time.After(delay) + + select { + case td := <-tds: + if td.Cmp(common.Big2) != 0 { + t.Errorf("incorrect td %v, expected %v", td, common.Big2) + } + case <-timer: + t.Errorf("no td recorded after %v", delay) + return + case err := <-eth.quit: + t.Errorf("no error expected, got %v", err) + return + } + + go p2p.Send(eth, NewBlockMsg, []interface{}{}) + // Block.DecodeRLP: validation failed: header is nil + eth.checkError(ErrDecode, delay) + +} + +func TestBlockMsg(t *testing.T) { + // logInit() + eth := newEth(t) + blocks := make(chan *types.Block) + eth.blockPool.addBlock = func(block *types.Block, peerId string) (err error) { + blocks <- block + return + } + + var disconnected bool + eth.blockPool.removePeer = func(peerId string) { + disconnected = true + } + + go eth.run() + + eth.handshake(t, true) + err := p2p.ExpectMsg(eth, TxMsg, []interface{}{}) + if err != nil { + t.Errorf("transactions expected, got %v", err) + } + + var delay = 3 * time.Second + // eth.reset() + newblock := func(i int64) *types.Block { + return types.NewBlock(common.Hash{byte(i)}, common.Address{byte(i)}, common.Hash{byte(i)}, big.NewInt(i), uint64(i), string(i)) + } + b := newblock(0) + b.Header().Difficulty = nil // check if nil as *big.Int decodes as 0 + go p2p.Send(eth, BlocksMsg, types.Blocks{b, newblock(1), newblock(2)}) + timer := time.After(delay) + for i := int64(0); i < 3; i++ { + select { + case block := <-blocks: + if (block.ParentHash() != common.Hash{byte(i)}) { + t.Errorf("incorrect block %v, expected %v", block.ParentHash(), common.Hash{byte(i)}) + } + if block.Difficulty().Cmp(big.NewInt(i)) != 0 { + t.Errorf("incorrect block %v, expected %v", block.Difficulty(), big.NewInt(i)) + } + case <-timer: + t.Errorf("no td recorded after %v", delay) + return + case err := <-eth.quit: + t.Errorf("no error expected, got %v", err) + return + } + } + + go p2p.Send(eth, BlocksMsg, []interface{}{[]interface{}{}}) + eth.checkError(ErrDecode, delay) + if !disconnected { + t.Errorf("peer not disconnected after error") + } + + // test empty transaction + eth.reset() + go eth.run() + eth.handshake(t, true) + err = p2p.ExpectMsg(eth, TxMsg, []interface{}{}) + if err != nil { + t.Errorf("transactions expected, got %v", err) + } + b = newblock(0) + b.AddTransaction(nil) + go p2p.Send(eth, BlocksMsg, types.Blocks{b}) + eth.checkError(ErrDecode, delay) + +} + +func TestTransactionsMsg(t *testing.T) { + logInit() + eth := newEth(t) + txs := make(chan *types.Transaction) + + eth.txPool.addTransactions = func(t []*types.Transaction) { + for _, tx := range t { + txs <- tx + } + } + go eth.run() + + eth.handshake(t, true) + err := p2p.ExpectMsg(eth, TxMsg, []interface{}{}) + if err != nil { + t.Errorf("transactions expected, got %v", err) + } + + var delay = 3 * time.Second + tx := &types.Transaction{} + + go p2p.Send(eth, TxMsg, []interface{}{tx, tx}) + timer := time.After(delay) + for i := int64(0); i < 2; i++ { + select { + case <-txs: + case <-timer: + return + case err := <-eth.quit: + t.Errorf("no error expected, got %v", err) + return + } + } + + go p2p.Send(eth, TxMsg, []interface{}{[]interface{}{}}) + eth.checkError(ErrDecode, delay) + +}