diff --git a/VERSION b/VERSION index 428b770e3e..1c99cf0e80 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.3 +1.4.4 diff --git a/accounts/abi/abi.go b/accounts/abi/abi.go index 32df6f19d7..1b07b2f68e 100644 --- a/accounts/abi/abi.go +++ b/accounts/abi/abi.go @@ -238,8 +238,16 @@ func (abi ABI) Unpack(v interface{}, name string, output []byte) error { return fmt.Errorf("abi: unmarshalling empty output") } - value := reflect.ValueOf(v).Elem() - typ := value.Type() + // make sure the passed value is a pointer + valueOf := reflect.ValueOf(v) + if reflect.Ptr != valueOf.Kind() { + return fmt.Errorf("abi: Unpack(non-pointer %T)", v) + } + + var ( + value = valueOf.Elem() + typ = value.Type() + ) if len(method.Outputs) > 1 { switch value.Kind() { @@ -268,6 +276,25 @@ func (abi ABI) Unpack(v interface{}, name string, output []byte) error { return fmt.Errorf("abi: cannot marshal tuple in to slice %T (only []interface{} is supported)", v) } + // if the slice already contains values, set those instead of the interface slice itself. + if value.Len() > 0 { + if len(method.Outputs) > value.Len() { + return fmt.Errorf("abi: cannot marshal in to slices of unequal size (require: %v, got: %v)", len(method.Outputs), value.Len()) + } + + for i := 0; i < len(method.Outputs); i++ { + marshalledValue, err := toGoType(i, method.Outputs[i], output) + if err != nil { + return err + } + reflectValue := reflect.ValueOf(marshalledValue) + if err := set(value.Index(i).Elem(), reflectValue, method.Outputs[i]); err != nil { + return err + } + } + return nil + } + // create a new slice and start appending the unmarshalled // values to the new interface slice. z := reflect.MakeSlice(typ, 0, len(method.Outputs)) @@ -296,34 +323,6 @@ func (abi ABI) Unpack(v interface{}, name string, output []byte) error { return nil } -// set attempts to assign src to dst by either setting, copying or otherwise. -// -// set is a bit more lenient when it comes to assignment and doesn't force an as -// strict ruleset as bare `reflect` does. -func set(dst, src reflect.Value, output Argument) error { - dstType := dst.Type() - srcType := src.Type() - - switch { - case dstType.AssignableTo(src.Type()): - dst.Set(src) - case dstType.Kind() == reflect.Array && srcType.Kind() == reflect.Slice: - if !dstType.Elem().AssignableTo(r_byte) { - return fmt.Errorf("abi: cannot unmarshal %v in to array of elem %v", src.Type(), dstType.Elem()) - } - - if dst.Len() < output.Type.SliceSize { - return fmt.Errorf("abi: cannot unmarshal src (len=%d) in to dst (len=%d)", output.Type.SliceSize, dst.Len()) - } - reflect.Copy(dst, src) - case dstType.Kind() == reflect.Interface: - dst.Set(src) - default: - return fmt.Errorf("abi: cannot unmarshal %v in to %v", src.Type(), dst.Type()) - } - return nil -} - func (abi *ABI) UnmarshalJSON(data []byte) error { var fields []struct { Type string diff --git a/accounts/abi/abi_test.go b/accounts/abi/abi_test.go index 05535b3b50..df89ba1381 100644 --- a/accounts/abi/abi_test.go +++ b/accounts/abi/abi_test.go @@ -289,6 +289,37 @@ func TestSimpleMethodUnpack(t *testing.T) { } } +func TestUnpackSetInterfaceSlice(t *testing.T) { + var ( + var1 = new(uint8) + var2 = new(uint8) + ) + out := []interface{}{var1, var2} + abi, err := JSON(strings.NewReader(`[{"type":"function", "name":"ints", "outputs":[{"type":"uint8"}, {"type":"uint8"}]}]`)) + if err != nil { + t.Fatal(err) + } + marshalledReturn := append(pad([]byte{1}, 32, true), pad([]byte{2}, 32, true)...) + err = abi.Unpack(&out, "ints", marshalledReturn) + if err != nil { + t.Fatal(err) + } + if *var1 != 1 { + t.Errorf("expected var1 to be 1, got", *var1) + } + if *var2 != 2 { + t.Errorf("expected var2 to be 2, got", *var2) + } + + out = []interface{}{var1} + err = abi.Unpack(&out, "ints", marshalledReturn) + + expErr := "abi: cannot marshal in to slices of unequal size (require: 2, got: 1)" + if err == nil || err.Error() != expErr { + t.Error("expected err:", expErr, "Got:", err) + } +} + func TestPack(t *testing.T) { for i, test := range []struct { typ string diff --git a/accounts/abi/bind/bind_test.go b/accounts/abi/bind/bind_test.go index f9cc8aba47..a80560821c 100644 --- a/accounts/abi/bind/bind_test.go +++ b/accounts/abi/bind/bind_test.go @@ -194,12 +194,44 @@ var bindTests = []struct { } `, }, + // Tests that plain values can be properly returned and deserialized + { + `Getter`, + ` + contract Getter { + function getter() constant returns (string, int, bytes32) { + return ("Hi", 1, sha3("")); + } + } + `, + `606060405260dc8060106000396000f3606060405260e060020a6000350463993a04b78114601a575b005b600060605260c0604052600260809081527f486900000000000000000000000000000000000000000000000000000000000060a05260017fc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a47060e0829052610100819052606060c0908152600261012081905281906101409060a09080838184600060046012f1505081517fffff000000000000000000000000000000000000000000000000000000000000169091525050604051610160819003945092505050f3`, + `[{"constant":true,"inputs":[],"name":"getter","outputs":[{"name":"","type":"string"},{"name":"","type":"int256"},{"name":"","type":"bytes32"}],"type":"function"}]`, + ` + // Generate a new random account and a funded simulator + key, _ := crypto.GenerateKey() + auth := bind.NewKeyedTransactor(key) + sim := backends.NewSimulatedBackend(core.GenesisAccount{Address: auth.From, Balance: big.NewInt(10000000000)}) + + // Deploy a tuple tester contract and execute a structured call on it + _, _, getter, err := DeployGetter(auth, sim) + if err != nil { + t.Fatalf("Failed to deploy getter contract: %v", err) + } + sim.Commit() + + if str, num, _, err := getter.Getter(nil); err != nil { + t.Fatalf("Failed to call anonymous field retriever: %v", err) + } else if str != "Hi" || num.Cmp(big.NewInt(1)) != 0 { + t.Fatalf("Retrieved value mismatch: have %v/%v, want %v/%v", str, num, "Hi", 1) + } + `, + }, // Tests that tuples can be properly returned and deserialized { `Tupler`, ` contract Tupler { - function tuple() returns (string a, int b, bytes32 c) { + function tuple() constant returns (string a, int b, bytes32 c) { return ("Hi", 1, sha3("")); } } @@ -219,8 +251,10 @@ var bindTests = []struct { } sim.Commit() - if _, err := tupler.Tuple(nil); err != nil { + if res, err := tupler.Tuple(nil); err != nil { t.Fatalf("Failed to call structure retriever: %v", err) + } else if res.A != "Hi" || res.B.Cmp(big.NewInt(1)) != 0 { + t.Fatalf("Retrieved value mismatch: have %v/%v, want %v/%v", res.A, res.B, "Hi", 1) } `, }, diff --git a/accounts/abi/bind/template.go b/accounts/abi/bind/template.go index 36ac1d78d0..72998bb6d8 100644 --- a/accounts/abi/bind/template.go +++ b/accounts/abi/bind/template.go @@ -211,7 +211,7 @@ package {{.Package}} {{range $i, $_ := .Normalized.Outputs}}ret{{$i}} = new({{bindtype .Type}}) {{end}} ){{end}} - out := {{if .Structured}}ret{{else}}{{if eq (len .Normalized.Outputs) 1}}ret0{{else}}[]interface{}{ + out := {{if .Structured}}ret{{else}}{{if eq (len .Normalized.Outputs) 1}}ret0{{else}}&[]interface{}{ {{range $i, $_ := .Normalized.Outputs}}ret{{$i}}, {{end}} }{{end}}{{end}} diff --git a/accounts/abi/reflect.go b/accounts/abi/reflect.go index 780c64c66e..ab5020200e 100644 --- a/accounts/abi/reflect.go +++ b/accounts/abi/reflect.go @@ -16,7 +16,10 @@ package abi -import "reflect" +import ( + "fmt" + "reflect" +) // indirect recursively dereferences the value until it either gets the value // or finds a big.Int @@ -62,3 +65,33 @@ func mustArrayToByteSlice(value reflect.Value) reflect.Value { reflect.Copy(slice, value) return slice } + +// set attempts to assign src to dst by either setting, copying or otherwise. +// +// set is a bit more lenient when it comes to assignment and doesn't force an as +// strict ruleset as bare `reflect` does. +func set(dst, src reflect.Value, output Argument) error { + dstType := dst.Type() + srcType := src.Type() + + switch { + case dstType.AssignableTo(src.Type()): + dst.Set(src) + case dstType.Kind() == reflect.Array && srcType.Kind() == reflect.Slice: + if !dstType.Elem().AssignableTo(r_byte) { + return fmt.Errorf("abi: cannot unmarshal %v in to array of elem %v", src.Type(), dstType.Elem()) + } + + if dst.Len() < output.Type.SliceSize { + return fmt.Errorf("abi: cannot unmarshal src (len=%d) in to dst (len=%d)", output.Type.SliceSize, dst.Len()) + } + reflect.Copy(dst, src) + case dstType.Kind() == reflect.Interface: + dst.Set(src) + case dstType.Kind() == reflect.Ptr: + return set(dst.Elem(), src, output) + default: + return fmt.Errorf("abi: cannot unmarshal %v in to %v", src.Type(), dst.Type()) + } + return nil +} diff --git a/cmd/geth/main.go b/cmd/geth/main.go index a6c36582c9..5a2fc62873 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -50,7 +50,7 @@ const ( clientIdentifier = "Geth" // Client identifier to advertise over the network versionMajor = 1 // Major version component of the current release versionMinor = 4 // Minor version component of the current release - versionPatch = 3 // Patch version component of the current release + versionPatch = 4 // Patch version component of the current release versionMeta = "stable" // Version metadata to append to the version string versionOracle = "0xfa7b9770ca4cb04296cac84f37736d4041251cdf" // Ethereum address of the Geth release oracle diff --git a/core/block_validator.go b/core/block_validator.go index 555c5ee064..801d2572b6 100644 --- a/core/block_validator.go +++ b/core/block_validator.go @@ -292,7 +292,7 @@ func calcDifficultyHomestead(time, parentTime uint64, parentNumber, parentDiff * // minimum difficulty can ever be (before exponential factor) if x.Cmp(params.MinimumDifficulty) < 0 { - x = params.MinimumDifficulty + x.Set(params.MinimumDifficulty) } // for the exponential factor @@ -325,7 +325,7 @@ func calcDifficultyFrontier(time, parentTime uint64, parentNumber, parentDiff *b diff.Sub(parentDiff, adjust) } if diff.Cmp(params.MinimumDifficulty) < 0 { - diff = params.MinimumDifficulty + diff.Set(params.MinimumDifficulty) } periodCount := new(big.Int).Add(parentNumber, common.Big1) diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 801181712f..0f76357cba 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/params" "github.com/rcrowley/go-metrics" ) @@ -45,6 +46,8 @@ var ( MaxReceiptFetch = 256 // Amount of transaction receipts to allow fetching per request MaxStateFetch = 384 // Amount of node state values to allow fetching per request + MaxForkAncestry = 3 * params.EpochDuration.Uint64() // Maximum chain reorganisation + hashTTL = 3 * time.Second // [eth/61] Time it takes for a hash request to time out blockTargetRTT = 3 * time.Second / 2 // [eth/61] Target time for completing a block retrieval request blockTTL = 3 * blockTargetRTT // [eth/61] Maximum time allowance before a block request is considered expired @@ -79,6 +82,7 @@ var ( errEmptyHeaderSet = errors.New("empty header set by peer") errPeersUnavailable = errors.New("no peers available or all tried for download") errAlreadyInPool = errors.New("hash already in pool") + errInvalidAncestor = errors.New("retrieved ancestor is invalid") errInvalidChain = errors.New("retrieved hash chain is invalid") errInvalidBlock = errors.New("retrieved block is invalid") errInvalidBody = errors.New("retrieved block body is invalid") @@ -266,7 +270,7 @@ func (d *Downloader) Synchronise(id string, head common.Hash, td *big.Int, mode case errBusy: glog.V(logger.Detail).Infof("Synchronisation already in progress") - case errTimeout, errBadPeer, errStallingPeer, errEmptyHashSet, errEmptyHeaderSet, errPeersUnavailable, errInvalidChain: + case errTimeout, errBadPeer, errStallingPeer, errEmptyHashSet, errEmptyHeaderSet, errPeersUnavailable, errInvalidAncestor, errInvalidChain: glog.V(logger.Debug).Infof("Removing peer %v: %v", id, err) d.dropPeer(id) @@ -353,7 +357,7 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e if err != nil { return err } - origin, err := d.findAncestor61(p) + origin, err := d.findAncestor61(p, latest) if err != nil { return err } @@ -380,7 +384,7 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e if err != nil { return err } - origin, err := d.findAncestor(p) + origin, err := d.findAncestor(p, latest) if err != nil { return err } @@ -536,11 +540,19 @@ func (d *Downloader) fetchHeight61(p *peer) (uint64, error) { // on the correct chain, checking the top N blocks should already get us a match. // In the rare scenario when we ended up on a long reorganisation (i.e. none of // the head blocks match), we do a binary search to find the common ancestor. -func (d *Downloader) findAncestor61(p *peer) (uint64, error) { +func (d *Downloader) findAncestor61(p *peer, height uint64) (uint64, error) { glog.V(logger.Debug).Infof("%v: looking for common ancestor", p) - // Request out head blocks to short circuit ancestor location - head := d.headBlock().NumberU64() + // Figure out the valid ancestor range to prevent rewrite attacks + floor, ceil := int64(-1), d.headBlock().NumberU64() + if ceil >= MaxForkAncestry { + floor = int64(ceil - MaxForkAncestry) + } + // Request the topmost blocks to short circuit binary ancestor lookup + head := ceil + if head > height { + head = height + } from := int64(head) - int64(MaxHashFetch) + 1 if from < 0 { from = 0 @@ -600,11 +612,18 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) { } // If the head fetch already found an ancestor, return if !common.EmptyHash(hash) { + if int64(number) <= floor { + glog.V(logger.Warn).Infof("%v: potential rewrite attack: #%d [%x…] <= #%d limit", p, number, hash[:4], floor) + return 0, errInvalidAncestor + } glog.V(logger.Debug).Infof("%v: common ancestor: #%d [%x…]", p, number, hash[:4]) return number, nil } // Ancestor not found, we need to binary search over our chain start, end := uint64(0), head + if floor > 0 { + start = uint64(floor) + } for start+1 < end { // Split our chain interval in two, and request the hash to cross check check := (start + end) / 2 @@ -660,6 +679,12 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) { } } } + // Ensure valid ancestry and return + if int64(start) <= floor { + glog.V(logger.Warn).Infof("%v: potential rewrite attack: #%d [%x…] <= #%d limit", p, start, hash[:4], floor) + return 0, errInvalidAncestor + } + glog.V(logger.Debug).Infof("%v: common ancestor: #%d [%x…]", p, start, hash[:4]) return start, nil } @@ -961,15 +986,23 @@ func (d *Downloader) fetchHeight(p *peer) (uint64, error) { // on the correct chain, checking the top N links should already get us a match. // In the rare scenario when we ended up on a long reorganisation (i.e. none of // the head links match), we do a binary search to find the common ancestor. -func (d *Downloader) findAncestor(p *peer) (uint64, error) { +func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { glog.V(logger.Debug).Infof("%v: looking for common ancestor", p) - // Request our head headers to short circuit ancestor location - head := d.headHeader().Number.Uint64() + // Figure out the valid ancestor range to prevent rewrite attacks + floor, ceil := int64(-1), d.headHeader().Number.Uint64() if d.mode == FullSync { - head = d.headBlock().NumberU64() + ceil = d.headBlock().NumberU64() } else if d.mode == FastSync { - head = d.headFastBlock().NumberU64() + ceil = d.headFastBlock().NumberU64() + } + if ceil >= MaxForkAncestry { + floor = int64(ceil - MaxForkAncestry) + } + // Request the topmost blocks to short circuit binary ancestor lookup + head := ceil + if head > height { + head = height } from := int64(head) - int64(MaxHeaderFetch) + 1 if from < 0 { @@ -1040,11 +1073,18 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) { } // If the head fetch already found an ancestor, return if !common.EmptyHash(hash) { + if int64(number) <= floor { + glog.V(logger.Warn).Infof("%v: potential rewrite attack: #%d [%x…] <= #%d limit", p, number, hash[:4], floor) + return 0, errInvalidAncestor + } glog.V(logger.Debug).Infof("%v: common ancestor: #%d [%x…]", p, number, hash[:4]) return number, nil } // Ancestor not found, we need to binary search over our chain start, end := uint64(0), head + if floor > 0 { + start = uint64(floor) + } for start+1 < end { // Split our chain interval in two, and request the hash to cross check check := (start + end) / 2 @@ -1100,6 +1140,12 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) { } } } + // Ensure valid ancestry and return + if int64(start) <= floor { + glog.V(logger.Warn).Infof("%v: potential rewrite attack: #%d [%x…] <= #%d limit", p, start, hash[:4], floor) + return 0, errInvalidAncestor + } + glog.V(logger.Debug).Infof("%v: common ancestor: #%d [%x…]", p, start, hash[:4]) return start, nil } diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index e66a902643..b0b0c2bd32 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -43,8 +43,9 @@ var ( genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000)) ) -// Reduce the block cache limit, otherwise the tests will be very heavy. +// Reduce some of the parameters to make the tester faster. func init() { + MaxForkAncestry = uint64(10000) blockCacheLimit = 1024 } @@ -52,11 +53,15 @@ func init() { // the returned hash chain is ordered head->parent. In addition, every 3rd block // contains a transaction and every 5th an uncle to allow testing correct block // reassembly. -func makeChain(n int, seed byte, parent *types.Block, parentReceipts types.Receipts) ([]common.Hash, map[common.Hash]*types.Header, map[common.Hash]*types.Block, map[common.Hash]types.Receipts) { +func makeChain(n int, seed byte, parent *types.Block, parentReceipts types.Receipts, heavy bool) ([]common.Hash, map[common.Hash]*types.Header, map[common.Hash]*types.Block, map[common.Hash]types.Receipts) { // Generate the block chain blocks, receipts := core.GenerateChain(parent, testdb, n, func(i int, block *core.BlockGen) { block.SetCoinbase(common.Address{seed}) + // If a heavy chain is requested, delay blocks to raise difficulty + if heavy { + block.OffsetTime(-1) + } // If the block number is multiple of 3, send a bonus transaction to the miner if parent == genesis && i%3 == 0 { tx, err := types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, nil, nil).SignECDSA(testKey) @@ -97,15 +102,19 @@ func makeChain(n int, seed byte, parent *types.Block, parentReceipts types.Recei // makeChainFork creates two chains of length n, such that h1[:f] and // h2[:f] are different but have a common suffix of length n-f. -func makeChainFork(n, f int, parent *types.Block, parentReceipts types.Receipts) ([]common.Hash, []common.Hash, map[common.Hash]*types.Header, map[common.Hash]*types.Header, map[common.Hash]*types.Block, map[common.Hash]*types.Block, map[common.Hash]types.Receipts, map[common.Hash]types.Receipts) { +func makeChainFork(n, f int, parent *types.Block, parentReceipts types.Receipts, balanced bool) ([]common.Hash, []common.Hash, map[common.Hash]*types.Header, map[common.Hash]*types.Header, map[common.Hash]*types.Block, map[common.Hash]*types.Block, map[common.Hash]types.Receipts, map[common.Hash]types.Receipts) { // Create the common suffix - hashes, headers, blocks, receipts := makeChain(n-f, 0, parent, parentReceipts) + hashes, headers, blocks, receipts := makeChain(n-f, 0, parent, parentReceipts, false) - // Create the forks - hashes1, headers1, blocks1, receipts1 := makeChain(f, 1, blocks[hashes[0]], receipts[hashes[0]]) + // Create the forks, making the second heavyer if non balanced forks were requested + hashes1, headers1, blocks1, receipts1 := makeChain(f, 1, blocks[hashes[0]], receipts[hashes[0]], false) hashes1 = append(hashes1, hashes[1:]...) - hashes2, headers2, blocks2, receipts2 := makeChain(f, 2, blocks[hashes[0]], receipts[hashes[0]]) + heavy := false + if !balanced { + heavy = true + } + hashes2, headers2, blocks2, receipts2 := makeChain(f, 2, blocks[hashes[0]], receipts[hashes[0]], heavy) hashes2 = append(hashes2, hashes[1:]...) for hash, header := range headers { @@ -712,7 +721,7 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -736,7 +745,7 @@ func TestThrottling64Fast(t *testing.T) { testThrottling(t, 64, FastSync) } func testThrottling(t *testing.T, protocol int, mode SyncMode) { // Create a long block chain to download and the tester targetBlocks := 8 * blockCacheLimit - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -810,20 +819,20 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { // Tests that simple synchronization against a forked chain works correctly. In // this test common ancestor lookup should *not* be short circuited, and a full // binary search should be executed. -func TestForkedSynchronisation61(t *testing.T) { testForkedSynchronisation(t, 61, FullSync) } -func TestForkedSynchronisation62(t *testing.T) { testForkedSynchronisation(t, 62, FullSync) } -func TestForkedSynchronisation63Full(t *testing.T) { testForkedSynchronisation(t, 63, FullSync) } -func TestForkedSynchronisation63Fast(t *testing.T) { testForkedSynchronisation(t, 63, FastSync) } -func TestForkedSynchronisation64Full(t *testing.T) { testForkedSynchronisation(t, 64, FullSync) } -func TestForkedSynchronisation64Fast(t *testing.T) { testForkedSynchronisation(t, 64, FastSync) } -func TestForkedSynchronisation64Light(t *testing.T) { testForkedSynchronisation(t, 64, LightSync) } +func TestForkedSync61(t *testing.T) { testForkedSync(t, 61, FullSync) } +func TestForkedSync62(t *testing.T) { testForkedSync(t, 62, FullSync) } +func TestForkedSync63Full(t *testing.T) { testForkedSync(t, 63, FullSync) } +func TestForkedSync63Fast(t *testing.T) { testForkedSync(t, 63, FastSync) } +func TestForkedSync64Full(t *testing.T) { testForkedSync(t, 64, FullSync) } +func TestForkedSync64Fast(t *testing.T) { testForkedSync(t, 64, FastSync) } +func TestForkedSync64Light(t *testing.T) { testForkedSync(t, 64, LightSync) } -func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) { +func testForkedSync(t *testing.T, protocol int, mode SyncMode) { t.Parallel() // Create a long enough forked chain common, fork := MaxHashFetch, 2*MaxHashFetch - hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil) + hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, true) tester := newTester() tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA) @@ -842,6 +851,40 @@ func testForkedSynchronisation(t *testing.T, protocol int, mode SyncMode) { assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork + 1}) } +// Tests that synchronising against a much shorter but much heavyer fork works +// corrently and is not dropped. +func TestHeavyForkedSync61(t *testing.T) { testHeavyForkedSync(t, 61, FullSync) } +func TestHeavyForkedSync62(t *testing.T) { testHeavyForkedSync(t, 62, FullSync) } +func TestHeavyForkedSync63Full(t *testing.T) { testHeavyForkedSync(t, 63, FullSync) } +func TestHeavyForkedSync63Fast(t *testing.T) { testHeavyForkedSync(t, 63, FastSync) } +func TestHeavyForkedSync64Full(t *testing.T) { testHeavyForkedSync(t, 64, FullSync) } +func TestHeavyForkedSync64Fast(t *testing.T) { testHeavyForkedSync(t, 64, FastSync) } +func TestHeavyForkedSync64Light(t *testing.T) { testHeavyForkedSync(t, 64, LightSync) } + +func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { + t.Parallel() + + // Create a long enough forked chain + common, fork := MaxHashFetch, 4*MaxHashFetch + hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, false) + + tester := newTester() + tester.newPeer("light", protocol, hashesA, headersA, blocksA, receiptsA) + tester.newPeer("heavy", protocol, hashesB[fork/2:], headersB, blocksB, receiptsB) + + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("light", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, common+fork+1) + + // Synchronise with the second peer and make sure that fork is pulled too + if err := tester.sync("heavy", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnForkedChain(t, tester, common+1, []int{common + fork + 1, common + fork/2 + 1}) +} + // Tests that an inactive downloader will not accept incoming hashes and blocks. func TestInactiveDownloader61(t *testing.T) { t.Parallel() @@ -856,6 +899,74 @@ func TestInactiveDownloader61(t *testing.T) { } } +// Tests that chain forks are contained within a certain interval of the current +// chain head, ensuring that malicious peers cannot waste resources by feeding +// long dead chains. +func TestBoundedForkedSync61(t *testing.T) { testBoundedForkedSync(t, 61, FullSync) } +func TestBoundedForkedSync62(t *testing.T) { testBoundedForkedSync(t, 62, FullSync) } +func TestBoundedForkedSync63Full(t *testing.T) { testBoundedForkedSync(t, 63, FullSync) } +func TestBoundedForkedSync63Fast(t *testing.T) { testBoundedForkedSync(t, 63, FastSync) } +func TestBoundedForkedSync64Full(t *testing.T) { testBoundedForkedSync(t, 64, FullSync) } +func TestBoundedForkedSync64Fast(t *testing.T) { testBoundedForkedSync(t, 64, FastSync) } +func TestBoundedForkedSync64Light(t *testing.T) { testBoundedForkedSync(t, 64, LightSync) } + +func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) { + t.Parallel() + + // Create a long enough forked chain + common, fork := 13, int(MaxForkAncestry+17) + hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, true) + + tester := newTester() + tester.newPeer("original", protocol, hashesA, headersA, blocksA, receiptsA) + tester.newPeer("rewriter", protocol, hashesB, headersB, blocksB, receiptsB) + + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("original", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, common+fork+1) + + // Synchronise with the second peer and ensure that the fork is rejected to being too old + if err := tester.sync("rewriter", nil, mode); err != errInvalidAncestor { + t.Fatalf("sync failure mismatch: have %v, want %v", err, errInvalidAncestor) + } +} + +// Tests that chain forks are contained within a certain interval of the current +// chain head for short but heavy forks too. These are a bit special because they +// take different ancestor lookup paths. +func TestBoundedHeavyForkedSync61(t *testing.T) { testBoundedHeavyForkedSync(t, 61, FullSync) } +func TestBoundedHeavyForkedSync62(t *testing.T) { testBoundedHeavyForkedSync(t, 62, FullSync) } +func TestBoundedHeavyForkedSync63Full(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FullSync) } +func TestBoundedHeavyForkedSync63Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 63, FastSync) } +func TestBoundedHeavyForkedSync64Full(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FullSync) } +func TestBoundedHeavyForkedSync64Fast(t *testing.T) { testBoundedHeavyForkedSync(t, 64, FastSync) } +func TestBoundedHeavyForkedSync64Light(t *testing.T) { testBoundedHeavyForkedSync(t, 64, LightSync) } + +func testBoundedHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { + t.Parallel() + + // Create a long enough forked chain + common, fork := 13, int(MaxForkAncestry+17) + hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, false) + + tester := newTester() + tester.newPeer("original", protocol, hashesA, headersA, blocksA, receiptsA) + tester.newPeer("heavy-rewriter", protocol, hashesB[MaxForkAncestry-17:], headersB, blocksB, receiptsB) // Root the fork below the ancestor limit + + // Synchronise with the peer and make sure all blocks were retrieved + if err := tester.sync("original", nil, mode); err != nil { + t.Fatalf("failed to synchronise blocks: %v", err) + } + assertOwnChain(t, tester, common+fork+1) + + // Synchronise with the second peer and ensure that the fork is rejected to being too old + if err := tester.sync("heavy-rewriter", nil, mode); err != errInvalidAncestor { + t.Fatalf("sync failure mismatch: have %v, want %v", err, errInvalidAncestor) + } +} + // Tests that an inactive downloader will not accept incoming block headers and // bodies. func TestInactiveDownloader62(t *testing.T) { @@ -909,7 +1020,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) { if targetBlocks >= MaxHeaderFetch { targetBlocks = MaxHeaderFetch - 15 } - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -944,7 +1055,7 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { // Create various peers with various parts of the chain targetPeers := 8 targetBlocks := targetPeers*blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() for i := 0; i < targetPeers; i++ { @@ -972,7 +1083,7 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) // Create peers of every type tester := newTester() @@ -1010,7 +1121,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { // Create a block chain to download targetBlocks := 2*blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) @@ -1063,7 +1174,7 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() @@ -1095,7 +1206,7 @@ func TestShiftedHeaderAttack64Light(t *testing.T) { testShiftedHeaderAttack(t, 6 func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() @@ -1126,7 +1237,7 @@ func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback( func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := 3*fsHeaderSafetyNet + fsMinFullBlocks - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() @@ -1217,7 +1328,7 @@ func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { t.Parallel() tester := newTester() - hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil, false) tester.newPeer("attack", protocol, []common.Hash{hashes[0]}, headers, blocks, receipts) if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer { @@ -1247,6 +1358,7 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { {errEmptyHashSet, true}, // No hashes were returned as a response, drop as it's a dead end {errEmptyHeaderSet, true}, // No headers were returned as a response, drop as it's a dead end {errPeersUnavailable, true}, // Nobody had the advertised blocks, drop the advertiser + {errInvalidAncestor, true}, // Agreed upon ancestor is not acceptable, drop the chain rewriter {errInvalidChain, true}, // Hash chain was detected as invalid, definitely drop {errInvalidBlock, false}, // A bad peer was detected, but not the sync origin {errInvalidBody, false}, // A bad peer was detected, but not the sync origin @@ -1294,7 +1406,7 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) // Set a sync init hook to catch progress changes starting := make(chan struct{}) @@ -1366,7 +1478,7 @@ func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a forked chain to simulate origin revertal common, fork := MaxHashFetch, 2*MaxHashFetch - hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil) + hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, true) // Set a sync init hook to catch progress changes starting := make(chan struct{}) @@ -1441,7 +1553,7 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a small enough block chain to download targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) // Set a sync init hook to catch progress changes starting := make(chan struct{}) @@ -1517,7 +1629,7 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { // Create a small block chain targetBlocks := blockCacheLimit - 15 - hashes, headers, blocks, receipts := makeChain(targetBlocks+3, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(targetBlocks+3, 0, genesis, nil, false) // Set a sync init hook to catch progress changes starting := make(chan struct{}) @@ -1590,7 +1702,7 @@ func TestDeliverHeadersHang64Light(t *testing.T) { testDeliverHeadersHang(t, 64, func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { t.Parallel() - hashes, headers, blocks, receipts := makeChain(5, 0, genesis, nil) + hashes, headers, blocks, receipts := makeChain(5, 0, genesis, nil, false) fakeHeads := []*types.Header{{}, {}, {}, {}} for i := 0; i < 200; i++ { tester := newTester() diff --git a/event/event.go b/event/event.go index 57dd52baa1..fd0bcfbd48 100644 --- a/event/event.go +++ b/event/event.go @@ -66,6 +66,9 @@ func (mux *TypeMux) Subscribe(types ...interface{}) Subscription { mux.mutex.Lock() defer mux.mutex.Unlock() if mux.stopped { + // set the status to closed so that calling Unsubscribe after this + // call will short curuit + sub.closed = true close(sub.postC) } else { if mux.subm == nil { diff --git a/event/event_test.go b/event/event_test.go index 323cfea49e..3940293013 100644 --- a/event/event_test.go +++ b/event/event_test.go @@ -25,6 +25,14 @@ import ( type testEvent int +func TestSubCloseUnsub(t *testing.T) { + // the point of this test is **not** to panic + var mux TypeMux + mux.Stop() + sub := mux.Subscribe(int(0)) + sub.Unsubscribe() +} + func TestSub(t *testing.T) { mux := new(TypeMux) defer mux.Stop() diff --git a/rpc/websocket.go b/rpc/websocket.go index 1303f98db1..fe9354d946 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -61,22 +61,22 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http allowAllOrigins = true } if origin != "" { - origins.Add(origin) + origins.Add(strings.ToLower(origin)) } } - // allow localhost if no allowedOrigins are specified + // allow localhost if no allowedOrigins are specified. if len(origins.List()) == 0 { origins.Add("http://localhost") if hostname, err := os.Hostname(); err == nil { - origins.Add("http://" + hostname) + origins.Add("http://" + strings.ToLower(hostname)) } } glog.V(logger.Debug).Infof("Allowed origin(s) for WS RPC interface %v\n", origins.List()) f := func(cfg *websocket.Config, req *http.Request) error { - origin := req.Header.Get("Origin") + origin := strings.ToLower(req.Header.Get("Origin")) if allowAllOrigins || origins.Has(origin) { return nil }