From 48d2a8b8ee9621810a988e3561e4213749c54da7 Mon Sep 17 00:00:00 2001 From: obscuren Date: Fri, 2 Jan 2015 12:09:38 +0100 Subject: [PATCH] Refactored tx pool and added extra fields to block * chain manager sets td on block + td output w/ String * added tx pool tests for removing/adding/validating * tx pool now uses a set for txs instead of list.List --- core/transaction_pool.go | 130 +++++++++------------------------- core/transaction_pool_test.go | 82 +++++++++++++++++++++ core/types/block.go | 4 +- core/types/transaction.go | 5 ++ 4 files changed, 123 insertions(+), 98 deletions(-) create mode 100644 core/transaction_pool_test.go diff --git a/core/transaction_pool.go b/core/transaction_pool.go index 1149d4cfb8..4f91f35755 100644 --- a/core/transaction_pool.go +++ b/core/transaction_pool.go @@ -1,16 +1,13 @@ package core import ( - "bytes" - "container/list" "fmt" "math/big" - "sync" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/state" + "gopkg.in/fatih/set.v0" ) var txplogger = logger.NewLogger("TXP") @@ -26,86 +23,50 @@ const ( minGasPrice = 1000000 ) -var MinGasPrice = big.NewInt(10000000000000) - -func EachTx(pool *list.List, it func(*types.Transaction, *list.Element) bool) { - for e := pool.Front(); e != nil; e = e.Next() { - if it(e.Value.(*types.Transaction), e) { - break - } - } -} - -func FindTx(pool *list.List, finder func(*types.Transaction, *list.Element) bool) *types.Transaction { - for e := pool.Front(); e != nil; e = e.Next() { - if tx, ok := e.Value.(*types.Transaction); ok { - if finder(tx, e) { - return tx - } - } - } - - return nil -} - type TxProcessor interface { ProcessTransaction(tx *types.Transaction) } // The tx pool a thread safe transaction pool handler. In order to // guarantee a non blocking pool we use a queue channel which can be -// independently read without needing access to the actual pool. If the -// pool is being drained or synced for whatever reason the transactions -// will simple queue up and handled when the mutex is freed. +// independently read without needing access to the actual pool. type TxPool struct { - // The mutex for accessing the Tx pool. - mutex sync.Mutex // Queueing channel for reading and writing incoming // transactions to queueChan chan *types.Transaction // Quiting channel quit chan bool // The actual pool - pool *list.List + //pool *list.List + pool *set.Set SecondaryProcessor TxProcessor subscribers []chan TxMsg - chainManager *ChainManager - eventMux *event.TypeMux + stateQuery StateQuery + eventMux *event.TypeMux } -func NewTxPool(chainManager *ChainManager, eventMux *event.TypeMux) *TxPool { +func NewTxPool(stateQuery StateQuery, eventMux *event.TypeMux) *TxPool { return &TxPool{ - pool: list.New(), - queueChan: make(chan *types.Transaction, txPoolQueueSize), - quit: make(chan bool), - chainManager: chainManager, - eventMux: eventMux, + pool: set.New(), + queueChan: make(chan *types.Transaction, txPoolQueueSize), + quit: make(chan bool), + stateQuery: stateQuery, + eventMux: eventMux, } } -// Blocking function. Don't use directly. Use QueueTransaction instead func (pool *TxPool) addTransaction(tx *types.Transaction) { - pool.mutex.Lock() - defer pool.mutex.Unlock() - pool.pool.PushBack(tx) + pool.pool.Add(tx) // Broadcast the transaction to the rest of the peers pool.eventMux.Post(TxPreEvent{tx}) } func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { - // Get the last block so we can retrieve the sender and receiver from - // the merkle trie - block := pool.chainManager.CurrentBlock - // Something has gone horribly wrong if this happens - if block == nil { - return fmt.Errorf("No last block on the block chain") - } - if len(tx.To()) != 0 && len(tx.To()) != 20 { return fmt.Errorf("Invalid recipient. len = %d", len(tx.To())) } @@ -120,7 +81,7 @@ func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { if senderAddr == nil { return fmt.Errorf("invalid sender") } - sender := pool.chainManager.State().GetAccount(senderAddr) + sender := pool.stateQuery.GetAccount(senderAddr) totAmount := new(big.Int).Set(tx.Value()) // Make sure there's enough in the sender's account. Having insufficient @@ -129,19 +90,12 @@ func (pool *TxPool) ValidateTransaction(tx *types.Transaction) error { return fmt.Errorf("Insufficient amount in sender's (%x) account", tx.From()) } - // Increment the nonce making each tx valid only once to prevent replay - // attacks - return nil } func (self *TxPool) Add(tx *types.Transaction) error { hash := tx.Hash() - foundTx := FindTx(self.pool, func(tx *types.Transaction, e *list.Element) bool { - return bytes.Compare(tx.Hash(), hash) == 0 - }) - - if foundTx != nil { + if self.pool.Has(tx) { return fmt.Errorf("Known transaction (%x)", hash[0:4]) } @@ -161,7 +115,7 @@ func (self *TxPool) Add(tx *types.Transaction) error { } func (self *TxPool) Size() int { - return self.pool.Len() + return self.pool.Size() } func (self *TxPool) AddTransactions(txs []*types.Transaction) { @@ -175,63 +129,47 @@ func (self *TxPool) AddTransactions(txs []*types.Transaction) { } func (pool *TxPool) GetTransactions() []*types.Transaction { - pool.mutex.Lock() - defer pool.mutex.Unlock() - - txList := make([]*types.Transaction, pool.pool.Len()) + txList := make([]*types.Transaction, pool.Size()) i := 0 - for e := pool.pool.Front(); e != nil; e = e.Next() { - tx := e.Value.(*types.Transaction) - - txList[i] = tx - + pool.pool.Each(func(v interface{}) bool { + txList[i] = v.(*types.Transaction) i++ - } + + return true + }) return txList } -func (pool *TxPool) RemoveInvalid(state *state.StateDB) { - pool.mutex.Lock() - defer pool.mutex.Unlock() - - for e := pool.pool.Front(); e != nil; e = e.Next() { - tx := e.Value.(*types.Transaction) - sender := state.GetAccount(tx.From()) +func (pool *TxPool) RemoveInvalid(query StateQuery) { + var removedTxs types.Transactions + pool.pool.Each(func(v interface{}) bool { + tx := v.(*types.Transaction) + sender := query.GetAccount(tx.From()) err := pool.ValidateTransaction(tx) if err != nil || sender.Nonce >= tx.Nonce() { - pool.pool.Remove(e) + removedTxs = append(removedTxs, tx) } - } + + return true + }) + pool.RemoveSet(removedTxs) } func (self *TxPool) RemoveSet(txs types.Transactions) { - self.mutex.Lock() - defer self.mutex.Unlock() - for _, tx := range txs { - EachTx(self.pool, func(t *types.Transaction, element *list.Element) bool { - if t == tx { - self.pool.Remove(element) - return true // To stop the loop - } - return false - }) + self.pool.Remove(tx) } } func (pool *TxPool) Flush() []*types.Transaction { txList := pool.GetTransactions() - - // Recreate a new list all together - // XXX Is this the fastest way? - pool.pool = list.New() + pool.pool.Clear() return txList } func (pool *TxPool) Start() { - //go pool.queueHandler() } func (pool *TxPool) Stop() { diff --git a/core/transaction_pool_test.go b/core/transaction_pool_test.go new file mode 100644 index 0000000000..296c6bd8a6 --- /dev/null +++ b/core/transaction_pool_test.go @@ -0,0 +1,82 @@ +package core + +import ( + "crypto/ecdsa" + "testing" + + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/ethutil" + "github.com/ethereum/go-ethereum/event" + "github.com/ethereum/go-ethereum/state" +) + +// State query interface +type stateQuery struct{} + +func (self stateQuery) GetAccount(addr []byte) *state.StateObject { + return state.NewStateObject(addr) +} + +// State query interface +type invalidStateQuery struct{} + +func (self invalidStateQuery) GetAccount(addr []byte) *state.StateObject { + o := state.NewStateObject(addr) + o.Nonce++ + return o +} + +func transaction() *types.Transaction { + return types.NewTransactionMessage(make([]byte, 20), ethutil.Big0, ethutil.Big0, ethutil.Big0, nil) +} + +func setup() (*TxPool, *ecdsa.PrivateKey) { + var m event.TypeMux + key, _ := crypto.GenerateKey() + return NewTxPool(stateQuery{}, &m), key +} + +func TestTxAdding(t *testing.T) { + pool, key := setup() + tx1 := transaction() + tx1.SignECDSA(key) + err := pool.Add(tx1) + if err != nil { + t.Error(err) + } + + err = pool.Add(tx1) + if err == nil { + t.Error("added tx twice") + } +} + +func TestAddInvalidTx(t *testing.T) { + pool, _ := setup() + tx1 := transaction() + err := pool.Add(tx1) + if err == nil { + t.Error("expected error") + } +} + +func TestRemoveSet(t *testing.T) { + pool, _ := setup() + tx1 := transaction() + pool.pool.Add(tx1) + pool.RemoveSet(types.Transactions{tx1}) + if pool.Size() > 0 { + t.Error("expected pool size to be 0") + } +} + +func TestRemoveInvalid(t *testing.T) { + pool, _ := setup() + tx1 := transaction() + pool.pool.Add(tx1) + pool.RemoveInvalid(invalidStateQuery{}) + if pool.Size() > 0 { + t.Error("expected pool size to be 0") + } +} diff --git a/core/types/block.go b/core/types/block.go index 7b4695f733..b59044bfce 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -209,7 +209,7 @@ func (self *Block) HashNoNonce() []byte { } func (self *Block) String() string { - return fmt.Sprintf(`BLOCK(%x): Size: %v { + return fmt.Sprintf(`BLOCK(%x): Size: %v TD: %v { Header: [ %v @@ -219,7 +219,7 @@ Transactions: Uncles: %v } -`, self.header.Hash(), self.Size(), self.header, self.transactions, self.uncles) +`, self.header.Hash(), self.Size(), self.Td, self.header, self.transactions, self.uncles) } func (self *Header) String() string { diff --git a/core/types/transaction.go b/core/types/transaction.go index 59244adc3d..83d76648fd 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -2,6 +2,7 @@ package types import ( "bytes" + "crypto/ecdsa" "fmt" "math/big" @@ -139,6 +140,10 @@ func (tx *Transaction) Sign(privk []byte) error { return nil } +func (tx *Transaction) SignECDSA(key *ecdsa.PrivateKey) error { + return tx.Sign(crypto.FromECDSA(key)) +} + func (tx *Transaction) RlpData() interface{} { data := []interface{}{tx.AccountNonce, tx.Price, tx.GasLimit, tx.Recipient, tx.Amount, tx.Payload}