This commit is contained in:
obscuren 2015-01-05 17:10:42 +01:00
parent b0854fbff5
commit 6abf8ef78f
41 changed files with 2329 additions and 1203 deletions

View File

@ -59,6 +59,8 @@ var (
DumpNumber int DumpNumber int
VmType int VmType int
ImportChain string ImportChain string
SHH bool
Dial bool
) )
// flags specific to cli client // flags specific to cli client
@ -94,6 +96,8 @@ func Init() {
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server") flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)") flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
flag.BoolVar(&UseSeed, "seed", true, "seed peers") flag.BoolVar(&UseSeed, "seed", true, "seed peers")
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key") flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)") flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given") flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
@ -105,7 +109,7 @@ func Init() {
flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0") flag.BoolVar(&DiffTool, "difftool", false, "creates output for diff'ing. Sets LogLevel=0")
flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false") flag.StringVar(&DiffType, "diff", "all", "sets the level of diff output [vm, all]. Has no effect if difftool=false")
flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block") flag.BoolVar(&ShowGenesis, "genesis", false, "Dump the genesis block")
flag.StringVar(&ImportChain, "chain", "", "Imports fiven chain") flag.StringVar(&ImportChain, "chain", "", "Imports given chain")
flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]") flag.BoolVar(&Dump, "dump", false, "output the ethereum state in JSON format. Sub args [number, hash]")
flag.StringVar(&DumpHash, "hash", "", "specify arg in hex") flag.StringVar(&DumpHash, "hash", "", "specify arg in hex")

View File

@ -64,10 +64,14 @@ func main() {
NATType: PMPGateway, NATType: PMPGateway,
PMPGateway: PMPGateway, PMPGateway: PMPGateway,
KeyRing: KeyRing, KeyRing: KeyRing,
Shh: SHH,
Dial: Dial,
}) })
if err != nil { if err != nil {
clilogger.Fatalln(err) clilogger.Fatalln(err)
} }
utils.KeyTasks(ethereum.KeyManager(), KeyRing, GenAddr, SecretFile, ExportDir, NonInteractive) utils.KeyTasks(ethereum.KeyManager(), KeyRing, GenAddr, SecretFile, ExportDir, NonInteractive)
if Dump { if Dump {
@ -112,13 +116,6 @@ func main() {
return return
} }
// better reworked as cases
if StartJsConsole {
InitJsConsole(ethereum)
} else if len(InputFile) > 0 {
ExecJsFile(ethereum, InputFile)
}
if StartRpc { if StartRpc {
utils.StartRpc(ethereum, RpcPort) utils.StartRpc(ethereum, RpcPort)
} }
@ -129,6 +126,11 @@ func main() {
utils.StartEthereum(ethereum, UseSeed) utils.StartEthereum(ethereum, UseSeed)
if StartJsConsole {
InitJsConsole(ethereum)
} else if len(InputFile) > 0 {
ExecJsFile(ethereum, InputFile)
}
// this blocks the thread // this blocks the thread
ethereum.WaitForShutdown() ethereum.WaitForShutdown()
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"gopkg.in/fatih/set.v0"
) )
var txplogger = logger.NewLogger("TXP") var txplogger = logger.NewLogger("TXP")
@ -38,7 +37,7 @@ type TxPool struct {
quit chan bool quit chan bool
// The actual pool // The actual pool
//pool *list.List //pool *list.List
pool *set.Set txs map[string]*types.Transaction
SecondaryProcessor TxProcessor SecondaryProcessor TxProcessor
@ -49,21 +48,19 @@ type TxPool struct {
func NewTxPool(eventMux *event.TypeMux) *TxPool { func NewTxPool(eventMux *event.TypeMux) *TxPool {
return &TxPool{ return &TxPool{
pool: set.New(), txs: make(map[string]*types.Transaction),
queueChan: make(chan *types.Transaction, txPoolQueueSize), queueChan: make(chan *types.Transaction, txPoolQueueSize),
quit: make(chan bool), quit: make(chan bool),
eventMux: eventMux, eventMux: eventMux,
} }
} }
func (pool *TxPool) addTransaction(tx *types.Transaction) {
pool.pool.Add(tx)
// Broadcast the transaction to the rest of the peers
pool.eventMux.Post(TxPreEvent{tx})
}
func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error {
hash := tx.Hash()
if pool.txs[string(hash)] != nil {
return fmt.Errorf("Known transaction (%x)", hash[0:4])
}
if len(tx.To()) != 0 && len(tx.To()) != 20 { if len(tx.To()) != 0 && len(tx.To()) != 20 {
return fmt.Errorf("Invalid recipient. len = %d", len(tx.To())) return fmt.Errorf("Invalid recipient. len = %d", len(tx.To()))
} }
@ -95,18 +92,17 @@ func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error {
return nil return nil
} }
func (self *TxPool) Add(tx *types.Transaction) error { func (self *TxPool) addTx(tx *types.Transaction) {
hash := tx.Hash() self.txs[string(tx.Hash())] = tx
if self.pool.Has(tx) { }
return fmt.Errorf("Known transaction (%x)", hash[0:4])
}
func (self *TxPool) Add(tx *types.Transaction) error {
err := self.ValidateTransaction(tx) err := self.ValidateTransaction(tx)
if err != nil { if err != nil {
return err return err
} }
self.addTransaction(tx) self.addTx(tx)
var to string var to string
if len(tx.To()) > 0 { if len(tx.To()) > 0 {
@ -124,7 +120,7 @@ func (self *TxPool) Add(tx *types.Transaction) error {
} }
func (self *TxPool) Size() int { func (self *TxPool) Size() int {
return self.pool.Size() return len(self.txs)
} }
func (self *TxPool) AddTransactions(txs []*types.Transaction) { func (self *TxPool) AddTransactions(txs []*types.Transaction) {
@ -137,43 +133,39 @@ func (self *TxPool) AddTransactions(txs []*types.Transaction) {
} }
} }
func (pool *TxPool) GetTransactions() []*types.Transaction { func (self *TxPool) GetTransactions() (txs types.Transactions) {
txList := make([]*types.Transaction, pool.Size()) txs = make(types.Transactions, self.Size())
i := 0 i := 0
pool.pool.Each(func(v interface{}) bool { for _, tx := range self.txs {
txList[i] = v.(*types.Transaction) txs[i] = tx
i++ i++
}
return true return
})
return txList
} }
func (pool *TxPool) RemoveInvalid(query StateQuery) { func (pool *TxPool) RemoveInvalid(query StateQuery) {
var removedTxs types.Transactions var removedTxs types.Transactions
pool.pool.Each(func(v interface{}) bool { for _, tx := range pool.txs {
tx := v.(*types.Transaction)
sender := query.GetAccount(tx.From()) sender := query.GetAccount(tx.From())
err := pool.ValidateTransaction(tx) err := pool.ValidateTransaction(tx)
if err != nil || sender.Nonce >= tx.Nonce() { if err != nil || sender.Nonce >= tx.Nonce() {
removedTxs = append(removedTxs, tx) removedTxs = append(removedTxs, tx)
} }
}
return true
})
pool.RemoveSet(removedTxs) pool.RemoveSet(removedTxs)
} }
func (self *TxPool) RemoveSet(txs types.Transactions) { func (self *TxPool) RemoveSet(txs types.Transactions) {
for _, tx := range txs { for _, tx := range txs {
self.pool.Remove(tx) delete(self.txs, string(tx.Hash()))
} }
} }
func (pool *TxPool) Flush() []*types.Transaction { func (pool *TxPool) Flush() []*types.Transaction {
txList := pool.GetTransactions() txList := pool.GetTransactions()
pool.pool.Clear() pool.txs = make(map[string]*types.Transaction)
return txList return txList
} }

View File

@ -67,6 +67,9 @@ func (self *Header) HashNoNonce() []byte {
} }
type Block struct { type Block struct {
// Preset Hash for mock
HeaderHash []byte
ParentHeaderHash []byte
header *Header header *Header
uncles []*Header uncles []*Header
transactions Transactions transactions Transactions
@ -99,41 +102,19 @@ func NewBlockWithHeader(header *Header) *Block {
} }
func (self *Block) DecodeRLP(s *rlp.Stream) error { func (self *Block) DecodeRLP(s *rlp.Stream) error {
if _, err := s.List(); err != nil { var extblock struct {
Header *Header
Txs []*Transaction
Uncles []*Header
TD *big.Int // optional
}
if err := s.Decode(&extblock); err != nil {
return err return err
} }
self.header = extblock.Header
var header Header self.uncles = extblock.Uncles
if err := s.Decode(&header); err != nil { self.transactions = extblock.Txs
return err self.Td = extblock.TD
}
var transactions []*Transaction
if err := s.Decode(&transactions); err != nil {
return err
}
var uncleHeaders []*Header
if err := s.Decode(&uncleHeaders); err != nil {
return err
}
var tdBytes []byte
if err := s.Decode(&tdBytes); err != nil {
// If this block comes from the network that's fine. If loaded from disk it should be there
// Blocks don't store their Td when propagated over the network
} else {
self.Td = ethutil.BigD(tdBytes)
}
if err := s.ListEnd(); err != nil {
return err
}
self.header = &header
self.uncles = uncleHeaders
self.transactions = transactions
return nil return nil
} }
@ -189,23 +170,35 @@ func (self *Block) RlpDataForStorage() interface{} {
// Header accessors (add as you need them) // Header accessors (add as you need them)
func (self *Block) Number() *big.Int { return self.header.Number } func (self *Block) Number() *big.Int { return self.header.Number }
func (self *Block) NumberU64() uint64 { return self.header.Number.Uint64() } func (self *Block) NumberU64() uint64 { return self.header.Number.Uint64() }
func (self *Block) ParentHash() []byte { return self.header.ParentHash }
func (self *Block) Bloom() []byte { return self.header.Bloom } func (self *Block) Bloom() []byte { return self.header.Bloom }
func (self *Block) Coinbase() []byte { return self.header.Coinbase } func (self *Block) Coinbase() []byte { return self.header.Coinbase }
func (self *Block) Time() int64 { return int64(self.header.Time) } func (self *Block) Time() int64 { return int64(self.header.Time) }
func (self *Block) GasLimit() *big.Int { return self.header.GasLimit } func (self *Block) GasLimit() *big.Int { return self.header.GasLimit }
func (self *Block) GasUsed() *big.Int { return self.header.GasUsed } func (self *Block) GasUsed() *big.Int { return self.header.GasUsed }
func (self *Block) Hash() []byte { return self.header.Hash() }
func (self *Block) Trie() *ptrie.Trie { return ptrie.New(self.header.Root, ethutil.Config.Db) } func (self *Block) Trie() *ptrie.Trie { return ptrie.New(self.header.Root, ethutil.Config.Db) }
func (self *Block) SetRoot(root []byte) { self.header.Root = root }
func (self *Block) State() *state.StateDB { return state.New(self.Trie()) } func (self *Block) State() *state.StateDB { return state.New(self.Trie()) }
func (self *Block) Size() ethutil.StorageSize { return ethutil.StorageSize(len(ethutil.Encode(self))) } func (self *Block) Size() ethutil.StorageSize { return ethutil.StorageSize(len(ethutil.Encode(self))) }
func (self *Block) SetRoot(root []byte) { self.header.Root = root }
// Implement block.Pow // Implement pow.Block
func (self *Block) Difficulty() *big.Int { return self.header.Difficulty } func (self *Block) Difficulty() *big.Int { return self.header.Difficulty }
func (self *Block) N() []byte { return self.header.Nonce } func (self *Block) N() []byte { return self.header.Nonce }
func (self *Block) HashNoNonce() []byte { func (self *Block) HashNoNonce() []byte { return self.header.HashNoNonce() }
return crypto.Sha3(ethutil.Encode(self.header.rlpData(false)))
func (self *Block) Hash() []byte {
if self.HeaderHash != nil {
return self.HeaderHash
} else {
return self.header.Hash()
}
}
func (self *Block) ParentHash() []byte {
if self.ParentHeaderHash != nil {
return self.ParentHeaderHash
} else {
return self.header.ParentHash
}
} }
func (self *Block) String() string { func (self *Block) String() string {

View File

@ -36,6 +36,9 @@ type Config struct {
NATType string NATType string
PMPGateway string PMPGateway string
Shh bool
Dial bool
KeyManager *crypto.KeyManager KeyManager *crypto.KeyManager
} }
@ -130,11 +133,13 @@ func New(config *Config) (*Ethereum, error) {
insertChain := eth.chainManager.InsertChain insertChain := eth.chainManager.InsertChain
eth.blockPool = NewBlockPool(hasBlock, insertChain, ezp.Verify) eth.blockPool = NewBlockPool(hasBlock, insertChain, ezp.Verify)
// Start services
eth.txPool.Start()
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool) ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()} protocols := []p2p.Protocol{ethProto}
if config.Shh {
eth.whisper = whisper.New()
protocols = append(protocols, eth.whisper.Protocol())
}
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway) nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway)
if err != nil { if err != nil {
@ -145,9 +150,13 @@ func New(config *Config) (*Ethereum, error) {
Identity: clientId, Identity: clientId,
MaxPeers: config.MaxPeers, MaxPeers: config.MaxPeers,
Protocols: protocols, Protocols: protocols,
ListenAddr: ":" + config.Port,
Blacklist: eth.blacklist, Blacklist: eth.blacklist,
NAT: nat, NAT: nat,
NoDial: !config.Dial,
}
if len(config.Port) > 0 {
eth.net.ListenAddr = ":" + config.Port
} }
return eth, nil return eth, nil
@ -219,8 +228,14 @@ func (s *Ethereum) Start(seed bool) error {
if err != nil { if err != nil {
return err return err
} }
// Start services
s.txPool.Start()
s.blockPool.Start() s.blockPool.Start()
if s.whisper != nil {
s.whisper.Start() s.whisper.Start()
}
// broadcast transactions // broadcast transactions
s.txSub = s.eventMux.Subscribe(core.TxPreEvent{}) s.txSub = s.eventMux.Subscribe(core.TxPreEvent{})
@ -268,7 +283,9 @@ func (s *Ethereum) Stop() {
s.txPool.Stop() s.txPool.Stop()
s.eventMux.Stop() s.eventMux.Stop()
s.blockPool.Stop() s.blockPool.Stop()
if s.whisper != nil {
s.whisper.Stop() s.whisper.Stop()
}
logger.Infoln("Server stopped") logger.Infoln("Server stopped")
close(s.shutdownChan) close(s.shutdownChan)
@ -285,16 +302,16 @@ func (self *Ethereum) txBroadcastLoop() {
// automatically stops if unsubscribe // automatically stops if unsubscribe
for obj := range self.txSub.Chan() { for obj := range self.txSub.Chan() {
event := obj.(core.TxPreEvent) event := obj.(core.TxPreEvent)
self.net.Broadcast("eth", TxMsg, []interface{}{event.Tx.RlpData()}) self.net.Broadcast("eth", TxMsg, event.Tx.RlpData())
} }
} }
func (self *Ethereum) blockBroadcastLoop() { func (self *Ethereum) blockBroadcastLoop() {
// automatically stops if unsubscribe // automatically stops if unsubscribe
for obj := range self.txSub.Chan() { for obj := range self.blockSub.Chan() {
switch ev := obj.(type) { switch ev := obj.(type) {
case core.NewMinedBlockEvent: case core.NewMinedBlockEvent:
self.net.Broadcast("eth", NewBlockMsg, ev.Block.RlpData()) self.net.Broadcast("eth", NewBlockMsg, ev.Block.RlpData(), ev.Block.Td)
} }
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,115 +1,65 @@
package eth package eth
import ( import (
"bytes"
"fmt" "fmt"
"log" "log"
"math/big"
"os" "os"
"sync" "sync"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
ethlogger "github.com/ethereum/go-ethereum/logger" ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/pow"
) )
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel)) const waitTimeout = 60 // seconds
type testChainManager struct { var logsys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugLevel))
knownBlock func(hash []byte) bool
addBlock func(*types.Block) error var ini = false
checkPoW func(*types.Block) bool
func logInit() {
if !ini {
ethlogger.AddLogSystem(logsys)
ini = true
}
} }
func (self *testChainManager) KnownBlock(hash []byte) bool { // test helpers
if self.knownBlock != nil { func arrayEq(a, b []int) bool {
return self.knownBlock(hash) if len(a) != len(b) {
}
return false
}
func (self *testChainManager) AddBlock(block *types.Block) error {
if self.addBlock != nil {
return self.addBlock(block)
}
return nil
}
func (self *testChainManager) CheckPoW(block *types.Block) bool {
if self.checkPoW != nil {
return self.checkPoW(block)
}
return false
}
func knownBlock(hashes ...[]byte) (f func([]byte) bool) {
f = func(block []byte) bool {
for _, hash := range hashes {
if bytes.Compare(block, hash) == 0 {
return true
}
}
return false return false
} }
return for i := range a {
} if a[i] != b[i] {
func addBlock(hashes ...[]byte) (f func(*types.Block) error) {
f = func(block *types.Block) error {
for _, hash := range hashes {
if bytes.Compare(block.Hash(), hash) == 0 {
return fmt.Errorf("invalid by test")
}
}
return nil
}
return
}
func checkPoW(hashes ...[]byte) (f func(*types.Block) bool) {
f = func(block *types.Block) bool {
for _, hash := range hashes {
if bytes.Compare(block.Hash(), hash) == 0 {
return false return false
} }
} }
return true return true
}
return
}
func newTestChainManager(knownBlocks [][]byte, invalidBlocks [][]byte, invalidPoW [][]byte) *testChainManager {
return &testChainManager{
knownBlock: knownBlock(knownBlocks...),
addBlock: addBlock(invalidBlocks...),
checkPoW: checkPoW(invalidPoW...),
}
} }
type intToHash map[int][]byte type intToHash map[int][]byte
type hashToInt map[string]int type hashToInt map[string]int
// hashPool is a test helper, that allows random hashes to be referred to by integers
type testHashPool struct { type testHashPool struct {
intToHash intToHash
hashToInt hashToInt
lock sync.Mutex
} }
func newHash(i int) []byte { func newHash(i int) []byte {
return crypto.Sha3([]byte(string(i))) return crypto.Sha3([]byte(string(i)))
} }
func newTestBlockPool(knownBlockIndexes []int, invalidBlockIndexes []int, invalidPoWIndexes []int) (hashPool *testHashPool, blockPool *BlockPool) {
hashPool = &testHashPool{make(intToHash), make(hashToInt)}
knownBlocks := hashPool.indexesToHashes(knownBlockIndexes)
invalidBlocks := hashPool.indexesToHashes(invalidBlockIndexes)
invalidPoW := hashPool.indexesToHashes(invalidPoWIndexes)
blockPool = NewBlockPool(newTestChainManager(knownBlocks, invalidBlocks, invalidPoW))
return
}
func (self *testHashPool) indexesToHashes(indexes []int) (hashes [][]byte) { func (self *testHashPool) indexesToHashes(indexes []int) (hashes [][]byte) {
self.lock.Lock()
defer self.lock.Unlock()
for _, i := range indexes { for _, i := range indexes {
hash, found := self.intToHash[i] hash, found := self.intToHash[i]
if !found { if !found {
@ -123,6 +73,8 @@ func (self *testHashPool) indexesToHashes(indexes []int) (hashes [][]byte) {
} }
func (self *testHashPool) hashesToIndexes(hashes [][]byte) (indexes []int) { func (self *testHashPool) hashesToIndexes(hashes [][]byte) (indexes []int) {
self.lock.Lock()
defer self.lock.Unlock()
for _, hash := range hashes { for _, hash := range hashes {
i, found := self.hashToInt[string(hash)] i, found := self.hashToInt[string(hash)]
if !found { if !found {
@ -133,66 +85,812 @@ func (self *testHashPool) hashesToIndexes(hashes [][]byte) (indexes []int) {
return return
} }
type protocolChecker struct { // test blockChain is an integer trie
blockHashesRequests []int type blockChain map[int][]int
blocksRequests [][]int
invalidBlocks []error // blockPoolTester provides the interface between tests and a blockPool
//
// refBlockChain is used to guide which blocks will be accepted as valid
// blockChain gives the current state of the blockchain and
// accumulates inserts so that we can check the resulting chain
type blockPoolTester struct {
hashPool *testHashPool hashPool *testHashPool
lock sync.Mutex lock sync.RWMutex
refBlockChain blockChain
blockChain blockChain
blockPool *BlockPool
t *testing.T
} }
func newTestBlockPool(t *testing.T) (hashPool *testHashPool, blockPool *BlockPool, b *blockPoolTester) {
hashPool = &testHashPool{intToHash: make(intToHash), hashToInt: make(hashToInt)}
b = &blockPoolTester{
t: t,
hashPool: hashPool,
blockChain: make(blockChain),
refBlockChain: make(blockChain),
}
b.blockPool = NewBlockPool(b.hasBlock, b.insertChain, b.verifyPoW)
blockPool = b.blockPool
return
}
func (self *blockPoolTester) Errorf(format string, params ...interface{}) {
fmt.Printf(format+"\n", params...)
self.t.Errorf(format, params...)
}
// blockPoolTester implements the 3 callbacks needed by the blockPool:
// hasBlock, insetChain, verifyPoW
func (self *blockPoolTester) hasBlock(block []byte) (ok bool) {
self.lock.RLock()
defer self.lock.RUnlock()
indexes := self.hashPool.hashesToIndexes([][]byte{block})
i := indexes[0]
_, ok = self.blockChain[i]
fmt.Printf("has block %v (%x...): %v\n", i, block[0:4], ok)
return
}
func (self *blockPoolTester) insertChain(blocks types.Blocks) error {
self.lock.RLock()
defer self.lock.RUnlock()
var parent, child int
var children, refChildren []int
var ok bool
for _, block := range blocks {
child = self.hashPool.hashesToIndexes([][]byte{block.Hash()})[0]
_, ok = self.blockChain[child]
if ok {
fmt.Printf("block %v already in blockchain\n", child)
continue // already in chain
}
parent = self.hashPool.hashesToIndexes([][]byte{block.ParentHeaderHash})[0]
children, ok = self.blockChain[parent]
if !ok {
return fmt.Errorf("parent %v not in blockchain ", parent)
}
ok = false
var found bool
refChildren, found = self.refBlockChain[parent]
if found {
for _, c := range refChildren {
if c == child {
ok = true
}
}
if !ok {
return fmt.Errorf("invalid block %v", child)
}
} else {
ok = true
}
if ok {
// accept any blocks if parent not in refBlockChain
fmt.Errorf("blockchain insert %v -> %v\n", parent, child)
self.blockChain[parent] = append(children, child)
self.blockChain[child] = nil
}
}
return nil
}
func (self *blockPoolTester) verifyPoW(pblock pow.Block) bool {
return true
}
// test helper that compares the resulting blockChain to the desired blockChain
func (self *blockPoolTester) checkBlockChain(blockChain map[int][]int) {
for k, v := range self.blockChain {
fmt.Printf("got: %v -> %v\n", k, v)
}
for k, v := range blockChain {
fmt.Printf("expected: %v -> %v\n", k, v)
}
if len(blockChain) != len(self.blockChain) {
self.Errorf("blockchain incorrect (zlength differ)")
}
for k, v := range blockChain {
vv, ok := self.blockChain[k]
if !ok || !arrayEq(v, vv) {
self.Errorf("blockchain incorrect on %v -> %v (!= %v)", k, vv, v)
}
}
}
//
// peerTester provides the peer callbacks for the blockPool
// it registers actual callbacks so that result can be compared to desired behaviour
// provides helper functions to mock the protocol calls to the blockPool
type peerTester struct {
blockHashesRequests []int
blocksRequests [][]int
blocksRequestsMap map[int]bool
peerErrors []int
blockPool *BlockPool
hashPool *testHashPool
lock sync.RWMutex
id string
td int
currentBlock int
t *testing.T
}
// peerTester constructor takes hashPool and blockPool from the blockPoolTester
func (self *blockPoolTester) newPeer(id string, td int, cb int) *peerTester {
return &peerTester{
id: id,
td: td,
currentBlock: cb,
hashPool: self.hashPool,
blockPool: self.blockPool,
t: self.t,
blocksRequestsMap: make(map[int]bool),
}
}
func (self *peerTester) Errorf(format string, params ...interface{}) {
fmt.Printf(format+"\n", params...)
self.t.Errorf(format, params...)
}
// helper to compare actual and expected block requests
func (self *peerTester) checkBlocksRequests(blocksRequests ...[]int) {
if len(blocksRequests) > len(self.blocksRequests) {
self.Errorf("blocks requests incorrect (length differ)\ngot %v\nexpected %v", self.blocksRequests, blocksRequests)
} else {
for i, rr := range blocksRequests {
r := self.blocksRequests[i]
if !arrayEq(r, rr) {
self.Errorf("blocks requests incorrect\ngot %v\nexpected %v", self.blocksRequests, blocksRequests)
}
}
}
}
// helper to compare actual and expected block hash requests
func (self *peerTester) checkBlockHashesRequests(blocksHashesRequests ...int) {
rr := blocksHashesRequests
self.lock.RLock()
r := self.blockHashesRequests
self.lock.RUnlock()
if len(r) != len(rr) {
self.Errorf("block hashes requests incorrect (length differ)\ngot %v\nexpected %v", r, rr)
} else {
if !arrayEq(r, rr) {
self.Errorf("block hashes requests incorrect\ngot %v\nexpected %v", r, rr)
}
}
}
// waiter function used by peer.AddBlocks
// blocking until requests appear
// since block requests are sent to any random peers
// block request map is shared between peers
// times out after a period
func (self *peerTester) waitBlocksRequests(blocksRequest ...int) {
timeout := time.After(waitTimeout * time.Second)
rr := blocksRequest
for {
self.lock.RLock()
r := self.blocksRequestsMap
fmt.Printf("[%s] blocks request check %v (%v)\n", self.id, rr, r)
i := 0
for i = 0; i < len(rr); i++ {
_, ok := r[rr[i]]
if !ok {
break
}
}
self.lock.RUnlock()
if i == len(rr) {
return
}
time.Sleep(100 * time.Millisecond)
select {
case <-timeout:
default:
}
}
}
// waiter function used by peer.AddBlockHashes
// blocking until requests appear
// times out after a period
func (self *peerTester) waitBlockHashesRequests(blocksHashesRequest int) {
timeout := time.After(waitTimeout * time.Second)
rr := blocksHashesRequest
for i := 0; ; {
self.lock.RLock()
r := self.blockHashesRequests
self.lock.RUnlock()
fmt.Printf("[%s] block hash request check %v (%v)\n", self.id, rr, r)
for ; i < len(r); i++ {
if rr == r[i] {
return
}
}
time.Sleep(100 * time.Millisecond)
select {
case <-timeout:
default:
}
}
}
// mocks a simple blockchain 0 (genesis) ... n (head)
func (self *blockPoolTester) initRefBlockChain(n int) {
for i := 0; i < n; i++ {
self.refBlockChain[i] = []int{i + 1}
}
}
// peerTester functions that mimic protocol calls to the blockpool
// registers the peer with the blockPool
func (self *peerTester) AddPeer() bool {
hash := self.hashPool.indexesToHashes([]int{self.currentBlock})[0]
return self.blockPool.AddPeer(big.NewInt(int64(self.td)), hash, self.id, self.requestBlockHashes, self.requestBlocks, self.peerError)
}
// peer sends blockhashes if and when gets a request
func (self *peerTester) AddBlockHashes(indexes ...int) {
i := 0
fmt.Printf("ready to add block hashes %v\n", indexes)
self.waitBlockHashesRequests(indexes[0])
fmt.Printf("adding block hashes %v\n", indexes)
hashes := self.hashPool.indexesToHashes(indexes)
next := func() (hash []byte, ok bool) {
if i < len(hashes) {
hash = hashes[i]
ok = true
i++
}
return
}
self.blockPool.AddBlockHashes(next, self.id)
}
// peer sends blocks if and when there is a request
// (in the shared request store, not necessarily to a person)
func (self *peerTester) AddBlocks(indexes ...int) {
hashes := self.hashPool.indexesToHashes(indexes)
fmt.Printf("ready to add blocks %v\n", indexes[1:])
self.waitBlocksRequests(indexes[1:]...)
fmt.Printf("adding blocks %v \n", indexes[1:])
for i := 1; i < len(hashes); i++ {
fmt.Printf("adding block %v %x\n", indexes[i], hashes[i][:4])
self.blockPool.AddBlock(&types.Block{HeaderHash: ethutil.Bytes(hashes[i]), ParentHeaderHash: ethutil.Bytes(hashes[i-1])}, self.id)
}
}
// peer callbacks
// -1 is special: not found (a hash never seen) // -1 is special: not found (a hash never seen)
func (self *protocolChecker) requestBlockHashesCallBack() (requestBlockHashesCallBack func([]byte) error) { // records block hashes requests by the blockPool
requestBlockHashesCallBack = func(hash []byte) error { func (self *peerTester) requestBlockHashes(hash []byte) error {
indexes := self.hashPool.hashesToIndexes([][]byte{hash}) indexes := self.hashPool.hashesToIndexes([][]byte{hash})
fmt.Printf("[%s] blocks hash request %v %x\n", self.id, indexes[0], hash[:4])
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
self.blockHashesRequests = append(self.blockHashesRequests, indexes[0]) self.blockHashesRequests = append(self.blockHashesRequests, indexes[0])
return nil return nil
}
return
} }
func (self *protocolChecker) requestBlocksCallBack() (requestBlocksCallBack func([][]byte) error) { // records block requests by the blockPool
requestBlocksCallBack = func(hashes [][]byte) error { func (self *peerTester) requestBlocks(hashes [][]byte) error {
indexes := self.hashPool.hashesToIndexes(hashes) indexes := self.hashPool.hashesToIndexes(hashes)
fmt.Printf("blocks request %v %x...\n", indexes, hashes[0][:4])
self.lock.Lock() self.lock.Lock()
defer self.lock.Unlock() defer self.lock.Unlock()
self.blocksRequests = append(self.blocksRequests, indexes) self.blocksRequests = append(self.blocksRequests, indexes)
for _, i := range indexes {
self.blocksRequestsMap[i] = true
}
return nil return nil
}
return
} }
func (self *protocolChecker) invalidBlockCallBack() (invalidBlockCallBack func(error)) { // records the error codes of all the peerErrors found the blockPool
invalidBlockCallBack = func(err error) { func (self *peerTester) peerError(code int, format string, params ...interface{}) {
self.invalidBlocks = append(self.invalidBlocks, err) self.peerErrors = append(self.peerErrors, code)
}
return
} }
// the actual tests
func TestAddPeer(t *testing.T) { func TestAddPeer(t *testing.T) {
ethlogger.AddLogSystem(sys) logInit()
knownBlockIndexes := []int{0, 1} _, blockPool, blockPoolTester := newTestBlockPool(t)
invalidBlockIndexes := []int{2, 3} peer0 := blockPoolTester.newPeer("peer0", 1, 0)
invalidPoWIndexes := []int{4, 5} peer1 := blockPoolTester.newPeer("peer1", 2, 1)
hashPool, blockPool := newTestBlockPool(knownBlockIndexes, invalidBlockIndexes, invalidPoWIndexes) peer2 := blockPoolTester.newPeer("peer2", 3, 2)
// TODO: var peer *peerInfo
// hashPool, blockPool, blockChainChecker = newTestBlockPool(knownBlockIndexes, invalidBlockIndexes, invalidPoWIndexes)
peer0 := &protocolChecker{ blockPool.Start()
// blockHashesRequests: make([]int),
// blocksRequests: make([][]int), // pool
// invalidBlocks: make([]error), best := peer0.AddPeer()
hashPool: hashPool,
}
best := blockPool.AddPeer(ethutil.Big1, newHash(100), "0",
peer0.requestBlockHashesCallBack(),
peer0.requestBlocksCallBack(),
peer0.invalidBlockCallBack(),
)
if !best { if !best {
t.Errorf("peer not accepted as best") t.Errorf("peer0 (TD=1) not accepted as best")
} }
if blockPool.peer.id != "peer0" {
t.Errorf("peer0 (TD=1) not set as best")
}
peer0.checkBlockHashesRequests(0)
best = peer2.AddPeer()
if !best {
t.Errorf("peer2 (TD=3) not accepted as best")
}
if blockPool.peer.id != "peer2" {
t.Errorf("peer2 (TD=3) not set as best")
}
peer2.checkBlockHashesRequests(2)
best = peer1.AddPeer()
if best {
t.Errorf("peer1 (TD=2) accepted as best")
}
if blockPool.peer.id != "peer2" {
t.Errorf("peer2 (TD=3) not set any more as best")
}
if blockPool.peer.td.Cmp(big.NewInt(int64(3))) != 0 {
t.Errorf("peer1 TD not set")
}
peer2.td = 4
peer2.currentBlock = 3
best = peer2.AddPeer()
if !best {
t.Errorf("peer2 (TD=4) not accepted as best")
}
if blockPool.peer.id != "peer2" {
t.Errorf("peer2 (TD=4) not set as best")
}
if blockPool.peer.td.Cmp(big.NewInt(int64(4))) != 0 {
t.Errorf("peer2 TD not updated")
}
peer2.checkBlockHashesRequests(2, 3)
peer1.td = 3
peer1.currentBlock = 2
best = peer1.AddPeer()
if best {
t.Errorf("peer1 (TD=3) should not be set as best")
}
if blockPool.peer.id == "peer1" {
t.Errorf("peer1 (TD=3) should not be set as best")
}
peer, best = blockPool.getPeer("peer1")
if peer.td.Cmp(big.NewInt(int64(3))) != 0 {
t.Errorf("peer1 TD should be updated")
}
blockPool.RemovePeer("peer2")
peer, best = blockPool.getPeer("peer2")
if peer != nil {
t.Errorf("peer2 not removed")
}
if blockPool.peer.id != "peer1" {
t.Errorf("existing peer1 (TD=3) should be set as best peer")
}
peer1.checkBlockHashesRequests(2)
blockPool.RemovePeer("peer1")
peer, best = blockPool.getPeer("peer1")
if peer != nil {
t.Errorf("peer1 not removed")
}
if blockPool.peer.id != "peer0" {
t.Errorf("existing peer0 (TD=1) should be set as best peer")
}
blockPool.RemovePeer("peer0")
peer, best = blockPool.getPeer("peer0")
if peer != nil {
t.Errorf("peer1 not removed")
}
// adding back earlier peer ok
peer0.currentBlock = 3
best = peer0.AddPeer()
if !best {
t.Errorf("peer0 (TD=1) should be set as best")
}
if blockPool.peer.id != "peer0" {
t.Errorf("peer0 (TD=1) should be set as best")
}
peer0.checkBlockHashesRequests(0, 0, 3)
blockPool.Stop() blockPool.Stop()
} }
func TestPeerWithKnownBlock(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.refBlockChain[0] = nil
blockPoolTester.blockChain[0] = nil
// hashPool, blockPool, blockPoolTester := newTestBlockPool()
blockPool.Start()
peer0 := blockPoolTester.newPeer("0", 1, 0)
peer0.AddPeer()
blockPool.Stop()
// no request on known block
peer0.checkBlockHashesRequests()
}
func TestSimpleChain(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(2)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 2)
peer1.AddPeer()
go peer1.AddBlockHashes(2, 1, 0)
peer1.AddBlocks(0, 1, 2)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[2] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestInvalidBlock(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(2)
blockPoolTester.refBlockChain[2] = []int{}
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 3)
peer1.AddPeer()
go peer1.AddBlockHashes(3, 2, 1, 0)
peer1.AddBlocks(0, 1, 2, 3)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[2] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
if len(peer1.peerErrors) == 1 {
if peer1.peerErrors[0] != ErrInvalidBlock {
t.Errorf("wrong error, got %v, expected %v", peer1.peerErrors[0], ErrInvalidBlock)
}
} else {
t.Errorf("expected invalid block error, got nothing")
}
}
func TestVerifyPoW(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(3)
first := false
blockPoolTester.blockPool.verifyPoW = func(b pow.Block) bool {
bb, _ := b.(*types.Block)
indexes := blockPoolTester.hashPool.hashesToIndexes([][]byte{bb.Hash()})
if indexes[0] == 1 && !first {
first = true
return false
} else {
return true
}
}
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 2)
peer1.AddPeer()
go peer1.AddBlockHashes(2, 1, 0)
peer1.AddBlocks(0, 1, 2)
peer1.AddBlocks(0, 1)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[2] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
if len(peer1.peerErrors) == 1 {
if peer1.peerErrors[0] != ErrInvalidPoW {
t.Errorf("wrong error, got %v, expected %v", peer1.peerErrors[0], ErrInvalidPoW)
}
} else {
t.Errorf("expected invalid pow error, got nothing")
}
}
func TestMultiSectionChain(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(5)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 5)
peer1.AddPeer()
go peer1.AddBlockHashes(5, 4, 3)
go peer1.AddBlocks(2, 3, 4, 5)
go peer1.AddBlockHashes(3, 2, 1, 0)
peer1.AddBlocks(0, 1, 2)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[5] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestNewBlocksOnPartialChain(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(7)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 5)
peer1.AddPeer()
go peer1.AddBlockHashes(5, 4, 3)
peer1.AddBlocks(2, 3) // partially complete section
// peer1 found new blocks
peer1.td = 2
peer1.currentBlock = 7
peer1.AddPeer()
go peer1.AddBlockHashes(7, 6, 5)
go peer1.AddBlocks(3, 4, 5, 6, 7)
go peer1.AddBlockHashes(3, 2, 1, 0) // tests that hash request from known chain root is remembered
peer1.AddBlocks(0, 1, 2)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[7] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestPeerSwitch(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(6)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 5)
peer2 := blockPoolTester.newPeer("peer2", 2, 6)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer1.AddPeer()
go peer1.AddBlockHashes(5, 4, 3)
peer1.AddBlocks(2, 3) // section partially complete, block 3 will be preserved after peer demoted
peer2.AddPeer() // peer2 is promoted as best peer, peer1 is demoted
go peer2.AddBlockHashes(6, 5) //
go peer2.AddBlocks(4, 5, 6) // tests that block request for earlier section is remembered
go peer1.AddBlocks(3, 4) // tests that connecting section by demoted peer is remembered and blocks are accepted from demoted peer
go peer2.AddBlockHashes(3, 2, 1, 0) // tests that known chain section is activated, hash requests from 3 is remembered
peer2.AddBlocks(0, 1, 2) // final blocks linking to blockchain sent
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[6] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestPeerDownSwitch(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(6)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 4)
peer2 := blockPoolTester.newPeer("peer2", 2, 6)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer2.AddPeer()
go peer2.AddBlockHashes(6, 5, 4)
peer2.AddBlocks(5, 6) // partially complete, section will be preserved
blockPool.RemovePeer("peer2") // peer2 disconnects
peer1.AddPeer() // inferior peer1 is promoted as best peer
go peer1.AddBlockHashes(4, 3, 2, 1, 0) //
go peer1.AddBlocks(3, 4, 5) // tests that section set by demoted peer is remembered and blocks are accepted
peer1.AddBlocks(0, 1, 2, 3)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[6] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestPeerSwitchBack(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(8)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 2, 11)
peer2 := blockPoolTester.newPeer("peer2", 1, 8)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer2.AddPeer()
go peer2.AddBlockHashes(8, 7, 6)
go peer2.AddBlockHashes(6, 5, 4)
peer2.AddBlocks(4, 5) // section partially complete
peer1.AddPeer() // peer1 is promoted as best peer
go peer1.AddBlockHashes(11, 10) // only gives useless results
blockPool.RemovePeer("peer1") // peer1 disconnects
go peer2.AddBlockHashes(4, 3, 2, 1, 0) // tests that asking for hashes from 4 is remembered
go peer2.AddBlocks(3, 4, 5, 6, 7, 8) // tests that section 4, 5, 6 and 7, 8 are remembered for missing blocks
peer2.AddBlocks(0, 1, 2, 3)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[8] = []int{}
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestForkSimple(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(9)
blockPoolTester.refBlockChain[3] = []int{4, 7}
delete(blockPoolTester.refBlockChain, 6)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 9)
peer2 := blockPoolTester.newPeer("peer2", 2, 6)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer1.AddPeer()
go peer1.AddBlockHashes(9, 8, 7, 3, 2)
peer1.AddBlocks(1, 2, 3, 7, 8, 9)
peer2.AddPeer() // peer2 is promoted as best peer
go peer2.AddBlockHashes(6, 5, 4, 3, 2) // fork on 3 -> 4 (earlier child: 7)
go peer2.AddBlocks(1, 2, 3, 4, 5, 6)
go peer2.AddBlockHashes(2, 1, 0)
peer2.AddBlocks(0, 1, 2)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[6] = []int{}
blockPoolTester.refBlockChain[3] = []int{4}
delete(blockPoolTester.refBlockChain, 7)
delete(blockPoolTester.refBlockChain, 8)
delete(blockPoolTester.refBlockChain, 9)
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestForkSwitchBackByNewBlocks(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(11)
blockPoolTester.refBlockChain[3] = []int{4, 7}
delete(blockPoolTester.refBlockChain, 6)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 9)
peer2 := blockPoolTester.newPeer("peer2", 2, 6)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer1.AddPeer()
go peer1.AddBlockHashes(9, 8, 7, 3, 2)
peer1.AddBlocks(8, 9) // partial section
peer2.AddPeer() //
go peer2.AddBlockHashes(6, 5, 4, 3, 2) // peer2 forks on block 3
peer2.AddBlocks(1, 2, 3, 4, 5, 6) //
// peer1 finds new blocks
peer1.td = 3
peer1.currentBlock = 11
peer1.AddPeer()
go peer1.AddBlockHashes(11, 10, 9)
peer1.AddBlocks(7, 8, 9, 10, 11)
go peer1.AddBlockHashes(7, 3) // tests that hash request from fork root is remembered
go peer1.AddBlocks(3, 7) // tests that block requests on earlier fork are remembered
// go peer1.AddBlockHashes(1, 0) // tests that hash request from root of connecting chain section (added by demoted peer) is remembered
go peer1.AddBlockHashes(2, 1, 0) // tests that hash request from root of connecting chain section (added by demoted peer) is remembered
peer1.AddBlocks(0, 1, 2, 3)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[11] = []int{}
blockPoolTester.refBlockChain[3] = []int{7}
delete(blockPoolTester.refBlockChain, 6)
delete(blockPoolTester.refBlockChain, 5)
delete(blockPoolTester.refBlockChain, 4)
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestForkSwitchBackByPeerSwitchBack(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(9)
blockPoolTester.refBlockChain[3] = []int{4, 7}
delete(blockPoolTester.refBlockChain, 6)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 9)
peer2 := blockPoolTester.newPeer("peer2", 2, 6)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer1.AddPeer()
go peer1.AddBlockHashes(9, 8, 7, 3, 2)
peer1.AddBlocks(8, 9)
peer2.AddPeer() //
go peer2.AddBlockHashes(6, 5, 4, 3, 2) // peer2 forks on block 3
peer2.AddBlocks(2, 3, 4, 5, 6) //
blockPool.RemovePeer("peer2") // peer2 disconnects, peer1 is promoted again as best peer
peer1.AddBlockHashes(7, 3) // tests that hash request from fork root is remembered
go peer1.AddBlocks(3, 7, 8) // tests that block requests on earlier fork are remembered
go peer1.AddBlockHashes(2, 1, 0) //
peer1.AddBlocks(0, 1, 2, 3)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[9] = []int{}
blockPoolTester.refBlockChain[3] = []int{7}
delete(blockPoolTester.refBlockChain, 6)
delete(blockPoolTester.refBlockChain, 5)
delete(blockPoolTester.refBlockChain, 4)
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}
func TestForkCompleteSectionSwitchBackByPeerSwitchBack(t *testing.T) {
logInit()
_, blockPool, blockPoolTester := newTestBlockPool(t)
blockPoolTester.blockChain[0] = nil
blockPoolTester.initRefBlockChain(9)
blockPoolTester.refBlockChain[3] = []int{4, 7}
delete(blockPoolTester.refBlockChain, 6)
blockPool.Start()
peer1 := blockPoolTester.newPeer("peer1", 1, 9)
peer2 := blockPoolTester.newPeer("peer2", 2, 6)
peer2.blocksRequestsMap = peer1.blocksRequestsMap
peer1.AddPeer()
go peer1.AddBlockHashes(9, 8, 7)
peer1.AddBlocks(3, 7, 8, 9) // make sure this section is complete
time.Sleep(1 * time.Second)
go peer1.AddBlockHashes(7, 3, 2) // block 3/7 is section boundary
peer1.AddBlocks(2, 3) // partially complete sections
peer2.AddPeer() //
go peer2.AddBlockHashes(6, 5, 4, 3, 2) // peer2 forks on block 3
peer2.AddBlocks(2, 3, 4, 5, 6) // block 2 still missing.
blockPool.RemovePeer("peer2") // peer2 disconnects, peer1 is promoted again as best peer
peer1.AddBlockHashes(7, 3) // tests that hash request from fork root is remembered even though section process completed
go peer1.AddBlockHashes(2, 1, 0) //
peer1.AddBlocks(0, 1, 2)
blockPool.Wait(waitTimeout * time.Second)
blockPool.Stop()
blockPoolTester.refBlockChain[9] = []int{}
blockPoolTester.refBlockChain[3] = []int{7}
delete(blockPoolTester.refBlockChain, 6)
delete(blockPoolTester.refBlockChain, 5)
delete(blockPoolTester.refBlockChain, 4)
blockPoolTester.checkBlockChain(blockPoolTester.refBlockChain)
}

View File

@ -52,18 +52,17 @@ func ProtocolError(code int, format string, params ...interface{}) (err *protoco
} }
func (self protocolError) Error() (message string) { func (self protocolError) Error() (message string) {
message = self.message if len(message) == 0 {
if message == "" { var ok bool
message, ok := errorToString[self.Code] self.message, ok = errorToString[self.Code]
if !ok { if !ok {
panic("invalid error code") panic("invalid error code")
} }
if self.format != "" { if self.format != "" {
message += ": " + fmt.Sprintf(self.format, self.params...) self.message += ": " + fmt.Sprintf(self.format, self.params...)
} }
self.message = message
} }
return return self.message
} }
func (self *protocolError) Fatal() bool { func (self *protocolError) Fatal() bool {

View File

@ -3,7 +3,7 @@ package eth
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math" "io"
"math/big" "math/big"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
@ -95,14 +95,13 @@ func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPoo
blockPool: blockPool, blockPool: blockPool,
rw: rw, rw: rw,
peer: peer, peer: peer,
id: (string)(peer.Identity().Pubkey()), id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]),
} }
err = self.handleStatus() err = self.handleStatus()
if err == nil { if err == nil {
for { for {
err = self.handle() err = self.handle()
if err != nil { if err != nil {
fmt.Println(err)
self.blockPool.RemovePeer(self.id) self.blockPool.RemovePeer(self.id)
break break
} }
@ -117,7 +116,7 @@ func (self *ethProtocol) handle() error {
return err return err
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
} }
// make sure that the payload has been fully consumed // make sure that the payload has been fully consumed
defer msg.Discard() defer msg.Discard()
@ -125,76 +124,87 @@ func (self *ethProtocol) handle() error {
switch msg.Code { switch msg.Code {
case StatusMsg: case StatusMsg:
return ProtocolError(ErrExtraStatusMsg, "") return self.protoError(ErrExtraStatusMsg, "")
case TxMsg: case TxMsg:
// TODO: rework using lazy RLP stream // TODO: rework using lazy RLP stream
var txs []*types.Transaction var txs []*types.Transaction
if err := msg.Decode(&txs); err != nil { if err := msg.Decode(&txs); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
self.txPool.AddTransactions(txs) self.txPool.AddTransactions(txs)
case GetBlockHashesMsg: case GetBlockHashesMsg:
var request getBlockHashesMsgData var request getBlockHashesMsgData
if err := msg.Decode(&request); err != nil { if err := msg.Decode(&request); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "->msg %v: %v", msg, err)
} }
hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount) hashes := self.chainManager.GetBlockHashesFromHash(request.Hash, request.Amount)
return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...) return self.rw.EncodeMsg(BlockHashesMsg, ethutil.ByteSliceToInterface(hashes)...)
case BlockHashesMsg: case BlockHashesMsg:
// TODO: redo using lazy decode , this way very inefficient on known chains // TODO: redo using lazy decode , this way very inefficient on known chains
msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) msgStream := rlp.NewStream(msg.Payload)
var err error var err error
var i int
iter := func() (hash []byte, ok bool) { iter := func() (hash []byte, ok bool) {
hash, err = msgStream.Bytes() hash, err = msgStream.Bytes()
if err == nil { if err == nil {
i++
ok = true ok = true
} else {
if err != io.EOF {
self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err)
}
} }
return return
} }
self.blockPool.AddBlockHashes(iter, self.id) self.blockPool.AddBlockHashes(iter, self.id)
if err != nil && err != rlp.EOL {
return ProtocolError(ErrDecode, "%v", err)
}
case GetBlocksMsg: case GetBlocksMsg:
var blockHashes [][]byte msgStream := rlp.NewStream(msg.Payload)
if err := msg.Decode(&blockHashes); err != nil {
return ProtocolError(ErrDecode, "%v", err)
}
max := int(math.Min(float64(len(blockHashes)), blockHashesBatchSize))
var blocks []interface{} var blocks []interface{}
for i, hash := range blockHashes { var i int
if i >= max { for {
i++
var hash []byte
if err := msgStream.Decode(&hash); err != nil {
if err == io.EOF {
break break
} else {
return self.protoError(ErrDecode, "msg %v: %v", msg, err)
}
} }
block := self.chainManager.GetBlock(hash) block := self.chainManager.GetBlock(hash)
if block != nil { if block != nil {
blocks = append(blocks, block.RlpData()) blocks = append(blocks, block)
}
if i == blockHashesBatchSize {
break
} }
} }
return self.rw.EncodeMsg(BlocksMsg, blocks...) return self.rw.EncodeMsg(BlocksMsg, blocks...)
case BlocksMsg: case BlocksMsg:
msgStream := rlp.NewListStream(msg.Payload, uint64(msg.Size)) msgStream := rlp.NewStream(msg.Payload)
for { for {
var block *types.Block var block types.Block
if err := msgStream.Decode(&block); err != nil { if err := msgStream.Decode(&block); err != nil {
if err == rlp.EOL { if err == io.EOF {
break break
} else { } else {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
} }
self.blockPool.AddBlock(block, self.id) self.blockPool.AddBlock(&block, self.id)
} }
case NewBlockMsg: case NewBlockMsg:
var request newBlockMsgData var request newBlockMsgData
if err := msg.Decode(&request); err != nil { if err := msg.Decode(&request); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
hash := request.Block.Hash() hash := request.Block.Hash()
// to simplify backend interface adding a new block // to simplify backend interface adding a new block
@ -202,12 +212,12 @@ func (self *ethProtocol) handle() error {
// (or selected as new best peer) // (or selected as new best peer)
if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) { if self.blockPool.AddPeer(request.TD, hash, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) {
called := true called := true
iter := func() (hash []byte, ok bool) { iter := func() ([]byte, bool) {
if called { if called {
called = false called = false
return hash, true return hash, true
} else { } else {
return return nil, false
} }
} }
self.blockPool.AddBlockHashes(iter, self.id) self.blockPool.AddBlockHashes(iter, self.id)
@ -215,14 +225,14 @@ func (self *ethProtocol) handle() error {
} }
default: default:
return ProtocolError(ErrInvalidMsgCode, "%v", msg.Code) return self.protoError(ErrInvalidMsgCode, "%v", msg.Code)
} }
return nil return nil
} }
type statusMsgData struct { type statusMsgData struct {
ProtocolVersion uint ProtocolVersion uint32
NetworkId uint NetworkId uint32
TD *big.Int TD *big.Int
CurrentBlock []byte CurrentBlock []byte
GenesisBlock []byte GenesisBlock []byte
@ -253,56 +263,56 @@ func (self *ethProtocol) handleStatus() error {
} }
if msg.Code != StatusMsg { if msg.Code != StatusMsg {
return ProtocolError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg) return self.protoError(ErrNoStatusMsg, "first msg has code %x (!= %x)", msg.Code, StatusMsg)
} }
if msg.Size > ProtocolMaxMsgSize { if msg.Size > ProtocolMaxMsgSize {
return ProtocolError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize) return self.protoError(ErrMsgTooLarge, "%v > %v", msg.Size, ProtocolMaxMsgSize)
} }
var status statusMsgData var status statusMsgData
if err := msg.Decode(&status); err != nil { if err := msg.Decode(&status); err != nil {
return ProtocolError(ErrDecode, "%v", err) return self.protoError(ErrDecode, "msg %v: %v", msg, err)
} }
_, _, genesisBlock := self.chainManager.Status() _, _, genesisBlock := self.chainManager.Status()
if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 { if bytes.Compare(status.GenesisBlock, genesisBlock) != 0 {
return ProtocolError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock) return self.protoError(ErrGenesisBlockMismatch, "%x (!= %x)", status.GenesisBlock, genesisBlock)
} }
if status.NetworkId != NetworkId { if status.NetworkId != NetworkId {
return ProtocolError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId) return self.protoError(ErrNetworkIdMismatch, "%d (!= %d)", status.NetworkId, NetworkId)
} }
if ProtocolVersion != status.ProtocolVersion { if ProtocolVersion != status.ProtocolVersion {
return ProtocolError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion) return self.protoError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, ProtocolVersion)
} }
self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4]) self.peer.Infof("Peer is [eth] capable (%d/%d). TD=%v H=%x\n", status.ProtocolVersion, status.NetworkId, status.TD, status.CurrentBlock[:4])
//self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect)
self.peer.Infoln("AddPeer(IGNORED)")
return nil return nil
} }
func (self *ethProtocol) requestBlockHashes(from []byte) error { func (self *ethProtocol) requestBlockHashes(from []byte) error {
self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4]) self.peer.Debugf("fetching hashes (%d) %x...\n", blockHashesBatchSize, from[0:4])
return self.rw.EncodeMsg(GetBlockHashesMsg, from, blockHashesBatchSize) return self.rw.EncodeMsg(GetBlockHashesMsg, interface{}(from), uint64(blockHashesBatchSize))
} }
func (self *ethProtocol) requestBlocks(hashes [][]byte) error { func (self *ethProtocol) requestBlocks(hashes [][]byte) error {
self.peer.Debugf("fetching %v blocks", len(hashes)) self.peer.Debugf("fetching %v blocks", len(hashes))
return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)) return self.rw.EncodeMsg(GetBlocksMsg, ethutil.ByteSliceToInterface(hashes)...)
} }
func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) { func (self *ethProtocol) protoError(code int, format string, params ...interface{}) (err *protocolError) {
err = ProtocolError(code, format, params...) err = ProtocolError(code, format, params...)
if err.Fatal() { if err.Fatal() {
self.peer.Errorln(err) self.peer.Errorln("err %v", err)
// disconnect
} else { } else {
self.peer.Debugln(err) self.peer.Debugf("fyi %v", err)
} }
return return
} }
@ -310,10 +320,10 @@ func (self *ethProtocol) protoError(code int, format string, params ...interface
func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) { func (self *ethProtocol) protoErrorDisconnect(code int, format string, params ...interface{}) {
err := ProtocolError(code, format, params...) err := ProtocolError(code, format, params...)
if err.Fatal() { if err.Fatal() {
self.peer.Errorln(err) self.peer.Errorln("err %v", err)
// disconnect // disconnect
} else { } else {
self.peer.Debugln(err) self.peer.Debugf("fyi %v", err)
} }
} }

View File

@ -1,35 +1,48 @@
package eth package eth
import ( import (
"bytes"
"io" "io"
"log"
"math/big" "math/big"
"os"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethutil"
ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p" "github.com/ethereum/go-ethereum/p2p"
) )
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
type testMsgReadWriter struct { type testMsgReadWriter struct {
in chan p2p.Msg in chan p2p.Msg
out chan p2p.Msg out []p2p.Msg
} }
func (self *testMsgReadWriter) In(msg p2p.Msg) { func (self *testMsgReadWriter) In(msg p2p.Msg) {
self.in <- msg self.in <- msg
} }
func (self *testMsgReadWriter) Out(msg p2p.Msg) { func (self *testMsgReadWriter) Out() (msg p2p.Msg, ok bool) {
self.in <- msg if len(self.out) > 0 {
msg = self.out[0]
self.out = self.out[1:]
ok = true
}
return
} }
func (self *testMsgReadWriter) WriteMsg(msg p2p.Msg) error { func (self *testMsgReadWriter) WriteMsg(msg p2p.Msg) error {
self.out <- msg self.out = append(self.out, msg)
return nil return nil
} }
func (self *testMsgReadWriter) EncodeMsg(code uint64, data ...interface{}) error { func (self *testMsgReadWriter) EncodeMsg(code uint64, data ...interface{}) error {
return self.WriteMsg(p2p.NewMsg(code, data)) return self.WriteMsg(p2p.NewMsg(code, data...))
} }
func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) { func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) {
@ -40,145 +53,83 @@ func (self *testMsgReadWriter) ReadMsg() (p2p.Msg, error) {
return msg, nil return msg, nil
} }
func errorCheck(t *testing.T, expCode int, err error) { type testTxPool struct {
perr, ok := err.(*protocolError)
if ok && perr != nil {
if code := perr.Code; code != expCode {
ok = false
}
}
if !ok {
t.Errorf("expected error code %v, got %v", ErrNoStatusMsg, err)
}
}
type TestBackend struct {
getTransactions func() []*types.Transaction getTransactions func() []*types.Transaction
addTransactions func(txs []*types.Transaction) addTransactions func(txs []*types.Transaction)
getBlockHashes func(hash []byte, amount uint32) (hashes [][]byte) }
addBlockHashes func(next func() ([]byte, bool), peerId string)
type testChainManager struct {
getBlockHashes func(hash []byte, amount uint64) (hashes [][]byte)
getBlock func(hash []byte) *types.Block getBlock func(hash []byte) *types.Block
addBlock func(block *types.Block, peerId string) (err error)
addPeer func(td *big.Int, currentBlock []byte, peerId string, requestHashes func([]byte) error, requestBlocks func([][]byte) error, invalidBlock func(error)) (best bool)
removePeer func(peerId string)
status func() (td *big.Int, currentBlock []byte, genesisBlock []byte) status func() (td *big.Int, currentBlock []byte, genesisBlock []byte)
} }
func (self *TestBackend) GetTransactions() (txs []*types.Transaction) { type testBlockPool struct {
if self.getTransactions != nil { addBlockHashes func(next func() ([]byte, bool), peerId string)
txs = self.getTransactions() addBlock func(block *types.Block, peerId string) (err error)
} addPeer func(td *big.Int, currentBlock []byte, peerId string, requestHashes func([]byte) error, requestBlocks func([][]byte) error, peerError func(int, string, ...interface{})) (best bool)
return removePeer func(peerId string)
} }
func (self *TestBackend) AddTransactions(txs []*types.Transaction) { // func (self *testTxPool) GetTransactions() (txs []*types.Transaction) {
// if self.getTransactions != nil {
// txs = self.getTransactions()
// }
// return
// }
func (self *testTxPool) AddTransactions(txs []*types.Transaction) {
if self.addTransactions != nil { if self.addTransactions != nil {
self.addTransactions(txs) self.addTransactions(txs)
} }
} }
func (self *TestBackend) GetBlockHashes(hash []byte, amount uint32) (hashes [][]byte) { func (self *testChainManager) GetBlockHashesFromHash(hash []byte, amount uint64) (hashes [][]byte) {
if self.getBlockHashes != nil { if self.getBlockHashes != nil {
hashes = self.getBlockHashes(hash, amount) hashes = self.getBlockHashes(hash, amount)
} }
return return
} }
<<<<<<< HEAD func (self *testChainManager) Status() (td *big.Int, currentBlock []byte, genesisBlock []byte) {
<<<<<<< HEAD
func (self *TestBackend) AddBlockHashes(next func() ([]byte, bool), peerId string) {
if self.addBlockHashes != nil {
self.addBlockHashes(next, peerId)
}
}
=======
func (self *TestBackend) AddHash(hash []byte, peer *p2p.Peer) (more bool) {
if self.addHash != nil {
more = self.addHash(hash, peer)
=======
func (self *TestBackend) AddBlockHashes(next func() ([]byte, bool), peerId string) {
if self.addBlockHashes != nil {
self.addBlockHashes(next, peerId)
>>>>>>> eth protocol changes
}
}
<<<<<<< HEAD
>>>>>>> initial commit for eth-p2p integration
=======
>>>>>>> eth protocol changes
func (self *TestBackend) GetBlock(hash []byte) (block *types.Block) {
if self.getBlock != nil {
block = self.getBlock(hash)
}
return
}
<<<<<<< HEAD
<<<<<<< HEAD
func (self *TestBackend) AddBlock(block *types.Block, peerId string) (err error) {
if self.addBlock != nil {
err = self.addBlock(block, peerId)
=======
func (self *TestBackend) AddBlock(td *big.Int, block *types.Block, peer *p2p.Peer) (fetchHashes bool, err error) {
if self.addBlock != nil {
fetchHashes, err = self.addBlock(td, block, peer)
>>>>>>> initial commit for eth-p2p integration
=======
func (self *TestBackend) AddBlock(block *types.Block, peerId string) (err error) {
if self.addBlock != nil {
err = self.addBlock(block, peerId)
>>>>>>> eth protocol changes
}
return
}
<<<<<<< HEAD
<<<<<<< HEAD
func (self *TestBackend) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, invalidBlock func(error)) (best bool) {
if self.addPeer != nil {
best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, invalidBlock)
=======
func (self *TestBackend) AddPeer(td *big.Int, currentBlock []byte, peer *p2p.Peer) (fetchHashes bool) {
if self.addPeer != nil {
fetchHashes = self.addPeer(td, currentBlock, peer)
>>>>>>> initial commit for eth-p2p integration
=======
func (self *TestBackend) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, invalidBlock func(error)) (best bool) {
if self.addPeer != nil {
best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, invalidBlock)
>>>>>>> eth protocol changes
}
return
}
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> eth protocol changes
func (self *TestBackend) RemovePeer(peerId string) {
if self.removePeer != nil {
self.removePeer(peerId)
}
}
<<<<<<< HEAD
=======
>>>>>>> initial commit for eth-p2p integration
=======
>>>>>>> eth protocol changes
func (self *TestBackend) Status() (td *big.Int, currentBlock []byte, genesisBlock []byte) {
if self.status != nil { if self.status != nil {
td, currentBlock, genesisBlock = self.status() td, currentBlock, genesisBlock = self.status()
} }
return return
} }
<<<<<<< HEAD func (self *testChainManager) GetBlock(hash []byte) (block *types.Block) {
<<<<<<< HEAD if self.getBlock != nil {
======= block = self.getBlock(hash)
>>>>>>> eth protocol changes }
return
}
func (self *testBlockPool) AddBlockHashes(next func() ([]byte, bool), peerId string) {
if self.addBlockHashes != nil {
self.addBlockHashes(next, peerId)
}
}
func (self *testBlockPool) AddBlock(block *types.Block, peerId string) {
if self.addBlock != nil {
self.addBlock(block, peerId)
}
}
func (self *testBlockPool) AddPeer(td *big.Int, currentBlock []byte, peerId string, requestBlockHashes func([]byte) error, requestBlocks func([][]byte) error, peerError func(int, string, ...interface{})) (best bool) {
if self.addPeer != nil {
best = self.addPeer(td, currentBlock, peerId, requestBlockHashes, requestBlocks, peerError)
}
return
}
func (self *testBlockPool) RemovePeer(peerId string) {
if self.removePeer != nil {
self.removePeer(peerId)
}
}
// TODO: refactor this into p2p/client_identity // TODO: refactor this into p2p/client_identity
type peerId struct { type peerId struct {
pubkey []byte pubkey []byte
@ -201,32 +152,119 @@ func testPeer() *p2p.Peer {
return p2p.NewPeer(&peerId{}, []p2p.Cap{}) return p2p.NewPeer(&peerId{}, []p2p.Cap{})
} }
func TestErrNoStatusMsg(t *testing.T) { type ethProtocolTester struct {
<<<<<<< HEAD quit chan error
======= rw *testMsgReadWriter // p2p.MsgReadWriter
func TestEth(t *testing.T) { txPool *testTxPool // txPool
>>>>>>> initial commit for eth-p2p integration chainManager *testChainManager // chainManager
======= blockPool *testBlockPool // blockPool
>>>>>>> eth protocol changes t *testing.T
quit := make(chan bool) }
rw := &testMsgReadWriter{make(chan p2p.Msg, 10), make(chan p2p.Msg, 10)}
testBackend := &TestBackend{} func newEth(t *testing.T) *ethProtocolTester {
var err error return &ethProtocolTester{
go func() { quit: make(chan error),
<<<<<<< HEAD rw: &testMsgReadWriter{in: make(chan p2p.Msg, 10)},
<<<<<<< HEAD txPool: &testTxPool{},
err = runEthProtocol(testBackend, testPeer(), rw) chainManager: &testChainManager{},
======= blockPool: &testBlockPool{},
err = runEthProtocol(testBackend, nil, rw) t: t,
>>>>>>> initial commit for eth-p2p integration }
======= }
err = runEthProtocol(testBackend, testPeer(), rw)
>>>>>>> eth protocol changes func (self *ethProtocolTester) reset() {
close(quit) self.rw = &testMsgReadWriter{in: make(chan p2p.Msg, 10)}
}() self.quit = make(chan error)
statusMsg := p2p.NewMsg(4) }
rw.In(statusMsg)
<-quit func (self *ethProtocolTester) checkError(expCode int, delay time.Duration) (err error) {
errorCheck(t, ErrNoStatusMsg, err) var timer = time.After(delay)
// read(t, remote, []byte("hello, world"), nil) select {
case err = <-self.quit:
case <-timer:
self.t.Errorf("no error after %v, expected %v", delay, expCode)
return
}
perr, ok := err.(*protocolError)
if ok && perr != nil {
if code := perr.Code; code != expCode {
self.t.Errorf("expected protocol error (code %v), got %v (%v)", expCode, code, err)
}
} else {
self.t.Errorf("expected protocol error (code %v), got %v", expCode, err)
}
return
}
func (self *ethProtocolTester) In(msg p2p.Msg) {
self.rw.In(msg)
}
func (self *ethProtocolTester) Out() (p2p.Msg, bool) {
return self.rw.Out()
}
func (self *ethProtocolTester) checkMsg(i int, code uint64, val interface{}) (msg p2p.Msg) {
if i >= len(self.rw.out) {
self.t.Errorf("expected at least %v msgs, got %v", i, len(self.rw.out))
return
}
msg = self.rw.out[i]
if msg.Code != code {
self.t.Errorf("expected msg code %v, got %v", code, msg.Code)
}
if val != nil {
if err := msg.Decode(val); err != nil {
self.t.Errorf("rlp encoding error: %v", err)
}
}
return
}
func (self *ethProtocolTester) run() {
err := runEthProtocol(self.txPool, self.chainManager, self.blockPool, testPeer(), self.rw)
self.quit <- err
}
func TestStatusMsgErrors(t *testing.T) {
logInit()
eth := newEth(t)
td := ethutil.Big1
currentBlock := []byte{1}
genesis := []byte{2}
eth.chainManager.status = func() (*big.Int, []byte, []byte) { return td, currentBlock, genesis }
go eth.run()
statusMsg := p2p.NewMsg(4)
eth.In(statusMsg)
delay := 1 * time.Second
eth.checkError(ErrNoStatusMsg, delay)
var status statusMsgData
eth.checkMsg(0, StatusMsg, &status) // first outgoing msg should be StatusMsg
if status.TD.Cmp(td) != 0 ||
status.ProtocolVersion != ProtocolVersion ||
status.NetworkId != NetworkId ||
status.TD.Cmp(td) != 0 ||
bytes.Compare(status.CurrentBlock, currentBlock) != 0 ||
bytes.Compare(status.GenesisBlock, genesis) != 0 {
t.Errorf("incorrect outgoing status")
}
eth.reset()
go eth.run()
statusMsg = p2p.NewMsg(0, uint32(48), uint32(0), td, currentBlock, genesis)
eth.In(statusMsg)
eth.checkError(ErrProtocolVersionMismatch, delay)
eth.reset()
go eth.run()
statusMsg = p2p.NewMsg(0, uint32(49), uint32(1), td, currentBlock, genesis)
eth.In(statusMsg)
eth.checkError(ErrNetworkIdMismatch, delay)
eth.reset()
go eth.run()
statusMsg = p2p.NewMsg(0, uint32(49), uint32(0), td, currentBlock, []byte{3})
eth.In(statusMsg)
eth.checkError(ErrGenesisBlockMismatch, delay)
} }

27
eth/test/README.md Normal file
View File

@ -0,0 +1,27 @@
= Integration tests for eth protocol and blockpool
This is a simple suite of tests to fire up a local test node with peers to test blockchain synchronisation and download.
The scripts call ethereum (assumed to be compiled in go-ethereum root).
To run a test:
. run.sh 00 02
Without arguments, all tests are run.
Peers are launched with preloaded imported chains. In order to prevent them from synchronizing with each other they are set with `-dial=false` and `-maxpeer 1` options. They log into `/tmp/eth.test/nodes/XX` where XX is the last two digits of their port.
Chains to import can be bootstrapped by letting nodes mine for some time. This is done with
. bootstrap.sh
Only the relative timing and forks matter so they should work if the bootstrap script is rerun.
The reference blockchain of tests are soft links to these import chains and check at the end of a test run.
Connecting to peers and exporting blockchain is scripted with JS files executed by the JSRE, see `tests/XX.sh`.
Each test is set with a timeout. This may vary on different computers so adjust sensibly.
If you kill a test before it completes, do not forget to kill all the background processes, since they will impact the result. Use:
killall ethereum

9
eth/test/bootstrap.sh Normal file
View File

@ -0,0 +1,9 @@
#!/bin/bash
# bootstrap chains - used to regenerate tests/chains/*.chain
mkdir -p chains
bash ./mine.sh 00 10
bash ./mine.sh 01 5 00
bash ./mine.sh 02 10 00
bash ./mine.sh 03 5 02
bash ./mine.sh 04 10 02

BIN
eth/test/chains/00.chain Executable file

Binary file not shown.

BIN
eth/test/chains/01.chain Executable file

Binary file not shown.

BIN
eth/test/chains/02.chain Executable file

Binary file not shown.

BIN
eth/test/chains/03.chain Executable file

Binary file not shown.

BIN
eth/test/chains/04.chain Executable file

Binary file not shown.

20
eth/test/mine.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/bash
# bash ./mine.sh node_id timeout(sec) [basechain]
ETH=../../ethereum
MINE="$ETH -datadir tmp/nodes/$1 -seed=false -port '' -shh=false -id test$1"
rm -rf tmp/nodes/$1
echo "Creating chain $1..."
if [[ "" != "$3" ]]; then
CHAIN="chains/$3.chain"
CHAINARG="-chain $CHAIN"
$MINE -mine $CHAINARG -loglevel 3 | grep 'importing'
fi
$MINE -mine -loglevel 0 &
PID=$!
sleep $2
kill $PID
$MINE -loglevel 3 <(echo "eth.export(\"chains/$1.chain\")") > /tmp/eth.test/mine.tmp &
PID=$!
sleep 1
kill $PID
cat /tmp/eth.test/mine.tmp | grep 'exporting'

53
eth/test/run.sh Normal file
View File

@ -0,0 +1,53 @@
#!/bin/bash
# bash run.sh (testid0 testid1 ...)
# runs tests tests/testid0.sh tests/testid1.sh ...
# without arguments, it runs all tests
. tests/common.sh
TESTS=
if [ "$#" -eq 0 ]; then
for NAME in tests/??.sh; do
i=`basename $NAME .sh`
TESTS="$TESTS $i"
done
else
TESTS=$@
fi
ETH=../../ethereum
DIR="/tmp/eth.test/nodes"
TIMEOUT=10
mkdir -p $DIR/js
echo "running tests $TESTS"
for NAME in $TESTS; do
PIDS=
CHAIN="tests/$NAME.chain"
JSFILE="$DIR/js/$NAME.js"
CHAIN_TEST="$DIR/$NAME/chain"
echo "RUN: test $NAME"
cat tests/common.js > $JSFILE
. tests/$NAME.sh
sleep $TIMEOUT
echo "timeout after $TIMEOUT seconds: killing $PIDS"
kill $PIDS
if [ -r "$CHAIN" ]; then
if diff $CHAIN $CHAIN_TEST >/dev/null ; then
echo "chain ok: $CHAIN=$CHAIN_TEST"
else
echo "FAIL: chains differ: expected $CHAIN ; got $CHAIN_TEST"
continue
fi
fi
ERRORS=$DIR/errors
if [ -r "$ERRORS" ]; then
echo "FAIL: "
cat $ERRORS
else
echo PASS
fi
done

1
eth/test/tests/00.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/01.chain

13
eth/test/tests/00.sh Normal file
View File

@ -0,0 +1,13 @@
#!/bin/bash
TIMEOUT=4
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(1000)
eth.export("$CHAIN_TEST");
EOF
peer 11 01
test_node $NAME "" -loglevel 5 $JSFILE

1
eth/test/tests/01.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/02.chain

18
eth/test/tests/01.sh Normal file
View File

@ -0,0 +1,18 @@
#!/bin/bash
TIMEOUT=5
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
log("added peer localhost:30311");
sleep(1000);
log("added peer localhost:30312");
eth.addPeer("localhost:30312");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01
peer 12 02
test_node $NAME "" -loglevel 5 $JSFILE

1
eth/test/tests/02.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/01.chain

19
eth/test/tests/02.sh Normal file
View File

@ -0,0 +1,19 @@
#!/bin/bash
TIMEOUT=6
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01
peer 12 02
P13ID=$PID
test_node $NAME "" -loglevel 5 $JSFILE
sleep 0.5
kill $P13ID

1
eth/test/tests/03.chain Symbolic link
View File

@ -0,0 +1 @@
../chains/12k.chain

14
eth/test/tests/03.sh Normal file
View File

@ -0,0 +1,14 @@
#!/bin/bash
TIMEOUT=35
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(30000);
eth.export("$CHAIN_TEST");
EOF
peer 11 12k
sleep 2
test_node $NAME "" -loglevel 5 $JSFILE

17
eth/test/tests/04.sh Normal file
View File

@ -0,0 +1,17 @@
#!/bin/bash
TIMEOUT=15
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
sleep(13000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01 -mine
peer 12 02
test_node $NAME "" -loglevel 5 $JSFILE
sleep 6
cat $DIR/$NAME/debug.log | grep 'best peer'

20
eth/test/tests/05.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/bash
TIMEOUT=60
cat >> $JSFILE <<EOF
eth.addPeer("localhost:30311");
sleep(200);
eth.addPeer("localhost:30312");
eth.addPeer("localhost:30313");
eth.addPeer("localhost:30314");
sleep(3000);
eth.export("$CHAIN_TEST");
EOF
peer 11 01 -mine
peer 12 02 -mine
peer 13 03
peer 14 04
test_node $NAME "" -loglevel 5 $JSFILE

9
eth/test/tests/common.js Normal file
View File

@ -0,0 +1,9 @@
function log(text) {
console.log("[JS TEST SCRIPT] " + text);
}
function sleep(seconds) {
var now = new Date().getTime();
while(new Date().getTime() < now + seconds){}
}

20
eth/test/tests/common.sh Normal file
View File

@ -0,0 +1,20 @@
#!/bin/bash
# launched by run.sh
function test_node {
rm -rf $DIR/$1
ARGS="-datadir $DIR/$1 -debug debug -seed=false -shh=false -id test$1"
if [ "" != "$2" ]; then
chain="chains/$2.chain"
echo "import chain $chain"
$ETH $ARGS -loglevel 3 -chain $chain | grep CLI |grep import
fi
echo "starting test node $1 with extra args ${@:3}"
$ETH $ARGS -port 303$1 ${@:3} &
PID=$!
PIDS="$PIDS $PID"
}
function peer {
test_node $@ -loglevel 5 -logfile debug.log -maxpeer 1 -dial=false
}

View File

@ -7,7 +7,7 @@ import (
) )
func TestClientIdentity(t *testing.T) { func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey") clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
clientString := clientIdentity.String() clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version()) expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected { if clientString != expected {

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
@ -49,7 +50,14 @@ func encodePayload(params ...interface{}) []byte {
// For the decoding rules, please see package rlp. // For the decoding rules, please see package rlp.
func (msg Msg) Decode(val interface{}) error { func (msg Msg) Decode(val interface{}) error {
s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) s := rlp.NewListStream(msg.Payload, uint64(msg.Size))
return s.Decode(val) if err := s.Decode(val); err != nil {
return newPeerError(errInvalidMsg, "(code %#x) (size %d) %v", msg.Code, msg.Size, err)
}
return nil
}
func (msg Msg) String() string {
return fmt.Sprintf("msg #%v (%v bytes)", msg.Code, msg.Size)
} }
// Discard reads any remaining payload data into a black hole. // Discard reads any remaining payload data into a black hole.

View File

@ -45,8 +45,8 @@ func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port) return fmt.Sprintf("%v:%d", d.IP, d.Port)
} }
func (d peerAddr) RlpData() interface{} { func (d *peerAddr) RlpData() interface{} {
return []interface{}{d.IP, d.Port, d.Pubkey} return []interface{}{string(d.IP), d.Port, d.Pubkey}
} }
// Peer represents a remote peer. // Peer represents a remote peer.
@ -426,7 +426,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
} }
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data)) return rw.WriteMsg(NewMsg(code, data...))
} }
func (rw *proto) ReadMsg() (Msg, error) { func (rw *proto) ReadMsg() (Msg, error) {
@ -460,3 +460,25 @@ func (r *eofSignal) Read(buf []byte) (int, error) {
} }
return n, err return n, err
} }
func (peer *Peer) PeerList() []interface{} {
peers := peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -30,9 +30,8 @@ var discard = Protocol{
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
id := NewSimpleClientIdentity("test", "0", "0", "public key")
peer := newPeer(conn1, protos, nil) peer := newPeer(conn1, protos, nil)
peer.ourID = id peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil } peer.pubkeyHook = func(*peerAddr) error { return nil }
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
@ -130,7 +129,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
if err := rw.EncodeMsg(2); err == nil { if err := rw.EncodeMsg(2); err == nil {
t.Error("expected error for out-of-range msg code, got nil") t.Error("expected error for out-of-range msg code, got nil")
} }
if err := rw.EncodeMsg(1); err != nil { if err := rw.EncodeMsg(1, "foo", "bar"); err != nil {
t.Errorf("write error: %v", err) t.Errorf("write error: %v", err)
} }
return nil return nil
@ -148,6 +147,13 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
if msg.Code != 17 { if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
} }
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
}
} }
func TestPeerWrite(t *testing.T) { func TestPeerWrite(t *testing.T) {
@ -226,8 +232,8 @@ func TestPeerActivity(t *testing.T) {
} }
func TestNewPeer(t *testing.T) { func TestNewPeer(t *testing.T) {
id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey")
caps := []Cap{{"foo", 2}, {"bar", 3}} caps := []Cap{{"foo", 2}, {"bar", 3}}
id := &peerId{}
p := NewPeer(id, caps) p := NewPeer(id, caps)
if !reflect.DeepEqual(p.Caps(), caps) { if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)

View File

@ -3,8 +3,6 @@ package p2p
import ( import (
"bytes" "bytes"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
@ -89,20 +87,25 @@ type baseProtocol struct {
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer} bp := &baseProtocol{rw, peer}
if err := bp.doHandshake(rw); err != nil { errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err return err
} }
// run main loop // run main loop
quit := make(chan error, 1)
go func() { go func() {
for { for {
if err := bp.handle(rw); err != nil { if err := bp.handle(rw); err != nil {
quit <- err errc <- err
break break
} }
} }
}() }()
return bp.loop(quit) return bp.loop(errc)
} }
var pingTimeout = 2 * time.Second var pingTimeout = 2 * time.Second
@ -166,7 +169,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
case pongMsg: case pongMsg:
case getPeersMsg: case getPeersMsg:
peers := bp.peerList() peers := bp.peer.PeerList()
// this is dangerous. the spec says that we should _delay_ // this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available. // sending the response if no new information is available.
// this means that would need to send a response later when // this means that would need to send a response later when
@ -174,7 +177,7 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
// //
// TODO: add event mechanism to notify baseProtocol for new peers // TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 { if len(peers) > 0 {
return bp.rw.EncodeMsg(peersMsg, peers) return bp.rw.EncodeMsg(peersMsg, peers...)
} }
case peersMsg: case peersMsg:
@ -193,14 +196,9 @@ func (bp *baseProtocol) handle(rw MsgReadWriter) error {
return nil return nil
} }
func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { func (bp *baseProtocol) readHandshake() error {
// send our handshake
if err := rw.WriteMsg(bp.handshakeMsg()); err != nil {
return err
}
// read and handle remote handshake // read and handle remote handshake
msg, err := rw.ReadMsg() msg, err := bp.rw.ReadMsg()
if err != nil { if err != nil {
return err return err
} }
@ -210,12 +208,10 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
if msg.Size > baseProtocolMaxMsgSize { if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big") return newPeerError(errMisc, "message too big")
} }
var hs handshake var hs handshake
if err := msg.Decode(&hs); err != nil { if err := msg.Decode(&hs); err != nil {
return err return err
} }
// validate handshake info // validate handshake info
if hs.Version != baseProtocolVersion { if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
@ -238,9 +234,7 @@ func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error {
if err := bp.peer.pubkeyHook(pa); err != nil { if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err) return newPeerError(errPubkeyForbidden, "%v", err)
} }
// TODO: remove Caps with empty name // TODO: remove Caps with empty name
var addr *peerAddr var addr *peerAddr
if hs.ListenPort != 0 { if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
@ -270,25 +264,3 @@ func (bp *baseProtocol) handshakeMsg() Msg {
bp.peer.ourID.Pubkey()[1:], bp.peer.ourID.Pubkey()[1:],
) )
} }
func (bp *baseProtocol) peerList() []ethutil.RlpEncodable {
peers := bp.peer.otherPeers()
ds := make([]ethutil.RlpEncodable, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -2,12 +2,89 @@ package p2p
import ( import (
"fmt" "fmt"
"net"
"reflect"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
) )
type peerId struct {
pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
cannedPeerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
var ownAddr *peerAddr = &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
// run matcher, close pipe when addresses have arrived
addrChan := make(chan *peerAddr, len(cannedPeerList))
go func() {
for _, want := range cannedPeerList {
got := <-addrChan
t.Logf("got peer: %+v", got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %#v, want %#v", got, want)
}
}
close(addrChan)
var own []*peerAddr
var got *peerAddr
for got = range addrChan {
own = append(own, got)
}
if len(own) != 1 || !reflect.DeepEqual(ownAddr, own[0]) {
t.Errorf("mismatch: peers own address is incorrectly or not given, got %v, want %#v", ownAddr)
}
rw2.Close()
}()
// run first peer
peer1 := newTestPeer()
peer1.ourListenAddr = ownAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(cannedPeerList))
for i, addr := range cannedPeerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
go runBaseProtocol(peer1, rw1)
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
}
func TestBaseProtocolDisconnect(t *testing.T) { func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(NewSimpleClientIdentity("p1", "", "", "foo"), nil) peer := NewPeer(&peerId{}, nil)
peer.ourID = NewSimpleClientIdentity("p2", "", "", "bar") peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil } peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
@ -32,6 +109,7 @@ func TestBaseProtocolDisconnect(t *testing.T) {
if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil { if err := rw2.EncodeMsg(discMsg, DiscQuitting); err != nil {
t.Error(err) t.Error(err)
} }
close(done) close(done)
}() }()

View File

@ -113,9 +113,11 @@ func (srv *Server) PeerCount() int {
// SuggestPeer injects an address into the outbound address pool. // SuggestPeer injects an address into the outbound address pool.
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
addr := &peerAddr{ip, uint64(port), nodeID}
select { select {
case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: case srv.peerConnect <- addr:
default: // don't block default: // don't block
srvlog.Warnf("peer suggestion %v ignored", addr)
} }
} }
@ -258,6 +260,7 @@ func (srv *Server) listenLoop() {
for { for {
select { select {
case slot := <-srv.peerSlots: case slot := <-srv.peerSlots:
srvlog.Debugf("grabbed slot %v for listening", slot)
conn, err := srv.listener.Accept() conn, err := srv.listener.Accept()
if err != nil { if err != nil {
srv.peerSlots <- slot srv.peerSlots <- slot
@ -330,6 +333,7 @@ func (srv *Server) dialLoop() {
case desc := <-suggest: case desc := <-suggest:
// candidate peer found, will dial out asyncronously // candidate peer found, will dial out asyncronously
// if connection fails slot will be released // if connection fails slot will be released
srvlog.Infof("dial %v (%v)", desc, *slot)
go srv.dialPeer(desc, *slot) go srv.dialPeer(desc, *slot)
// we can watch if more peers needed in the next loop // we can watch if more peers needed in the next loop
slots = srv.peerSlots slots = srv.peerSlots

View File

@ -11,7 +11,7 @@ import (
func startTestServer(t *testing.T, pf peerFunc) *Server { func startTestServer(t *testing.T, pf peerFunc) *Server {
server := &Server{ server := &Server{
Identity: NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), Identity: &peerId{},
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
newPeerFunc: pf, newPeerFunc: pf,

View File

@ -76,22 +76,37 @@ func Decode(r io.Reader, val interface{}) error {
type decodeError struct { type decodeError struct {
msg string msg string
typ reflect.Type typ reflect.Type
ctx []string
} }
func (err decodeError) Error() string { func (err *decodeError) Error() string {
return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) ctx := ""
if len(err.ctx) > 0 {
ctx = ", decoding into "
for i := len(err.ctx) - 1; i >= 0; i-- {
ctx += err.ctx[i]
}
}
return fmt.Sprintf("rlp: %s for %v%s", err.msg, err.typ, ctx)
} }
func wrapStreamError(err error, typ reflect.Type) error { func wrapStreamError(err error, typ reflect.Type) error {
switch err { switch err {
case ErrExpectedList: case ErrExpectedList:
return decodeError{"expected input list", typ} return &decodeError{msg: "expected input list", typ: typ}
case ErrExpectedString: case ErrExpectedString:
return decodeError{"expected input string or byte", typ} return &decodeError{msg: "expected input string or byte", typ: typ}
case errUintOverflow: case errUintOverflow:
return decodeError{"input string too long", typ} return &decodeError{msg: "input string too long", typ: typ}
case errNotAtEOL: case errNotAtEOL:
return decodeError{"input list has too many elements", typ} return &decodeError{msg: "input list has too many elements", typ: typ}
}
return err
}
func addErrorContext(err error, ctx string) error {
if decErr, ok := err.(*decodeError); ok {
decErr.ctx = append(decErr.ctx, ctx)
} }
return err return err
} }
@ -180,13 +195,13 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
return nil, err return nil, err
} }
if typ.Kind() == reflect.Array { isArray := typ.Kind() == reflect.Array
return func(s *Stream, val reflect.Value) error { return func(s *Stream, val reflect.Value) error {
if isArray {
return decodeListArray(s, val, etypeinfo.decoder) return decodeListArray(s, val, etypeinfo.decoder)
}, nil } else {
}
return func(s *Stream, val reflect.Value) error {
return decodeListSlice(s, val, etypeinfo.decoder) return decodeListSlice(s, val, etypeinfo.decoder)
}
}, nil }, nil
} }
@ -219,7 +234,7 @@ func decodeListSlice(s *Stream, val reflect.Value, elemdec decoder) error {
if err := elemdec(s, val.Index(i)); err == EOL { if err := elemdec(s, val.Index(i)); err == EOL {
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, fmt.Sprint("[", i, "]"))
} }
} }
if i < val.Len() { if i < val.Len() {
@ -248,7 +263,7 @@ func decodeListArray(s *Stream, val reflect.Value, elemdec decoder) error {
if err := elemdec(s, val.Index(i)); err == EOL { if err := elemdec(s, val.Index(i)); err == EOL {
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, fmt.Sprint("[", i, "]"))
} }
} }
if i < vlen { if i < vlen {
@ -280,14 +295,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error {
switch kind { switch kind {
case Byte: case Byte:
if val.Len() == 0 { if val.Len() == 0 {
return decodeError{"input string too long", val.Type()} return &decodeError{msg: "input string too long", typ: val.Type()}
} }
bv, _ := s.Uint() bv, _ := s.Uint()
val.Index(0).SetUint(bv) val.Index(0).SetUint(bv)
zero(val, 1) zero(val, 1)
case String: case String:
if uint64(val.Len()) < size { if uint64(val.Len()) < size {
return decodeError{"input string too long", val.Type()} return &decodeError{msg: "input string too long", typ: val.Type()}
} }
slice := val.Slice(0, int(size)).Interface().([]byte) slice := val.Slice(0, int(size)).Interface().([]byte)
if err := s.readFull(slice); err != nil { if err := s.readFull(slice); err != nil {
@ -334,7 +349,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
// too few elements. leave the rest at their zero value. // too few elements. leave the rest at their zero value.
break break
} else if err != nil { } else if err != nil {
return err return addErrorContext(err, "."+typ.Field(f.index).Name)
} }
} }
return wrapStreamError(s.ListEnd(), typ) return wrapStreamError(s.ListEnd(), typ)
@ -599,7 +614,13 @@ func (s *Stream) Decode(val interface{}) error {
if err != nil { if err != nil {
return err return err
} }
return info.decoder(s, rval.Elem())
err = info.decoder(s, rval.Elem())
if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 {
// add decode target type to error so context has more meaning
decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")"))
}
return err
} }
// Reset discards any information about the current decoding context // Reset discards any information about the current decoding context

View File

@ -231,7 +231,12 @@ var decodeTests = []decodeTest{
{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, {input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")},
{input: "C0", ptr: new([]byte), value: []byte{}}, {input: "C0", ptr: new([]byte), value: []byte{}},
{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, {input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}},
{input: "C3820102", ptr: new([]byte), error: "rlp: input string too long for uint8"},
{
input: "C3820102",
ptr: new([]byte),
error: "rlp: input string too long for uint8, decoding into ([]uint8)[0]",
},
// byte arrays // byte arrays
{input: "01", ptr: new([5]byte), value: [5]byte{1}}, {input: "01", ptr: new([5]byte), value: [5]byte{1}},
@ -239,9 +244,22 @@ var decodeTests = []decodeTest{
{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}},
{input: "C0", ptr: new([5]byte), value: [5]byte{}}, {input: "C0", ptr: new([5]byte), value: [5]byte{}},
{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, {input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}},
{input: "C3820102", ptr: new([5]byte), error: "rlp: input string too long for uint8"},
{input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too long for [5]uint8"}, {
{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, input: "C3820102",
ptr: new([5]byte),
error: "rlp: input string too long for uint8, decoding into ([5]uint8)[0]",
},
{
input: "86010203040506",
ptr: new([5]byte),
error: "rlp: input string too long for [5]uint8",
},
{
input: "850101",
ptr: new([5]byte),
error: io.ErrUnexpectedEOF.Error(),
},
// byte array reuse (should be zeroed) // byte array reuse (should be zeroed)
{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, {input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}},
@ -272,13 +290,23 @@ var decodeTests = []decodeTest{
{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, {input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, {input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, {input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
{input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"},
{ {
input: "C501C302C103", input: "C501C302C103",
ptr: new(recstruct), ptr: new(recstruct),
value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
}, },
{
input: "C3010101",
ptr: new(simplestruct),
error: "rlp: input list has too many elements for rlp.simplestruct",
},
{
input: "C501C3C00000",
ptr: new(recstruct),
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
},
// pointers // pointers
{input: "00", ptr: new(*uint), value: (*uint)(nil)}, {input: "00", ptr: new(*uint), value: (*uint)(nil)},
{input: "80", ptr: new(*uint), value: (*uint)(nil)}, {input: "80", ptr: new(*uint), value: (*uint)(nil)},