diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 829bf5d43b..87801c29f5 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -8,7 +8,7 @@ and help. ## Contributing If you'd like to contribute to go-ethereum please fork, fix, commit and -send a pull request. Commits who do not comply with the coding standards +send a pull request. Commits which do not comply with the coding standards are ignored (use gofmt!). If you send pull requests make absolute sure that you commit on the `develop` branch and that you do not merge to master. Commits that are directly based on master are simply ignored. diff --git a/.travis.yml b/.travis.yml index c1d545c548..24486d4a0a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,12 @@ language: go go: - 1.4.2 + - 1.5.4 + - 1.6.2 install: # - go get code.google.com/p/go.tools/cmd/goimports # - go get github.com/golang/lint/golint - # - go get golang.org/x/tools/cmd/vet + # - go get golang.org/x/tools/cmd/vet - go get golang.org/x/tools/cmd/cover before_script: # - gofmt -l -w . @@ -24,6 +26,6 @@ notifications: webhooks: urls: - https://webhooks.gitter.im/e/e09ccdce1048c5e03445 - on_success: change + on_success: change on_failure: always - on_start: false + on_start: false diff --git a/README.md b/README.md index 70f90975e3..4acb0ff73d 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,10 @@ Once the dependencies are installed, run make geth +or, to build the full suite of utilities: + + make all + ## Executables The go-ethereum project comes with several wrappers/executables found in the `cmd` directory. @@ -58,14 +62,14 @@ anyone on the internet, and are grateful for even the smallest of fixes! If you'd like to contribute to go-ethereum, please fork, fix, commit and send a pull request for the maintainers to review and merge into the main code base. If you wish to submit more complex changes though, please check up with the core devs first on [our gitter channel](https://gitter.im/ethereum/go-ethereum) -to ensure those changes are in line with the general philosopy of the project and/or get some +to ensure those changes are in line with the general philosophy of the project and/or get some early feedback which can make both your efforts much lighter as well as our review and merge procedures quick and simple. -Please make sure your contributions adhere to our coding guidlines: +Please make sure your contributions adhere to our coding guidelines: * Code must adhere to the official Go [formatting](https://golang.org/doc/effective_go.html#formatting) guidelines (i.e. uses [gofmt](https://golang.org/cmd/gofmt/)). - * Code must be documented adherign to the official Go [commentary](https://golang.org/doc/effective_go.html#commentary) guidelines. + * Code must be documented adhering to the official Go [commentary](https://golang.org/doc/effective_go.html#commentary) guidelines. * Pull requests need to be based on and opened against the `develop` branch. * Commit messages should be prefixed with the package(s) they modify. * E.g. "eth, rpc: make trace configs optional" diff --git a/VERSION b/VERSION index e516bb9d96..c514bd85c2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.5 +1.4.6 diff --git a/cmd/geth/accountcmd.go b/cmd/geth/accountcmd.go index bf754c72f1..0f9d95c2c5 100644 --- a/cmd/geth/accountcmd.go +++ b/cmd/geth/accountcmd.go @@ -23,6 +23,7 @@ import ( "github.com/codegangsta/cli" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/cmd/utils" + "github.com/ethereum/go-ethereum/console" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" @@ -215,12 +216,12 @@ func getPassPhrase(prompt string, confirmation bool, i int, passwords []string) if prompt != "" { fmt.Println(prompt) } - password, err := utils.Stdin.PasswordPrompt("Passphrase: ") + password, err := console.Stdin.PromptPassword("Passphrase: ") if err != nil { utils.Fatalf("Failed to read passphrase: %v", err) } if confirmation { - confirm, err := utils.Stdin.PasswordPrompt("Repeat passphrase: ") + confirm, err := console.Stdin.PromptPassword("Repeat passphrase: ") if err != nil { utils.Fatalf("Failed to read passphrase confirmation: %v", err) } diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index 32eacc99ee..4f47de5d70 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -26,6 +26,7 @@ import ( "github.com/codegangsta/cli" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/console" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" @@ -116,7 +117,7 @@ func exportChain(ctx *cli.Context) { } func removeDB(ctx *cli.Context) { - confirm, err := utils.Stdin.ConfirmPrompt("Remove local database?") + confirm, err := console.Stdin.PromptConfirm("Remove local database?") if err != nil { utils.Fatalf("%v", err) } diff --git a/cmd/geth/consolecmd.go b/cmd/geth/consolecmd.go new file mode 100644 index 0000000000..8bfe27fef3 --- /dev/null +++ b/cmd/geth/consolecmd.go @@ -0,0 +1,167 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "os" + "os/signal" + + "github.com/codegangsta/cli" + "github.com/ethereum/go-ethereum/cmd/utils" + "github.com/ethereum/go-ethereum/console" +) + +var ( + consoleCommand = cli.Command{ + Action: localConsole, + Name: "console", + Usage: `Geth Console: interactive JavaScript environment`, + Description: ` +The Geth console is an interactive shell for the JavaScript runtime environment +which exposes a node admin interface as well as the Ðapp JavaScript API. +See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console +`, + } + attachCommand = cli.Command{ + Action: remoteConsole, + Name: "attach", + Usage: `Geth Console: interactive JavaScript environment (connect to node)`, + Description: ` +The Geth console is an interactive shell for the JavaScript runtime environment +which exposes a node admin interface as well as the Ðapp JavaScript API. +See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console. +This command allows to open a console on a running geth node. + `, + } + javascriptCommand = cli.Command{ + Action: ephemeralConsole, + Name: "js", + Usage: `executes the given JavaScript files in the Geth JavaScript VM`, + Description: ` +The JavaScript VM exposes a node admin interface as well as the Ðapp +JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console +`, + } +) + +// localConsole starts a new geth node, attaching a JavaScript console to it at the +// same time. +func localConsole(ctx *cli.Context) { + // Create and start the node based on the CLI flags + node := utils.MakeSystemNode(clientIdentifier, verString, relConfig, makeDefaultExtra(), ctx) + startNode(ctx, node) + defer node.Stop() + + // Attach to the newly started node and start the JavaScript console + client, err := node.Attach() + if err != nil { + utils.Fatalf("Failed to attach to the inproc geth: %v", err) + } + config := console.Config{ + DataDir: node.DataDir(), + DocRoot: ctx.GlobalString(utils.JSpathFlag.Name), + Client: client, + Preload: utils.MakeConsolePreloads(ctx), + } + console, err := console.New(config) + if err != nil { + utils.Fatalf("Failed to start the JavaScript console: %v", err) + } + defer console.Stop(false) + + // If only a short execution was requested, evaluate and return + if script := ctx.GlobalString(utils.ExecFlag.Name); script != "" { + console.Evaluate(script) + return + } + // Otherwise print the welcome screen and enter interactive mode + console.Welcome() + console.Interactive() +} + +// remoteConsole will connect to a remote geth instance, attaching a JavaScript +// console to it. +func remoteConsole(ctx *cli.Context) { + // Attach to a remotely running geth instance and start the JavaScript console + client, err := utils.NewRemoteRPCClient(ctx) + if err != nil { + utils.Fatalf("Unable to attach to remote geth: %v", err) + } + config := console.Config{ + DataDir: utils.MustMakeDataDir(ctx), + DocRoot: ctx.GlobalString(utils.JSpathFlag.Name), + Client: client, + Preload: utils.MakeConsolePreloads(ctx), + } + console, err := console.New(config) + if err != nil { + utils.Fatalf("Failed to start the JavaScript console: %v", err) + } + defer console.Stop(false) + + // If only a short execution was requested, evaluate and return + if script := ctx.GlobalString(utils.ExecFlag.Name); script != "" { + console.Evaluate(script) + return + } + // Otherwise print the welcome screen and enter interactive mode + console.Welcome() + console.Interactive() +} + +// ephemeralConsole starts a new geth node, attaches an ephemeral JavaScript +// console to it, and each of the files specified as arguments and tears the +// everything down. +func ephemeralConsole(ctx *cli.Context) { + // Create and start the node based on the CLI flags + node := utils.MakeSystemNode(clientIdentifier, verString, relConfig, makeDefaultExtra(), ctx) + startNode(ctx, node) + defer node.Stop() + + // Attach to the newly started node and start the JavaScript console + client, err := node.Attach() + if err != nil { + utils.Fatalf("Failed to attach to the inproc geth: %v", err) + } + config := console.Config{ + DataDir: node.DataDir(), + DocRoot: ctx.GlobalString(utils.JSpathFlag.Name), + Client: client, + Preload: utils.MakeConsolePreloads(ctx), + } + console, err := console.New(config) + if err != nil { + utils.Fatalf("Failed to start the JavaScript console: %v", err) + } + defer console.Stop(false) + + // Evaluate each of the specified JavaScript files + for _, file := range ctx.Args() { + if err = console.Execute(file); err != nil { + utils.Fatalf("Failed to execute %s: %v", file, err) + } + } + // Wait for pending callbacks, but stop for Ctrl-C. + abort := make(chan os.Signal, 1) + signal.Notify(abort, os.Interrupt) + + go func() { + <-abort + os.Exit(0) + }() + console.Stop(true) +} diff --git a/cmd/geth/consolecmd_test.go b/cmd/geth/consolecmd_test.go new file mode 100644 index 0000000000..e59fe1415b --- /dev/null +++ b/cmd/geth/consolecmd_test.go @@ -0,0 +1,162 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . + +package main + +import ( + "math/rand" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/rpc" +) + +// Tests that a node embedded within a console can be started up properly and +// then terminated by closing the input stream. +func TestConsoleWelcome(t *testing.T) { + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" + + // Start a geth console, make sure it's cleaned up and terminate the console + geth := runGeth(t, + "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", + "--etherbase", coinbase, "--shh", + "console") + + // Gather all the infos the welcome message needs to contain + geth.setTemplateFunc("goos", func() string { return runtime.GOOS }) + geth.setTemplateFunc("gover", runtime.Version) + geth.setTemplateFunc("gethver", func() string { return verString }) + geth.setTemplateFunc("niltime", func() string { return time.Unix(0, 0).Format(time.RFC1123) }) + geth.setTemplateFunc("apis", func() []string { + apis := append(strings.Split(rpc.DefaultIPCApis, ","), rpc.MetadataApi) + sort.Strings(apis) + return apis + }) + + // Verify the actual welcome message to the required template + geth.expect(` +Welcome to the Geth JavaScript console! + +instance: Geth/v{{gethver}}/{{goos}}/{{gover}} +coinbase: {{.Etherbase}} +at block: 0 ({{niltime}}) + datadir: {{.Datadir}} + modules:{{range apis}} {{.}}:1.0{{end}} + +> {{.InputLine "exit"}} +`) + geth.expectExit() +} + +// Tests that a console can be attached to a running node via various means. +func TestIPCAttachWelcome(t *testing.T) { + // Configure the instance for IPC attachement + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" + var ipc string + if runtime.GOOS == "windows" { + ipc = `\\.\pipe\geth` + strconv.Itoa(rand.Int()) + } else { + ws := tmpdir(t) + defer os.RemoveAll(ws) + ipc = filepath.Join(ws, "geth.ipc") + } + // Note: we need --shh because testAttachWelcome checks for default + // list of ipc modules and shh is included there. + geth := runGeth(t, + "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", + "--etherbase", coinbase, "--shh", "--ipcpath", ipc) + + time.Sleep(2 * time.Second) // Simple way to wait for the RPC endpoint to open + testAttachWelcome(t, geth, "ipc:"+ipc) + + geth.interrupt() + geth.expectExit() +} + +func TestHTTPAttachWelcome(t *testing.T) { + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" + port := strconv.Itoa(rand.Intn(65535-1024) + 1024) // Yeah, sometimes this will fail, sorry :P + geth := runGeth(t, + "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", + "--etherbase", coinbase, "--rpc", "--rpcport", port) + + time.Sleep(2 * time.Second) // Simple way to wait for the RPC endpoint to open + testAttachWelcome(t, geth, "http://localhost:"+port) + + geth.interrupt() + geth.expectExit() +} + +func TestWSAttachWelcome(t *testing.T) { + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" + port := strconv.Itoa(rand.Intn(65535-1024) + 1024) // Yeah, sometimes this will fail, sorry :P + + geth := runGeth(t, + "--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", + "--etherbase", coinbase, "--ws", "--wsport", port) + + time.Sleep(2 * time.Second) // Simple way to wait for the RPC endpoint to open + testAttachWelcome(t, geth, "ws://localhost:"+port) + + geth.interrupt() + geth.expectExit() +} + +func testAttachWelcome(t *testing.T, geth *testgeth, endpoint string) { + // Attach to a running geth note and terminate immediately + attach := runGeth(t, "attach", endpoint) + defer attach.expectExit() + attach.stdin.Close() + + // Gather all the infos the welcome message needs to contain + attach.setTemplateFunc("goos", func() string { return runtime.GOOS }) + attach.setTemplateFunc("gover", runtime.Version) + attach.setTemplateFunc("gethver", func() string { return verString }) + attach.setTemplateFunc("etherbase", func() string { return geth.Etherbase }) + attach.setTemplateFunc("niltime", func() string { return time.Unix(0, 0).Format(time.RFC1123) }) + attach.setTemplateFunc("ipc", func() bool { return strings.HasPrefix(endpoint, "ipc") }) + attach.setTemplateFunc("datadir", func() string { return geth.Datadir }) + attach.setTemplateFunc("apis", func() []string { + var apis []string + if strings.HasPrefix(endpoint, "ipc") { + apis = append(strings.Split(rpc.DefaultIPCApis, ","), rpc.MetadataApi) + } else { + apis = append(strings.Split(rpc.DefaultHTTPApis, ","), rpc.MetadataApi) + } + sort.Strings(apis) + return apis + }) + + // Verify the actual welcome message to the required template + attach.expect(` +Welcome to the Geth JavaScript console! + +instance: Geth/v{{gethver}}/{{goos}}/{{gover}} +coinbase: {{etherbase}} +at block: 0 ({{niltime}}){{if ipc}} + datadir: {{datadir}}{{end}} + modules:{{range apis}} {{.}}:1.0{{end}} + +> {{.InputLine "exit" }} +`) + attach.expectExit() +} diff --git a/cmd/geth/js.go b/cmd/geth/js.go deleted file mode 100644 index 729cc2fd71..0000000000 --- a/cmd/geth/js.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package main - -import ( - "fmt" - "math/big" - "os" - "os/signal" - "path/filepath" - "regexp" - "sort" - "strings" - - "github.com/codegangsta/cli" - "github.com/ethereum/go-ethereum/accounts" - "github.com/ethereum/go-ethereum/cmd/utils" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/registrar" - "github.com/ethereum/go-ethereum/eth" - "github.com/ethereum/go-ethereum/internal/web3ext" - re "github.com/ethereum/go-ethereum/jsre" - "github.com/ethereum/go-ethereum/node" - "github.com/ethereum/go-ethereum/rpc" - "github.com/peterh/liner" - "github.com/robertkrimen/otto" -) - -var ( - passwordRegexp = regexp.MustCompile("personal.[nus]") - onlyws = regexp.MustCompile("^\\s*$") - exit = regexp.MustCompile("^\\s*exit\\s*;*\\s*$") -) - -type jsre struct { - re *re.JSRE - stack *node.Node - wait chan *big.Int - ps1 string - atexit func() - corsDomain string - client rpc.Client -} - -func makeCompleter(re *jsre) liner.WordCompleter { - return func(line string, pos int) (head string, completions []string, tail string) { - if len(line) == 0 || pos == 0 { - return "", nil, "" - } - // chuck data to relevant part for autocompletion, e.g. in case of nested lines eth.getBalance(eth.coinb - i := 0 - for i = pos - 1; i > 0; i-- { - if line[i] == '.' || (line[i] >= 'a' && line[i] <= 'z') || (line[i] >= 'A' && line[i] <= 'Z') { - continue - } - if i >= 3 && line[i] == '3' && line[i-3] == 'w' && line[i-2] == 'e' && line[i-1] == 'b' { - continue - } - i += 1 - break - } - return line[:i], re.re.CompleteKeywords(line[i:pos]), line[pos:] - } -} - -func newLightweightJSRE(docRoot string, client rpc.Client, datadir string, interactive bool) *jsre { - js := &jsre{ps1: "> "} - js.wait = make(chan *big.Int) - js.client = client - js.re = re.New(docRoot) - if err := js.apiBindings(); err != nil { - utils.Fatalf("Unable to initialize console - %v", err) - } - js.setupInput(datadir) - return js -} - -func newJSRE(stack *node.Node, docRoot, corsDomain string, client rpc.Client, interactive bool) *jsre { - js := &jsre{stack: stack, ps1: "> "} - // set default cors domain used by startRpc from CLI flag - js.corsDomain = corsDomain - js.wait = make(chan *big.Int) - js.client = client - js.re = re.New(docRoot) - if err := js.apiBindings(); err != nil { - utils.Fatalf("Unable to connect - %v", err) - } - js.setupInput(stack.DataDir()) - return js -} - -func (self *jsre) setupInput(datadir string) { - self.withHistory(datadir, func(hist *os.File) { utils.Stdin.ReadHistory(hist) }) - utils.Stdin.SetCtrlCAborts(true) - utils.Stdin.SetWordCompleter(makeCompleter(self)) - utils.Stdin.SetTabCompletionStyle(liner.TabPrints) - self.atexit = func() { - self.withHistory(datadir, func(hist *os.File) { - hist.Truncate(0) - utils.Stdin.WriteHistory(hist) - }) - utils.Stdin.Close() - close(self.wait) - } -} - -func (self *jsre) batch(statement string) { - err := self.re.EvalAndPrettyPrint(statement) - - if err != nil { - fmt.Printf("%v", jsErrorString(err)) - } - - if self.atexit != nil { - self.atexit() - } - - self.re.Stop(false) -} - -// show summary of current geth instance -func (self *jsre) welcome() { - self.re.Run(` - (function () { - console.log('instance: ' + web3.version.node); - console.log("coinbase: " + eth.coinbase); - var ts = 1000 * eth.getBlock(eth.blockNumber).timestamp; - console.log("at block: " + eth.blockNumber + " (" + new Date(ts) + ")"); - console.log(' datadir: ' + admin.datadir); - })(); - `) - if modules, err := self.supportedApis(); err == nil { - loadedModules := make([]string, 0) - for api, version := range modules { - loadedModules = append(loadedModules, fmt.Sprintf("%s:%s", api, version)) - } - sort.Strings(loadedModules) - } -} - -func (self *jsre) supportedApis() (map[string]string, error) { - return self.client.SupportedModules() -} - -func (js *jsre) apiBindings() error { - apis, err := js.supportedApis() - if err != nil { - return err - } - - apiNames := make([]string, 0, len(apis)) - for a, _ := range apis { - apiNames = append(apiNames, a) - } - - jeth := utils.NewJeth(js.re, js.client) - js.re.Set("jeth", struct{}{}) - t, _ := js.re.Get("jeth") - jethObj := t.Object() - - jethObj.Set("send", jeth.Send) - jethObj.Set("sendAsync", jeth.Send) - - err = js.re.Compile("bignumber.js", re.BigNumber_JS) - if err != nil { - utils.Fatalf("Error loading bignumber.js: %v", err) - } - - err = js.re.Compile("web3.js", re.Web3_JS) - if err != nil { - utils.Fatalf("Error loading web3.js: %v", err) - } - - _, err = js.re.Run("var Web3 = require('web3');") - if err != nil { - utils.Fatalf("Error requiring web3: %v", err) - } - - _, err = js.re.Run("var web3 = new Web3(jeth);") - if err != nil { - utils.Fatalf("Error setting web3 provider: %v", err) - } - - // load only supported API's in javascript runtime - shortcuts := "var eth = web3.eth; var personal = web3.personal; " - for _, apiName := range apiNames { - if apiName == "web3" || apiName == "rpc" { - continue // manually mapped or ignore - } - - if jsFile, ok := web3ext.Modules[apiName]; ok { - if err = js.re.Compile(fmt.Sprintf("%s.js", apiName), jsFile); err == nil { - shortcuts += fmt.Sprintf("var %s = web3.%s; ", apiName, apiName) - } else { - utils.Fatalf("Error loading %s.js: %v", apiName, err) - } - } - } - - _, err = js.re.Run(shortcuts) - if err != nil { - utils.Fatalf("Error setting namespaces: %v", err) - } - - js.re.Run(`var GlobalRegistrar = eth.contract(` + registrar.GlobalRegistrarAbi + `); registrar = GlobalRegistrar.at("` + registrar.GlobalRegistrarAddr + `");`) - - // overrule some of the methods that require password as input and ask for it interactively - p, err := js.re.Get("personal") - if err != nil { - fmt.Println("Unable to overrule sensitive methods in personal module") - return nil - } - - // Override the unlockAccount and newAccount methods on the personal object since these require user interaction. - // Assign the jeth.unlockAccount and jeth.newAccount in the jsre the original web3 callbacks. These will be called - // by the jeth.* methods after they got the password from the user and send the original web3 request to the backend. - if persObj := p.Object(); persObj != nil { // make sure the personal api is enabled over the interface - js.re.Run(`jeth.unlockAccount = personal.unlockAccount;`) - persObj.Set("unlockAccount", jeth.UnlockAccount) - js.re.Run(`jeth.newAccount = personal.newAccount;`) - persObj.Set("newAccount", jeth.NewAccount) - } - - // The admin.sleep and admin.sleepBlocks are offered by the console and not by the RPC layer. - // Bind these if the admin module is available. - if a, err := js.re.Get("admin"); err == nil { - if adminObj := a.Object(); adminObj != nil { - adminObj.Set("sleepBlocks", jeth.SleepBlocks) - adminObj.Set("sleep", jeth.Sleep) - } - } - - return nil -} - -func (self *jsre) AskPassword() (string, bool) { - pass, err := utils.Stdin.PasswordPrompt("Passphrase: ") - if err != nil { - return "", false - } - return pass, true -} - -func (self *jsre) ConfirmTransaction(tx string) bool { - // Retrieve the Ethereum instance from the node - var ethereum *eth.Ethereum - if err := self.stack.Service(ðereum); err != nil { - return false - } - // If natspec is enabled, ask for permission - if ethereum.NatSpec && false /* disabled for now */ { - // notice := natspec.GetNotice(self.xeth, tx, ethereum.HTTPClient()) - // fmt.Println(notice) - // answer, _ := self.Prompt("Confirm Transaction [y/n]") - // return strings.HasPrefix(strings.Trim(answer, " "), "y") - } - return true -} - -func (self *jsre) UnlockAccount(addr []byte) bool { - fmt.Printf("Please unlock account %x.\n", addr) - pass, err := utils.Stdin.PasswordPrompt("Passphrase: ") - if err != nil { - return false - } - // TODO: allow retry - var ethereum *eth.Ethereum - if err := self.stack.Service(ðereum); err != nil { - return false - } - a := accounts.Account{Address: common.BytesToAddress(addr)} - if err := ethereum.AccountManager().Unlock(a, pass); err != nil { - return false - } else { - fmt.Println("Account is now unlocked for this session.") - return true - } -} - -// preloadJSFiles loads JS files that the user has specified with ctx.PreLoadJSFlag into -// the JSRE. If not all files could be loaded it will return an error describing the error. -func (self *jsre) preloadJSFiles(ctx *cli.Context) error { - if ctx.GlobalString(utils.PreLoadJSFlag.Name) != "" { - assetPath := ctx.GlobalString(utils.JSpathFlag.Name) - jsFiles := strings.Split(ctx.GlobalString(utils.PreLoadJSFlag.Name), ",") - for _, file := range jsFiles { - filename := common.AbsolutePath(assetPath, strings.TrimSpace(file)) - if err := self.re.Exec(filename); err != nil { - return fmt.Errorf("%s: %v", file, jsErrorString(err)) - } - } - } - return nil -} - -// jsErrorString adds a backtrace to errors generated by otto. -func jsErrorString(err error) string { - if ottoErr, ok := err.(*otto.Error); ok { - return ottoErr.String() - } - return err.Error() -} - -func (self *jsre) interactive() { - // Read input lines. - prompt := make(chan string) - inputln := make(chan string) - go func() { - defer close(inputln) - for { - line, err := utils.Stdin.Prompt(<-prompt) - if err != nil { - if err == liner.ErrPromptAborted { // ctrl-C - self.resetPrompt() - inputln <- "" - continue - } - return - } - inputln <- line - } - }() - // Wait for Ctrl-C, too. - sig := make(chan os.Signal, 1) - signal.Notify(sig, os.Interrupt) - - defer func() { - if self.atexit != nil { - self.atexit() - } - self.re.Stop(false) - }() - for { - prompt <- self.ps1 - select { - case <-sig: - fmt.Println("caught interrupt, exiting") - return - case input, ok := <-inputln: - if !ok || indentCount <= 0 && exit.MatchString(input) { - return - } - if onlyws.MatchString(input) { - continue - } - str += input + "\n" - self.setIndent() - if indentCount <= 0 { - if !excludeFromHistory(str) { - utils.Stdin.AppendHistory(str[:len(str)-1]) - } - self.parseInput(str) - str = "" - } - } - } -} - -func excludeFromHistory(input string) bool { - return len(input) == 0 || input[0] == ' ' || passwordRegexp.MatchString(input) -} - -func (self *jsre) withHistory(datadir string, op func(*os.File)) { - hist, err := os.OpenFile(filepath.Join(datadir, "history"), os.O_RDWR|os.O_CREATE, os.ModePerm) - if err != nil { - fmt.Printf("unable to open history file: %v\n", err) - return - } - op(hist) - hist.Close() -} - -func (self *jsre) parseInput(code string) { - defer func() { - if r := recover(); r != nil { - fmt.Println("[native] error", r) - } - }() - if err := self.re.EvalAndPrettyPrint(code); err != nil { - if ottoErr, ok := err.(*otto.Error); ok { - fmt.Println(ottoErr.String()) - } else { - fmt.Println(err) - } - return - } -} - -var indentCount = 0 -var str = "" - -func (self *jsre) resetPrompt() { - indentCount = 0 - str = "" - self.ps1 = "> " -} - -func (self *jsre) setIndent() { - open := strings.Count(str, "{") - open += strings.Count(str, "(") - closed := strings.Count(str, "}") - closed += strings.Count(str, ")") - indentCount = open - closed - if indentCount <= 0 { - self.ps1 = "> " - } else { - self.ps1 = strings.Join(make([]string, indentCount*2), "..") - self.ps1 += " " - } -} diff --git a/cmd/geth/js_test.go b/cmd/geth/js_test.go deleted file mode 100644 index ddfe0d4000..0000000000 --- a/cmd/geth/js_test.go +++ /dev/null @@ -1,500 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package main - -import ( - "fmt" - "io/ioutil" - "math/big" - "os" - "path/filepath" - "regexp" - "runtime" - "strconv" - "testing" - "time" - - "github.com/ethereum/go-ethereum/accounts" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/compiler" - "github.com/ethereum/go-ethereum/common/httpclient" - "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/crypto" - "github.com/ethereum/go-ethereum/eth" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/node" -) - -const ( - testSolcPath = "" - solcVersion = "0.9.23" - testAddress = "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" - testBalance = "10000000000000000000" - // of empty string - testHash = "0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470" -) - -var ( - versionRE = regexp.MustCompile(strconv.Quote(`"compilerVersion":"` + solcVersion + `"`)) - testNodeKey, _ = crypto.HexToECDSA("4b50fa71f5c3eeb8fdc452224b2395af2fcc3d125e06c32c82e048c0559db03f") - testAccount, _ = crypto.HexToECDSA("e6fab74a43941f82d89cb7faa408e227cdad3153c4720e540e855c19b15e6674") - testGenesis = `{"` + testAddress[2:] + `": {"balance": "` + testBalance + `"}}` -) - -type testjethre struct { - *jsre - lastConfirm string - client *httpclient.HTTPClient -} - -// Temporary disabled while natspec hasn't been migrated -//func (self *testjethre) ConfirmTransaction(tx string) bool { -// var ethereum *eth.Ethereum -// self.stack.Service(ðereum) -// -// if ethereum.NatSpec { -// self.lastConfirm = natspec.GetNotice(self.xeth, tx, self.client) -// } -// return true -//} - -func testJEthRE(t *testing.T) (string, *testjethre, *node.Node) { - return testREPL(t, nil) -} - -func testREPL(t *testing.T, config func(*eth.Config)) (string, *testjethre, *node.Node) { - tmp, err := ioutil.TempDir("", "geth-test") - if err != nil { - t.Fatal(err) - } - // Create a networkless protocol stack - stack, err := node.New(&node.Config{DataDir: tmp, PrivateKey: testNodeKey, Name: "test", NoDiscovery: true}) - if err != nil { - t.Fatalf("failed to create node: %v", err) - } - // Initialize and register the Ethereum protocol - accman := accounts.NewPlaintextManager(filepath.Join(tmp, "keystore")) - db, _ := ethdb.NewMemDatabase() - core.WriteGenesisBlockForTesting(db, core.GenesisAccount{ - Address: common.HexToAddress(testAddress), - Balance: common.String2Big(testBalance), - }) - ethConf := ð.Config{ - ChainConfig: &core.ChainConfig{HomesteadBlock: new(big.Int)}, - TestGenesisState: db, - AccountManager: accman, - DocRoot: "/", - SolcPath: testSolcPath, - PowTest: true, - } - if config != nil { - config(ethConf) - } - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - return eth.New(ctx, ethConf) - }); err != nil { - t.Fatalf("failed to register ethereum protocol: %v", err) - } - // Initialize all the keys for testing - a, err := accman.ImportECDSA(testAccount, "") - if err != nil { - t.Fatal(err) - } - if err := accman.Unlock(a, ""); err != nil { - t.Fatal(err) - } - // Start the node and assemble the REPL tester - if err := stack.Start(); err != nil { - t.Fatalf("failed to start test stack: %v", err) - } - var ethereum *eth.Ethereum - stack.Service(ðereum) - - assetPath := filepath.Join(os.Getenv("GOPATH"), "src", "github.com", "ethereum", "go-ethereum", "cmd", "mist", "assets", "ext") - client, err := stack.Attach() - if err != nil { - t.Fatalf("failed to attach to node: %v", err) - } - tf := &testjethre{client: ethereum.HTTPClient()} - repl := newJSRE(stack, assetPath, "", client, false) - tf.jsre = repl - return tmp, tf, stack -} - -func TestNodeInfo(t *testing.T) { - t.Skip("broken after p2p update") - tmp, repl, ethereum := testJEthRE(t) - defer ethereum.Stop() - defer os.RemoveAll(tmp) - - want := `{"DiscPort":0,"IP":"0.0.0.0","ListenAddr":"","Name":"test","NodeID":"4cb2fc32924e94277bf94b5e4c983beedb2eabd5a0bc941db32202735c6625d020ca14a5963d1738af43b6ac0a711d61b1a06de931a499fe2aa0b1a132a902b5","NodeUrl":"enode://4cb2fc32924e94277bf94b5e4c983beedb2eabd5a0bc941db32202735c6625d020ca14a5963d1738af43b6ac0a711d61b1a06de931a499fe2aa0b1a132a902b5@0.0.0.0:0","TCPPort":0,"Td":"131072"}` - checkEvalJSON(t, repl, `admin.nodeInfo`, want) -} - -func TestAccounts(t *testing.T) { - tmp, repl, node := testJEthRE(t) - defer node.Stop() - defer os.RemoveAll(tmp) - - checkEvalJSON(t, repl, `eth.accounts`, `["`+testAddress+`"]`) - checkEvalJSON(t, repl, `eth.coinbase`, `"`+testAddress+`"`) - val, err := repl.re.Run(`jeth.newAccount("password")`) - if err != nil { - t.Errorf("expected no error, got %v", err) - } - addr := val.String() - if !regexp.MustCompile(`0x[0-9a-f]{40}`).MatchString(addr) { - t.Errorf("address not hex: %q", addr) - } - - checkEvalJSON(t, repl, `eth.accounts`, `["`+testAddress+`","`+addr+`"]`) - -} - -func TestBlockChain(t *testing.T) { - tmp, repl, node := testJEthRE(t) - defer node.Stop() - defer os.RemoveAll(tmp) - // get current block dump before export/import. - val, err := repl.re.Run("JSON.stringify(debug.dumpBlock(eth.blockNumber))") - if err != nil { - t.Errorf("expected no error, got %v", err) - } - beforeExport := val.String() - - // do the export - extmp, err := ioutil.TempDir("", "geth-test-export") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(extmp) - tmpfile := filepath.Join(extmp, "export.chain") - tmpfileq := strconv.Quote(tmpfile) - - var ethereum *eth.Ethereum - node.Service(ðereum) - ethereum.BlockChain().Reset() - - checkEvalJSON(t, repl, `admin.exportChain(`+tmpfileq+`)`, `true`) - if _, err := os.Stat(tmpfile); err != nil { - t.Fatal(err) - } - - // check import, verify that dumpBlock gives the same result. - checkEvalJSON(t, repl, `admin.importChain(`+tmpfileq+`)`, `true`) - checkEvalJSON(t, repl, `debug.dumpBlock(eth.blockNumber)`, beforeExport) -} - -func TestMining(t *testing.T) { - tmp, repl, node := testJEthRE(t) - defer node.Stop() - defer os.RemoveAll(tmp) - checkEvalJSON(t, repl, `eth.mining`, `false`) -} - -func TestRPC(t *testing.T) { - tmp, repl, node := testJEthRE(t) - defer node.Stop() - defer os.RemoveAll(tmp) - - checkEvalJSON(t, repl, `admin.startRPC("127.0.0.1", 5004, "*", "web3,eth,net")`, `true`) -} - -func TestCheckTestAccountBalance(t *testing.T) { - t.Skip() // i don't think it tests the correct behaviour here. it's actually testing - // internals which shouldn't be tested. This now fails because of a change in the core - // and i have no means to fix this, sorry - @obscuren - tmp, repl, node := testJEthRE(t) - defer node.Stop() - defer os.RemoveAll(tmp) - - repl.re.Run(`primary = "` + testAddress + `"`) - checkEvalJSON(t, repl, `eth.getBalance(primary)`, `"`+testBalance+`"`) -} - -func TestSignature(t *testing.T) { - tmp, repl, node := testJEthRE(t) - defer node.Stop() - defer os.RemoveAll(tmp) - - val, err := repl.re.Run(`eth.sign("` + testAddress + `", "` + testHash + `")`) - - // This is a very preliminary test, lacking actual signature verification - if err != nil { - t.Errorf("Error running js: %v", err) - return - } - output := val.String() - t.Logf("Output: %v", output) - - regex := regexp.MustCompile(`^0x[0-9a-f]{130}$`) - if !regex.MatchString(output) { - t.Errorf("Signature is not 65 bytes represented in hexadecimal.") - return - } -} - -func TestContract(t *testing.T) { - t.Skip("contract testing is implemented with mining in ethash test mode. This takes about 7seconds to run. Unskip and run on demand") - coinbase := common.HexToAddress(testAddress) - tmp, repl, ethereum := testREPL(t, func(conf *eth.Config) { - conf.Etherbase = coinbase - conf.PowTest = true - }) - if err := ethereum.Start(); err != nil { - t.Errorf("error starting ethereum: %v", err) - return - } - defer ethereum.Stop() - defer os.RemoveAll(tmp) - - // Temporary disabled while registrar isn't migrated - //reg := registrar.New(repl.xeth) - //_, err := reg.SetGlobalRegistrar("", coinbase) - //if err != nil { - // t.Errorf("error setting HashReg: %v", err) - //} - //_, err = reg.SetHashReg("", coinbase) - //if err != nil { - // t.Errorf("error setting HashReg: %v", err) - //} - //_, err = reg.SetUrlHint("", coinbase) - //if err != nil { - // t.Errorf("error setting HashReg: %v", err) - //} - /* TODO: - * lookup receipt and contract addresses by tx hash - * name registration for HashReg and UrlHint addresses - * mine those transactions - * then set once more SetHashReg SetUrlHint - */ - - source := `contract test {\n` + - " /// @notice Will multiply `a` by 7." + `\n` + - ` function multiply(uint a) returns(uint d) {\n` + - ` return a * 7;\n` + - ` }\n` + - `}\n` - - if checkEvalJSON(t, repl, `admin.stopNatSpec()`, `true`) != nil { - return - } - - contractInfo, err := ioutil.ReadFile("info_test.json") - if err != nil { - t.Fatalf("%v", err) - } - if checkEvalJSON(t, repl, `primary = eth.accounts[0]`, `"`+testAddress+`"`) != nil { - return - } - if checkEvalJSON(t, repl, `source = "`+source+`"`, `"`+source+`"`) != nil { - return - } - - // if solc is found with right version, test it, otherwise read from file - sol, err := compiler.New("") - if err != nil { - t.Logf("solc not found: mocking contract compilation step") - } else if sol.Version() != solcVersion { - t.Logf("WARNING: solc different version found (%v, test written for %v, may need to update)", sol.Version(), solcVersion) - } - - if err != nil { - info, err := ioutil.ReadFile("info_test.json") - if err != nil { - t.Fatalf("%v", err) - } - _, err = repl.re.Run(`contract = JSON.parse(` + strconv.Quote(string(info)) + `)`) - if err != nil { - t.Errorf("%v", err) - } - } else { - if checkEvalJSON(t, repl, `contract = eth.compile.solidity(source).test`, string(contractInfo)) != nil { - return - } - } - - if checkEvalJSON(t, repl, `contract.code`, `"0x605880600c6000396000f3006000357c010000000000000000000000000000000000000000000000000000000090048063c6888fa114602e57005b603d6004803590602001506047565b8060005260206000f35b60006007820290506053565b91905056"`) != nil { - return - } - - if checkEvalJSON( - t, repl, - `contractaddress = eth.sendTransaction({from: primary, data: contract.code})`, - `"0x46d69d55c3c4b86a924a92c9fc4720bb7bce1d74"`, - ) != nil { - return - } - - if !processTxs(repl, t, 8) { - return - } - - callSetup := `abiDef = JSON.parse('[{"constant":false,"inputs":[{"name":"a","type":"uint256"}],"name":"multiply","outputs":[{"name":"d","type":"uint256"}],"type":"function"}]'); -Multiply7 = eth.contract(abiDef); -multiply7 = Multiply7.at(contractaddress); -` - _, err = repl.re.Run(callSetup) - if err != nil { - t.Errorf("unexpected error setting up contract, got %v", err) - return - } - - expNotice := "" - if repl.lastConfirm != expNotice { - t.Errorf("incorrect confirmation message: expected %v, got %v", expNotice, repl.lastConfirm) - return - } - - if checkEvalJSON(t, repl, `admin.startNatSpec()`, `true`) != nil { - return - } - if checkEvalJSON(t, repl, `multiply7.multiply.sendTransaction(6, { from: primary })`, `"0x4ef9088431a8033e4580d00e4eb2487275e031ff4163c7529df0ef45af17857b"`) != nil { - return - } - - if !processTxs(repl, t, 1) { - return - } - - expNotice = `About to submit transaction (no NatSpec info found for contract: content hash not found for '0x87e2802265838c7f14bb69eecd2112911af6767907a702eeaa445239fb20711b'): {"params":[{"to":"0x46d69d55c3c4b86a924a92c9fc4720bb7bce1d74","data": "0xc6888fa10000000000000000000000000000000000000000000000000000000000000006"}]}` - if repl.lastConfirm != expNotice { - t.Errorf("incorrect confirmation message: expected\n%v, got\n%v", expNotice, repl.lastConfirm) - return - } - - var contentHash = `"0x86d2b7cf1e72e9a7a3f8d96601f0151742a2f780f1526414304fbe413dc7f9bd"` - if sol != nil && solcVersion != sol.Version() { - modContractInfo := versionRE.ReplaceAll(contractInfo, []byte(`"compilerVersion":"`+sol.Version()+`"`)) - fmt.Printf("modified contractinfo:\n%s\n", modContractInfo) - contentHash = `"` + common.ToHex(crypto.Keccak256([]byte(modContractInfo))) + `"` - } - if checkEvalJSON(t, repl, `filename = "/tmp/info.json"`, `"/tmp/info.json"`) != nil { - return - } - if checkEvalJSON(t, repl, `contentHash = admin.saveInfo(contract.info, filename)`, contentHash) != nil { - return - } - if checkEvalJSON(t, repl, `admin.register(primary, contractaddress, contentHash)`, `true`) != nil { - return - } - if checkEvalJSON(t, repl, `admin.registerUrl(primary, contentHash, "file://"+filename)`, `true`) != nil { - return - } - - if checkEvalJSON(t, repl, `admin.startNatSpec()`, `true`) != nil { - return - } - - if !processTxs(repl, t, 3) { - return - } - - if checkEvalJSON(t, repl, `multiply7.multiply.sendTransaction(6, { from: primary })`, `"0x66d7635c12ad0b231e66da2f987ca3dfdca58ffe49c6442aa55960858103fd0c"`) != nil { - return - } - - if !processTxs(repl, t, 1) { - return - } - - expNotice = "Will multiply 6 by 7." - if repl.lastConfirm != expNotice { - t.Errorf("incorrect confirmation message: expected\n%v, got\n%v", expNotice, repl.lastConfirm) - return - } -} - -func pendingTransactions(repl *testjethre, t *testing.T) (txc int64, err error) { - var ethereum *eth.Ethereum - repl.stack.Service(ðereum) - - txs := ethereum.TxPool().GetTransactions() - return int64(len(txs)), nil -} - -func processTxs(repl *testjethre, t *testing.T, expTxc int) bool { - var txc int64 - var err error - for i := 0; i < 50; i++ { - txc, err = pendingTransactions(repl, t) - if err != nil { - t.Errorf("unexpected error checking pending transactions: %v", err) - return false - } - if expTxc < int(txc) { - t.Errorf("too many pending transactions: expected %v, got %v", expTxc, txc) - return false - } else if expTxc == int(txc) { - break - } - time.Sleep(100 * time.Millisecond) - } - if int(txc) != expTxc { - t.Errorf("incorrect number of pending transactions, expected %v, got %v", expTxc, txc) - return false - } - var ethereum *eth.Ethereum - repl.stack.Service(ðereum) - - err = ethereum.StartMining(runtime.NumCPU(), "") - if err != nil { - t.Errorf("unexpected error mining: %v", err) - return false - } - defer ethereum.StopMining() - - timer := time.NewTimer(100 * time.Second) - blockNr := ethereum.BlockChain().CurrentBlock().Number() - height := new(big.Int).Add(blockNr, big.NewInt(1)) - repl.wait <- height - select { - case <-timer.C: - // if times out make sure the xeth loop does not block - go func() { - select { - case repl.wait <- nil: - case <-repl.wait: - } - }() - case <-repl.wait: - } - txc, err = pendingTransactions(repl, t) - if err != nil { - t.Errorf("unexpected error checking pending transactions: %v", err) - return false - } - if txc != 0 { - t.Errorf("%d trasactions were not mined", txc) - return false - } - return true -} - -func checkEvalJSON(t *testing.T, re *testjethre, expr, want string) error { - val, err := re.re.Run("JSON.stringify(" + expr + ")") - if err == nil && val.String() != want { - err = fmt.Errorf("Output mismatch for `%s`:\ngot: %s\nwant: %s", expr, val.String(), want) - } - if err != nil { - _, file, line, _ := runtime.Caller(1) - file = filepath.Base(file) - fmt.Printf("\t%s:%d: %v\n", file, line, err) - t.Fail() - } - return err -} diff --git a/cmd/geth/main.go b/cmd/geth/main.go index e94b76594d..2639147c44 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -22,7 +22,6 @@ import ( "fmt" "io/ioutil" "os" - "os/signal" "path/filepath" "runtime" "strconv" @@ -33,6 +32,7 @@ import ( "github.com/ethereum/ethash" "github.com/ethereum/go-ethereum/cmd/utils" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/console" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/eth" "github.com/ethereum/go-ethereum/ethdb" @@ -50,7 +50,7 @@ const ( clientIdentifier = "Geth" // Client identifier to advertise over the network versionMajor = 1 // Major version component of the current release versionMinor = 4 // Minor version component of the current release - versionPatch = 5 // Patch version component of the current release + versionPatch = 6 // Patch version component of the current release versionMeta = "stable" // Version metadata to append to the version string versionOracle = "0xfa7b9770ca4cb04296cac84f37736d4041251cdf" // Ethereum address of the Geth release oracle @@ -95,6 +95,9 @@ func init() { monitorCommand, accountCommand, walletCommand, + consoleCommand, + attachCommand, + javascriptCommand, { Action: makedag, Name: "makedag", @@ -138,36 +141,6 @@ The output of this command is supposed to be machine-readable. The init command initialises a new genesis block and definition for the network. This is a destructive action and changes the network in which you will be participating. -`, - }, - { - Action: console, - Name: "console", - Usage: `Geth Console: interactive JavaScript environment`, - Description: ` -The Geth console is an interactive shell for the JavaScript runtime environment -which exposes a node admin interface as well as the Ðapp JavaScript API. -See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console -`, - }, - { - Action: attach, - Name: "attach", - Usage: `Geth Console: interactive JavaScript environment (connect to node)`, - Description: ` - The Geth console is an interactive shell for the JavaScript runtime environment - which exposes a node admin interface as well as the Ðapp JavaScript API. - See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console. - This command allows to open a console on a running geth node. - `, - }, - { - Action: execScripts, - Name: "js", - Usage: `executes the given JavaScript files in the Geth JavaScript VM`, - Description: ` -The JavaScript VM exposes a node admin interface as well as the Ðapp -JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Console `, }, } @@ -214,7 +187,7 @@ JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Conso utils.IPCApiFlag, utils.IPCPathFlag, utils.ExecFlag, - utils.PreLoadJSFlag, + utils.PreloadJSFlag, utils.WhisperEnabledFlag, utils.DevModeFlag, utils.TestNetFlag, @@ -244,6 +217,12 @@ JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Conso // Start system runtime metrics collection go metrics.CollectProcessMetrics(3 * time.Second) + // This should be the only place where reporting is enabled + // because it is not intended to run while testing. + // In addition to this check, bad block reports are sent only + // for chains with the main network genesis block and network id 1. + eth.EnableBadBlockReporting = true + utils.SetupNetwork(ctx) // Deprecation warning. @@ -257,7 +236,7 @@ JavaScript API. See https://github.com/ethereum/go-ethereum/wiki/Javascipt-Conso app.After = func(ctx *cli.Context) error { logger.Flush() debug.Exit() - utils.Stdin.Close() // Resets terminal mode. + console.Stdin.Close() // Resets terminal mode. return nil } } @@ -298,36 +277,6 @@ func geth(ctx *cli.Context) { node.Wait() } -// attach will connect to a running geth instance attaching a JavaScript console and to it. -func attach(ctx *cli.Context) { - // attach to a running geth instance - client, err := utils.NewRemoteRPCClient(ctx) - if err != nil { - utils.Fatalf("Unable to attach to geth: %v", err) - } - - repl := newLightweightJSRE( - ctx.GlobalString(utils.JSpathFlag.Name), - client, - ctx.GlobalString(utils.DataDirFlag.Name), - true, - ) - - // preload user defined JS files into the console - err = repl.preloadJSFiles(ctx) - if err != nil { - utils.Fatalf("unable to preload JS file %v", err) - } - - // in case the exec flag holds a JS statement execute it and return - if ctx.GlobalString(utils.ExecFlag.Name) != "" { - repl.batch(ctx.GlobalString(utils.ExecFlag.Name)) - } else { - repl.welcome() - repl.interactive() - } -} - // initGenesis will initialise the given JSON format genesis file and writes it as // the zero'd block (i.e. genesis) or will fail hard if it can't succeed. func initGenesis(ctx *cli.Context) { @@ -353,77 +302,6 @@ func initGenesis(ctx *cli.Context) { glog.V(logger.Info).Infof("successfully wrote genesis block and/or chain rule set: %x", block.Hash()) } -// console starts a new geth node, attaching a JavaScript console to it at the -// same time. -func console(ctx *cli.Context) { - // Create and start the node based on the CLI flags - node := utils.MakeSystemNode(clientIdentifier, verString, relConfig, makeDefaultExtra(), ctx) - startNode(ctx, node) - - // Attach to the newly started node, and either execute script or become interactive - client, err := node.Attach() - if err != nil { - utils.Fatalf("Failed to attach to the inproc geth: %v", err) - } - repl := newJSRE(node, - ctx.GlobalString(utils.JSpathFlag.Name), - ctx.GlobalString(utils.RPCCORSDomainFlag.Name), - client, true) - - // preload user defined JS files into the console - err = repl.preloadJSFiles(ctx) - if err != nil { - utils.Fatalf("%v", err) - } - - // in case the exec flag holds a JS statement execute it and return - if script := ctx.GlobalString(utils.ExecFlag.Name); script != "" { - repl.batch(script) - } else { - repl.welcome() - repl.interactive() - } - node.Stop() -} - -// execScripts starts a new geth node based on the CLI flags, and executes each -// of the JavaScript files specified as command arguments. -func execScripts(ctx *cli.Context) { - // Create and start the node based on the CLI flags - node := utils.MakeSystemNode(clientIdentifier, verString, relConfig, makeDefaultExtra(), ctx) - startNode(ctx, node) - defer node.Stop() - - // Attach to the newly started node and execute the given scripts - client, err := node.Attach() - if err != nil { - utils.Fatalf("Failed to attach to the inproc geth: %v", err) - } - repl := newJSRE(node, - ctx.GlobalString(utils.JSpathFlag.Name), - ctx.GlobalString(utils.RPCCORSDomainFlag.Name), - client, false) - - // Run all given files. - for _, file := range ctx.Args() { - if err = repl.re.Exec(file); err != nil { - break - } - } - if err != nil { - utils.Fatalf("JavaScript Error: %v", jsErrorString(err)) - } - // JS files loaded successfully. - // Wait for pending callbacks, but stop for Ctrl-C. - abort := make(chan os.Signal, 1) - signal.Notify(abort, os.Interrupt) - go func() { - <-abort - repl.re.Stop(false) - }() - repl.re.Stop(true) -} - // startNode boots up the system node and all registered protocols, after which // it unlocks any requested accounts, and starts the RPC/IPC interfaces and the // miner. diff --git a/cmd/geth/run_test.go b/cmd/geth/run_test.go index a82eb9d68a..f6bc3f869d 100644 --- a/cmd/geth/run_test.go +++ b/cmd/geth/run_test.go @@ -20,7 +20,6 @@ import ( "bufio" "bytes" "fmt" - "html/template" "io" "io/ioutil" "os" @@ -28,6 +27,7 @@ import ( "regexp" "sync" "testing" + "text/template" "time" ) @@ -45,6 +45,7 @@ type testgeth struct { // template variables for expect Datadir string Executable string + Etherbase string Func template.FuncMap removeDatadir bool @@ -67,11 +68,15 @@ func init() { func runGeth(t *testing.T, args ...string) *testgeth { tt := &testgeth{T: t, Executable: os.Args[0]} for i, arg := range args { - if arg == "-datadir" || arg == "--datadir" { + switch { + case arg == "-datadir" || arg == "--datadir": if i < len(args)-1 { tt.Datadir = args[i+1] } - break + case arg == "-etherbase" || arg == "--etherbase": + if i < len(args)-1 { + tt.Etherbase = args[i+1] + } } } if tt.Datadir == "" { diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 90019d7b97..01a71c1f65 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -101,7 +101,7 @@ var AppHelpFlagGroups = []flagGroup{ utils.RPCCORSDomainFlag, utils.JSpathFlag, utils.ExecFlag, - utils.PreLoadJSFlag, + utils.PreloadJSFlag, }, }, { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 43dbc37f74..c476e1c779 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -302,7 +302,7 @@ var ( Name: "exec", Usage: "Execute JavaScript statement (only in combination with console/attach)", } - PreLoadJSFlag = cli.StringFlag{ + PreloadJSFlag = cli.StringFlag{ Name: "preload", Usage: "Comma separated list of JavaScript files to preload into the console", } @@ -864,3 +864,20 @@ func MakeChain(ctx *cli.Context) (chain *core.BlockChain, chainDb ethdb.Database } return chain, chainDb } + +// MakeConsolePreloads retrieves the absolute paths for the console JavaScript +// scripts to preload before starting. +func MakeConsolePreloads(ctx *cli.Context) []string { + // Skip preloading if there's nothing to preload + if ctx.GlobalString(PreloadJSFlag.Name) == "" { + return nil + } + // Otherwise resolve absolute paths and return them + preloads := []string{} + + assets := ctx.GlobalString(JSpathFlag.Name) + for _, file := range strings.Split(ctx.GlobalString(PreloadJSFlag.Name), ",") { + preloads = append(preloads, common.AbsolutePath(assets, strings.TrimSpace(file))) + } + return preloads +} diff --git a/cmd/utils/input.go b/cmd/utils/input.go deleted file mode 100644 index 523d5a5870..0000000000 --- a/cmd/utils/input.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2016 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package utils - -import ( - "fmt" - "strings" - - "github.com/peterh/liner" -) - -// Holds the stdin line reader. -// Only this reader may be used for input because it keeps -// an internal buffer. -var Stdin = newUserInputReader() - -type userInputReader struct { - *liner.State - warned bool - supported bool - normalMode liner.ModeApplier - rawMode liner.ModeApplier -} - -func newUserInputReader() *userInputReader { - r := new(userInputReader) - // Get the original mode before calling NewLiner. - // This is usually regular "cooked" mode where characters echo. - normalMode, _ := liner.TerminalMode() - // Turn on liner. It switches to raw mode. - r.State = liner.NewLiner() - rawMode, err := liner.TerminalMode() - if err != nil || !liner.TerminalSupported() { - r.supported = false - } else { - r.supported = true - r.normalMode = normalMode - r.rawMode = rawMode - // Switch back to normal mode while we're not prompting. - normalMode.ApplyMode() - } - return r -} - -func (r *userInputReader) Prompt(prompt string) (string, error) { - if r.supported { - r.rawMode.ApplyMode() - defer r.normalMode.ApplyMode() - } else { - // liner tries to be smart about printing the prompt - // and doesn't print anything if input is redirected. - // Un-smart it by printing the prompt always. - fmt.Print(prompt) - prompt = "" - defer fmt.Println() - } - return r.State.Prompt(prompt) -} - -func (r *userInputReader) PasswordPrompt(prompt string) (passwd string, err error) { - if r.supported { - r.rawMode.ApplyMode() - defer r.normalMode.ApplyMode() - return r.State.PasswordPrompt(prompt) - } - if !r.warned { - fmt.Println("!! Unsupported terminal, password will be echoed.") - r.warned = true - } - // Just as in Prompt, handle printing the prompt here instead of relying on liner. - fmt.Print(prompt) - passwd, err = r.State.Prompt("") - fmt.Println() - return passwd, err -} - -func (r *userInputReader) ConfirmPrompt(prompt string) (bool, error) { - prompt = prompt + " [y/N] " - input, err := r.Prompt(prompt) - if len(input) > 0 && strings.ToUpper(input[:1]) == "Y" { - return true, nil - } - return false, err -} diff --git a/cmd/utils/jeth.go b/cmd/utils/jeth.go deleted file mode 100644 index 9410180b01..0000000000 --- a/cmd/utils/jeth.go +++ /dev/null @@ -1,301 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of go-ethereum. -// -// go-ethereum is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// go-ethereum is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with go-ethereum. If not, see . - -package utils - -import ( - "encoding/json" - "fmt" - "time" - - "github.com/ethereum/go-ethereum/jsre" - "github.com/ethereum/go-ethereum/rpc" - - "github.com/robertkrimen/otto" -) - -type Jeth struct { - re *jsre.JSRE - client rpc.Client -} - -// NewJeth create a new backend for the JSRE console -func NewJeth(re *jsre.JSRE, client rpc.Client) *Jeth { - return &Jeth{re, client} -} - -// err returns an error object for the given error code and message. -func (self *Jeth) err(call otto.FunctionCall, code int, msg string, id interface{}) (response otto.Value) { - m := rpc.JSONErrResponse{ - Version: "2.0", - Id: id, - Error: rpc.JSONError{ - Code: code, - Message: msg, - }, - } - - errObj, _ := json.Marshal(m.Error) - errRes, _ := json.Marshal(m) - - call.Otto.Run("ret_error = " + string(errObj)) - res, _ := call.Otto.Run("ret_response = " + string(errRes)) - - return res -} - -// UnlockAccount asks the user for the password and than executes the jeth.UnlockAccount callback in the jsre. -// It will need the public address for the account to unlock as first argument. -// The second argument is an optional string with the password. If not given the user is prompted for the password. -// The third argument is an optional integer which specifies for how long the account will be unlocked (in seconds). -func (self *Jeth) UnlockAccount(call otto.FunctionCall) (response otto.Value) { - var account, passwd otto.Value - duration := otto.NullValue() - - if !call.Argument(0).IsString() { - fmt.Println("first argument must be the account to unlock") - return otto.FalseValue() - } - - account = call.Argument(0) - - // if password is not given or as null value -> ask user for password - if call.Argument(1).IsUndefined() || call.Argument(1).IsNull() { - fmt.Printf("Unlock account %s\n", account) - if input, err := Stdin.PasswordPrompt("Passphrase: "); err != nil { - throwJSExeception(err.Error()) - } else { - passwd, _ = otto.ToValue(input) - } - } else { - if !call.Argument(1).IsString() { - throwJSExeception("password must be a string") - } - passwd = call.Argument(1) - } - - // third argument is the duration how long the account must be unlocked. - // verify that its a number. - if call.Argument(2).IsDefined() && !call.Argument(2).IsNull() { - if !call.Argument(2).IsNumber() { - throwJSExeception("unlock duration must be a number") - } - duration = call.Argument(2) - } - - // jeth.unlockAccount will send the request to the backend. - if val, err := call.Otto.Call("jeth.unlockAccount", nil, account, passwd, duration); err == nil { - return val - } else { - throwJSExeception(err.Error()) - } - - return otto.FalseValue() -} - -// NewAccount asks the user for the password and than executes the jeth.newAccount callback in the jsre -func (self *Jeth) NewAccount(call otto.FunctionCall) (response otto.Value) { - var passwd string - if len(call.ArgumentList) == 0 { - var err error - passwd, err = Stdin.PasswordPrompt("Passphrase: ") - if err != nil { - return otto.FalseValue() - } - passwd2, err := Stdin.PasswordPrompt("Repeat passphrase: ") - if err != nil { - return otto.FalseValue() - } - - if passwd != passwd2 { - fmt.Println("Passphrases don't match") - return otto.FalseValue() - } - } else if len(call.ArgumentList) == 1 && call.Argument(0).IsString() { - passwd, _ = call.Argument(0).ToString() - } else { - fmt.Println("expected 0 or 1 string argument") - return otto.FalseValue() - } - - ret, err := call.Otto.Call("jeth.newAccount", nil, passwd) - if err == nil { - return ret - } - fmt.Println(err) - return otto.FalseValue() -} - -// Send will serialize the first argument, send it to the node and returns the response. -func (self *Jeth) Send(call otto.FunctionCall) (response otto.Value) { - // verify we got a batch request (array) or a single request (object) - ro := call.Argument(0).Object() - if ro == nil || (ro.Class() != "Array" && ro.Class() != "Object") { - throwJSExeception("Internal Error: request must be an object or array") - } - - // convert otto vm arguments to go values by JSON serialising and parsing. - data, err := call.Otto.Call("JSON.stringify", nil, ro) - if err != nil { - throwJSExeception(err.Error()) - } - - jsonreq, _ := data.ToString() - - // parse arguments to JSON rpc requests, either to an array (batch) or to a single request. - var reqs []rpc.JSONRequest - batch := true - if err = json.Unmarshal([]byte(jsonreq), &reqs); err != nil { - // single request? - reqs = make([]rpc.JSONRequest, 1) - if err = json.Unmarshal([]byte(jsonreq), &reqs[0]); err != nil { - throwJSExeception("invalid request") - } - batch = false - } - - call.Otto.Set("response_len", len(reqs)) - call.Otto.Run("var ret_response = new Array(response_len);") - - for i, req := range reqs { - if err := self.client.Send(&req); err != nil { - return self.err(call, -32603, err.Error(), req.Id) - } - - result := make(map[string]interface{}) - if err = self.client.Recv(&result); err != nil { - return self.err(call, -32603, err.Error(), req.Id) - } - - id, _ := result["id"] - jsonver, _ := result["jsonrpc"] - - call.Otto.Set("ret_id", id) - call.Otto.Set("ret_jsonrpc", jsonver) - call.Otto.Set("response_idx", i) - - // call was successful - if res, ok := result["result"]; ok { - payload, _ := json.Marshal(res) - call.Otto.Set("ret_result", string(payload)) - response, err = call.Otto.Run(` - ret_response[response_idx] = { jsonrpc: ret_jsonrpc, id: ret_id, result: JSON.parse(ret_result) }; - `) - continue - } - - // request returned an error - if res, ok := result["error"]; ok { - payload, _ := json.Marshal(res) - call.Otto.Set("ret_result", string(payload)) - response, err = call.Otto.Run(` - ret_response[response_idx] = { jsonrpc: ret_jsonrpc, id: ret_id, error: JSON.parse(ret_result) }; - `) - continue - } - - return self.err(call, -32603, fmt.Sprintf("Invalid response"), new(int64)) - } - - if !batch { - call.Otto.Run("ret_response = ret_response[0];") - } - - // if a callback was given execute it. - if call.Argument(1).IsObject() { - call.Otto.Set("callback", call.Argument(1)) - call.Otto.Run(` - if (Object.prototype.toString.call(callback) == '[object Function]') { - callback(null, ret_response); - } - `) - } - - return -} - -// throwJSExeception panics on an otto value, the Otto VM will then throw msg as a javascript error. -func throwJSExeception(msg interface{}) otto.Value { - p, _ := otto.ToValue(msg) - panic(p) -} - -// Sleep will halt the console for arg[0] seconds. -func (self *Jeth) Sleep(call otto.FunctionCall) (response otto.Value) { - if len(call.ArgumentList) >= 1 { - if call.Argument(0).IsNumber() { - sleep, _ := call.Argument(0).ToInteger() - time.Sleep(time.Duration(sleep) * time.Second) - return otto.TrueValue() - } - } - return throwJSExeception("usage: sleep()") -} - -// SleepBlocks will wait for a specified number of new blocks or max for a -// given of seconds. sleepBlocks(nBlocks[, maxSleep]). -func (self *Jeth) SleepBlocks(call otto.FunctionCall) (response otto.Value) { - nBlocks := int64(0) - maxSleep := int64(9999999999999999) // indefinitely - - nArgs := len(call.ArgumentList) - - if nArgs == 0 { - throwJSExeception("usage: sleepBlocks([, max sleep in seconds])") - } - - if nArgs >= 1 { - if call.Argument(0).IsNumber() { - nBlocks, _ = call.Argument(0).ToInteger() - } else { - throwJSExeception("expected number as first argument") - } - } - - if nArgs >= 2 { - if call.Argument(1).IsNumber() { - maxSleep, _ = call.Argument(1).ToInteger() - } else { - throwJSExeception("expected number as second argument") - } - } - - // go through the console, this will allow web3 to call the appropriate - // callbacks if a delayed response or notification is received. - currentBlockNr := func() int64 { - result, err := call.Otto.Run("eth.blockNumber") - if err != nil { - throwJSExeception(err.Error()) - } - blockNr, err := result.ToInteger() - if err != nil { - throwJSExeception(err.Error()) - } - return blockNr - } - - targetBlockNr := currentBlockNr() + nBlocks - deadline := time.Now().Add(time.Duration(maxSleep) * time.Second) - - for time.Now().Before(deadline) { - if currentBlockNr() >= targetBlockNr { - return otto.TrueValue() - } - time.Sleep(time.Second) - } - - return otto.FalseValue() -} diff --git a/common/compiler/solidity.go b/common/compiler/solidity.go index ddf7a1ac96..6a5bfecd86 100644 --- a/common/compiler/solidity.go +++ b/common/compiler/solidity.go @@ -149,7 +149,6 @@ func (sol *Solidity) Compile(source string) (map[string]*Contract, error) { compilerOptions := strings.Join(params, " ") cmd := exec.Command(sol.solcPath, params...) - cmd.Dir = wd cmd.Stdin = strings.NewReader(source) cmd.Stderr = stderr diff --git a/console/bridge.go b/console/bridge.go new file mode 100644 index 0000000000..b23e06837d --- /dev/null +++ b/console/bridge.go @@ -0,0 +1,317 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package console + +import ( + "encoding/json" + "fmt" + "io" + "time" + + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/rpc" + "github.com/robertkrimen/otto" +) + +// bridge is a collection of JavaScript utility methods to bride the .js runtime +// environment and the Go RPC connection backing the remote method calls. +type bridge struct { + client rpc.Client // RPC client to execute Ethereum requests through + prompter UserPrompter // Input prompter to allow interactive user feedback + printer io.Writer // Output writer to serialize any display strings to +} + +// newBridge creates a new JavaScript wrapper around an RPC client. +func newBridge(client rpc.Client, prompter UserPrompter, printer io.Writer) *bridge { + return &bridge{ + client: client, + prompter: prompter, + printer: printer, + } +} + +// NewAccount is a wrapper around the personal.newAccount RPC method that uses a +// non-echoing password prompt to aquire the passphrase and executes the original +// RPC method (saved in jeth.newAccount) with it to actually execute the RPC call. +func (b *bridge) NewAccount(call otto.FunctionCall) (response otto.Value) { + var ( + password string + confirm string + err error + ) + switch { + // No password was specified, prompt the user for it + case len(call.ArgumentList) == 0: + if password, err = b.prompter.PromptPassword("Passphrase: "); err != nil { + throwJSException(err.Error()) + } + if confirm, err = b.prompter.PromptPassword("Repeat passphrase: "); err != nil { + throwJSException(err.Error()) + } + if password != confirm { + throwJSException("passphrases don't match!") + } + + // A single string password was specified, use that + case len(call.ArgumentList) == 1 && call.Argument(0).IsString(): + password, _ = call.Argument(0).ToString() + + // Otherwise fail with some error + default: + throwJSException("expected 0 or 1 string argument") + } + // Password aquired, execute the call and return + ret, err := call.Otto.Call("jeth.newAccount", nil, password) + if err != nil { + throwJSException(err.Error()) + } + return ret +} + +// UnlockAccount is a wrapper around the personal.unlockAccount RPC method that +// uses a non-echoing password prompt to aquire the passphrase and executes the +// original RPC method (saved in jeth.unlockAccount) with it to actually execute +// the RPC call. +func (b *bridge) UnlockAccount(call otto.FunctionCall) (response otto.Value) { + // Make sure we have an account specified to unlock + if !call.Argument(0).IsString() { + throwJSException("first argument must be the account to unlock") + } + account := call.Argument(0) + + // If password is not given or is the null value, prompt the user for it + var passwd otto.Value + + if call.Argument(1).IsUndefined() || call.Argument(1).IsNull() { + fmt.Fprintf(b.printer, "Unlock account %s\n", account) + if input, err := b.prompter.PromptPassword("Passphrase: "); err != nil { + throwJSException(err.Error()) + } else { + passwd, _ = otto.ToValue(input) + } + } else { + if !call.Argument(1).IsString() { + throwJSException("password must be a string") + } + passwd = call.Argument(1) + } + // Third argument is the duration how long the account must be unlocked. + duration := otto.NullValue() + if call.Argument(2).IsDefined() && !call.Argument(2).IsNull() { + if !call.Argument(2).IsNumber() { + throwJSException("unlock duration must be a number") + } + duration = call.Argument(2) + } + // Send the request to the backend and return + val, err := call.Otto.Call("jeth.unlockAccount", nil, account, passwd, duration) + if err != nil { + throwJSException(err.Error()) + } + return val +} + +// Sleep will block the console for the specified number of seconds. +func (b *bridge) Sleep(call otto.FunctionCall) (response otto.Value) { + if call.Argument(0).IsNumber() { + sleep, _ := call.Argument(0).ToInteger() + time.Sleep(time.Duration(sleep) * time.Second) + return otto.TrueValue() + } + return throwJSException("usage: sleep()") +} + +// SleepBlocks will block the console for a specified number of new blocks optionally +// until the given timeout is reached. +func (b *bridge) SleepBlocks(call otto.FunctionCall) (response otto.Value) { + var ( + blocks = int64(0) + sleep = int64(9999999999999999) // indefinitely + ) + // Parse the input parameters for the sleep + nArgs := len(call.ArgumentList) + if nArgs == 0 { + throwJSException("usage: sleepBlocks([, max sleep in seconds])") + } + if nArgs >= 1 { + if call.Argument(0).IsNumber() { + blocks, _ = call.Argument(0).ToInteger() + } else { + throwJSException("expected number as first argument") + } + } + if nArgs >= 2 { + if call.Argument(1).IsNumber() { + sleep, _ = call.Argument(1).ToInteger() + } else { + throwJSException("expected number as second argument") + } + } + // go through the console, this will allow web3 to call the appropriate + // callbacks if a delayed response or notification is received. + blockNumber := func() int64 { + result, err := call.Otto.Run("eth.blockNumber") + if err != nil { + throwJSException(err.Error()) + } + block, err := result.ToInteger() + if err != nil { + throwJSException(err.Error()) + } + return block + } + // Poll the current block number until either it ot a timeout is reached + targetBlockNr := blockNumber() + blocks + deadline := time.Now().Add(time.Duration(sleep) * time.Second) + + for time.Now().Before(deadline) { + if blockNumber() >= targetBlockNr { + return otto.TrueValue() + } + time.Sleep(time.Second) + } + return otto.FalseValue() +} + +// Send will serialize the first argument, send it to the node and returns the response. +func (b *bridge) Send(call otto.FunctionCall) (response otto.Value) { + // Ensure that we've got a batch request (array) or a single request (object) + arg := call.Argument(0).Object() + if arg == nil || (arg.Class() != "Array" && arg.Class() != "Object") { + throwJSException("request must be an object or array") + } + // Convert the otto VM arguments to Go values + data, err := call.Otto.Call("JSON.stringify", nil, arg) + if err != nil { + throwJSException(err.Error()) + } + reqjson, err := data.ToString() + if err != nil { + throwJSException(err.Error()) + } + + var ( + reqs []rpc.JSONRequest + batch = true + ) + if err = json.Unmarshal([]byte(reqjson), &reqs); err != nil { + // single request? + reqs = make([]rpc.JSONRequest, 1) + if err = json.Unmarshal([]byte(reqjson), &reqs[0]); err != nil { + throwJSException("invalid request") + } + batch = false + } + // Iteratively execute the requests + call.Otto.Set("response_len", len(reqs)) + call.Otto.Run("var ret_response = new Array(response_len);") + + for i, req := range reqs { + // Execute the RPC request and parse the reply + if err = b.client.Send(&req); err != nil { + return newErrorResponse(call, -32603, err.Error(), req.Id) + } + result := make(map[string]interface{}) + if err = b.client.Recv(&result); err != nil { + return newErrorResponse(call, -32603, err.Error(), req.Id) + } + // Feed the reply back into the JavaScript runtime environment + id, _ := result["id"] + jsonver, _ := result["jsonrpc"] + + call.Otto.Set("ret_id", id) + call.Otto.Set("ret_jsonrpc", jsonver) + call.Otto.Set("response_idx", i) + + if res, ok := result["result"]; ok { + payload, _ := json.Marshal(res) + call.Otto.Set("ret_result", string(payload)) + response, err = call.Otto.Run(` + ret_response[response_idx] = { jsonrpc: ret_jsonrpc, id: ret_id, result: JSON.parse(ret_result) }; + `) + continue + } + if res, ok := result["error"]; ok { + payload, _ := json.Marshal(res) + call.Otto.Set("ret_result", string(payload)) + response, err = call.Otto.Run(` + ret_response[response_idx] = { jsonrpc: ret_jsonrpc, id: ret_id, error: JSON.parse(ret_result) }; + `) + continue + } + return newErrorResponse(call, -32603, fmt.Sprintf("Invalid response"), new(int64)) + } + // Convert single requests back from batch ones + if !batch { + call.Otto.Run("ret_response = ret_response[0];") + } + // Execute any registered callbacks + if call.Argument(1).IsObject() { + call.Otto.Set("callback", call.Argument(1)) + call.Otto.Run(` + if (Object.prototype.toString.call(callback) == '[object Function]') { + callback(null, ret_response); + } + `) + } + return +} + +// throwJSException panics on an otto.Value. The Otto VM will recover from the +// Go panic and throw msg as a JavaScript error. +func throwJSException(msg interface{}) otto.Value { + val, err := otto.ToValue(msg) + if err != nil { + glog.V(logger.Error).Infof("Failed to serialize JavaScript exception %v: %v", msg, err) + } + panic(val) +} + +// newErrorResponse creates a JSON RPC error response for a specific request id, +// containing the specified error code and error message. Beside returning the +// error to the caller, it also sets the ret_error and ret_response JavaScript +// variables. +func newErrorResponse(call otto.FunctionCall, code int, msg string, id interface{}) (response otto.Value) { + // Bundle the error into a JSON RPC call response + res := rpc.JSONErrResponse{ + Version: rpc.JSONRPCVersion, + Id: id, + Error: rpc.JSONError{ + Code: code, + Message: msg, + }, + } + // Serialize the error response into JavaScript variables + errObj, err := json.Marshal(res.Error) + if err != nil { + glog.V(logger.Error).Infof("Failed to serialize JSON RPC error: %v", err) + } + resObj, err := json.Marshal(res) + if err != nil { + glog.V(logger.Error).Infof("Failed to serialize JSON RPC error response: %v", err) + } + + if _, err = call.Otto.Run("ret_error = " + string(errObj)); err != nil { + glog.V(logger.Error).Infof("Failed to set `ret_error` to the occurred error: %v", err) + } + resVal, err := call.Otto.Run("ret_response = " + string(resObj)) + if err != nil { + glog.V(logger.Error).Infof("Failed to set `ret_response` to the JSON RPC response: %v", err) + } + return resVal +} diff --git a/console/console.go b/console/console.go new file mode 100644 index 0000000000..baa9cf5457 --- /dev/null +++ b/console/console.go @@ -0,0 +1,371 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package console + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "os/signal" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/ethereum/go-ethereum/internal/jsre" + "github.com/ethereum/go-ethereum/internal/web3ext" + "github.com/ethereum/go-ethereum/rpc" + "github.com/peterh/liner" + "github.com/robertkrimen/otto" +) + +var ( + passwordRegexp = regexp.MustCompile("personal.[nus]") + onlyWhitespace = regexp.MustCompile("^\\s*$") + exit = regexp.MustCompile("^\\s*exit\\s*;*\\s*$") +) + +// HistoryFile is the file within the data directory to store input scrollback. +const HistoryFile = "history" + +// DefaultPrompt is the default prompt line prefix to use for user input querying. +const DefaultPrompt = "> " + +// Config is te collection of configurations to fine tune the behavior of the +// JavaScript console. +type Config struct { + DataDir string // Data directory to store the console history at + DocRoot string // Filesystem path from where to load JavaScript files from + Client rpc.Client // RPC client to execute Ethereum requests through + Prompt string // Input prompt prefix string (defaults to DefaultPrompt) + Prompter UserPrompter // Input prompter to allow interactive user feedback (defaults to TerminalPrompter) + Printer io.Writer // Output writer to serialize any display strings to (defaults to os.Stdout) + Preload []string // Absolute paths to JavaScript files to preload +} + +// Console is a JavaScript interpreted runtime environment. It is a fully fleged +// JavaScript console attached to a running node via an external or in-process RPC +// client. +type Console struct { + client rpc.Client // RPC client to execute Ethereum requests through + jsre *jsre.JSRE // JavaScript runtime environment running the interpreter + prompt string // Input prompt prefix string + prompter UserPrompter // Input prompter to allow interactive user feedback + histPath string // Absolute path to the console scrollback history + history []string // Scroll history maintained by the console + printer io.Writer // Output writer to serialize any display strings to +} + +func New(config Config) (*Console, error) { + // Handle unset config values gracefully + if config.Prompter == nil { + config.Prompter = Stdin + } + if config.Prompt == "" { + config.Prompt = DefaultPrompt + } + if config.Printer == nil { + config.Printer = os.Stdout + } + // Initialize the console and return + console := &Console{ + client: config.Client, + jsre: jsre.New(config.DocRoot, config.Printer), + prompt: config.Prompt, + prompter: config.Prompter, + printer: config.Printer, + histPath: filepath.Join(config.DataDir, HistoryFile), + } + if err := console.init(config.Preload); err != nil { + return nil, err + } + return console, nil +} + +// init retrieves the available APIs from the remote RPC provider and initializes +// the console's JavaScript namespaces based on the exposed modules. +func (c *Console) init(preload []string) error { + // Initialize the JavaScript <-> Go RPC bridge + bridge := newBridge(c.client, c.prompter, c.printer) + c.jsre.Set("jeth", struct{}{}) + + jethObj, _ := c.jsre.Get("jeth") + jethObj.Object().Set("send", bridge.Send) + jethObj.Object().Set("sendAsync", bridge.Send) + + consoleObj, _ := c.jsre.Get("console") + consoleObj.Object().Set("log", c.consoleOutput) + consoleObj.Object().Set("error", c.consoleOutput) + + // Load all the internal utility JavaScript libraries + if err := c.jsre.Compile("bignumber.js", jsre.BigNumber_JS); err != nil { + return fmt.Errorf("bignumber.js: %v", err) + } + if err := c.jsre.Compile("web3.js", jsre.Web3_JS); err != nil { + return fmt.Errorf("web3.js: %v", err) + } + if _, err := c.jsre.Run("var Web3 = require('web3');"); err != nil { + return fmt.Errorf("web3 require: %v", err) + } + if _, err := c.jsre.Run("var web3 = new Web3(jeth);"); err != nil { + return fmt.Errorf("web3 provider: %v", err) + } + // Load the supported APIs into the JavaScript runtime environment + apis, err := c.client.SupportedModules() + if err != nil { + return fmt.Errorf("api modules: %v", err) + } + flatten := "var eth = web3.eth; var personal = web3.personal; " + for api := range apis { + if api == "web3" { + continue // manually mapped or ignore + } + if file, ok := web3ext.Modules[api]; ok { + if err = c.jsre.Compile(fmt.Sprintf("%s.js", api), file); err != nil { + return fmt.Errorf("%s.js: %v", api, err) + } + flatten += fmt.Sprintf("var %s = web3.%s; ", api, api) + } + } + if _, err = c.jsre.Run(flatten); err != nil { + return fmt.Errorf("namespace flattening: %v", err) + } + // Initialize the global name register (disabled for now) + //c.jsre.Run(`var GlobalRegistrar = eth.contract(` + registrar.GlobalRegistrarAbi + `); registrar = GlobalRegistrar.at("` + registrar.GlobalRegistrarAddr + `");`) + + // If the console is in interactive mode, instrument password related methods to query the user + if c.prompter != nil { + // Retrieve the account management object to instrument + personal, err := c.jsre.Get("personal") + if err != nil { + return err + } + // Override the unlockAccount and newAccount methods since these require user interaction. + // Assign the jeth.unlockAccount and jeth.newAccount in the Console the original web3 callbacks. + // These will be called by the jeth.* methods after they got the password from the user and send + // the original web3 request to the backend. + if obj := personal.Object(); obj != nil { // make sure the personal api is enabled over the interface + if _, err = c.jsre.Run(`jeth.unlockAccount = personal.unlockAccount;`); err != nil { + return fmt.Errorf("personal.unlockAccount: %v", err) + } + if _, err = c.jsre.Run(`jeth.newAccount = personal.newAccount;`); err != nil { + return fmt.Errorf("personal.newAccount: %v", err) + } + obj.Set("unlockAccount", bridge.UnlockAccount) + obj.Set("newAccount", bridge.NewAccount) + } + } + // The admin.sleep and admin.sleepBlocks are offered by the console and not by the RPC layer. + admin, err := c.jsre.Get("admin") + if err != nil { + return err + } + if obj := admin.Object(); obj != nil { // make sure the admin api is enabled over the interface + obj.Set("sleepBlocks", bridge.SleepBlocks) + obj.Set("sleep", bridge.Sleep) + } + // Preload any JavaScript files before starting the console + for _, path := range preload { + if err := c.jsre.Exec(path); err != nil { + failure := err.Error() + if ottoErr, ok := err.(*otto.Error); ok { + failure = ottoErr.String() + } + return fmt.Errorf("%s: %v", path, failure) + } + } + // Configure the console's input prompter for scrollback and tab completion + if c.prompter != nil { + if content, err := ioutil.ReadFile(c.histPath); err != nil { + c.prompter.SetHistory(nil) + } else { + c.history = strings.Split(string(content), "\n") + c.prompter.SetHistory(c.history) + } + c.prompter.SetWordCompleter(c.AutoCompleteInput) + } + return nil +} + +// consoleOutput is an override for the console.log and console.error methods to +// stream the output into the configured output stream instead of stdout. +func (c *Console) consoleOutput(call otto.FunctionCall) otto.Value { + output := []string{} + for _, argument := range call.ArgumentList { + output = append(output, fmt.Sprintf("%v", argument)) + } + fmt.Fprintln(c.printer, strings.Join(output, " ")) + return otto.Value{} +} + +// AutoCompleteInput is a pre-assembled word completer to be used by the user +// input prompter to provide hints to the user about the methods available. +func (c *Console) AutoCompleteInput(line string, pos int) (string, []string, string) { + // No completions can be provided for empty inputs + if len(line) == 0 || pos == 0 { + return "", nil, "" + } + // Chunck data to relevant part for autocompletion + // E.g. in case of nested lines eth.getBalance(eth.coinb + start := 0 + for start = pos - 1; start > 0; start-- { + // Skip all methods and namespaces (i.e. including te dot) + if line[start] == '.' || (line[start] >= 'a' && line[start] <= 'z') || (line[start] >= 'A' && line[start] <= 'Z') { + continue + } + // Handle web3 in a special way (i.e. other numbers aren't auto completed) + if start >= 3 && line[start-3:start] == "web3" { + start -= 3 + continue + } + // We've hit an unexpected character, autocomplete form here + start++ + break + } + return line[:start], c.jsre.CompleteKeywords(line[start:pos]), line[pos:] +} + +// Welcome show summary of current Geth instance and some metadata about the +// console's available modules. +func (c *Console) Welcome() { + // Print some generic Geth metadata + fmt.Fprintf(c.printer, "Welcome to the Geth JavaScript console!\n\n") + c.jsre.Run(` + console.log("instance: " + web3.version.node); + console.log("coinbase: " + eth.coinbase); + console.log("at block: " + eth.blockNumber + " (" + new Date(1000 * eth.getBlock(eth.blockNumber).timestamp) + ")"); + console.log(" datadir: " + admin.datadir); + `) + // List all the supported modules for the user to call + if apis, err := c.client.SupportedModules(); err == nil { + modules := make([]string, 0, len(apis)) + for api, version := range apis { + modules = append(modules, fmt.Sprintf("%s:%s", api, version)) + } + sort.Strings(modules) + fmt.Fprintln(c.printer, " modules:", strings.Join(modules, " ")) + } + fmt.Fprintln(c.printer) +} + +// Evaluate executes code and pretty prints the result to the specified output +// stream. +func (c *Console) Evaluate(statement string) error { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(c.printer, "[native] error: %v\n", r) + } + }() + if err := c.jsre.Evaluate(statement, c.printer); err != nil { + return err + } + return nil +} + +// Interactive starts an interactive user session, where input is propted from +// the configured user prompter. +func (c *Console) Interactive() { + var ( + prompt = c.prompt // Current prompt line (used for multi-line inputs) + indents = 0 // Current number of input indents (used for multi-line inputs) + input = "" // Current user input + scheduler = make(chan string) // Channel to send the next prompt on and receive the input + ) + // Start a goroutine to listen for promt requests and send back inputs + go func() { + for { + // Read the next user input + line, err := c.prompter.PromptInput(<-scheduler) + if err != nil { + // In case of an error, either clear the prompt or fail + if err == liner.ErrPromptAborted { // ctrl-C + prompt, indents, input = c.prompt, 0, "" + scheduler <- "" + continue + } + close(scheduler) + return + } + // User input retrieved, send for interpretation and loop + scheduler <- line + } + }() + // Monitor Ctrl-C too in case the input is empty and we need to bail + abort := make(chan os.Signal, 1) + signal.Notify(abort, os.Interrupt) + + // Start sending prompts to the user and reading back inputs + for { + // Send the next prompt, triggering an input read and process the result + scheduler <- prompt + select { + case <-abort: + // User forcefully quite the console + fmt.Fprintln(c.printer, "caught interrupt, exiting") + return + + case line, ok := <-scheduler: + // User input was returned by the prompter, handle special cases + if !ok || (indents <= 0 && exit.MatchString(line)) { + return + } + if onlyWhitespace.MatchString(line) { + continue + } + // Append the line to the input and check for multi-line interpretation + input += line + "\n" + + indents = strings.Count(input, "{") + strings.Count(input, "(") - strings.Count(input, "}") - strings.Count(input, ")") + if indents <= 0 { + prompt = c.prompt + } else { + prompt = strings.Repeat("..", indents*2) + " " + } + // If all the needed lines are present, save the command and run + if indents <= 0 { + if len(input) > 0 && input[0] != ' ' && !passwordRegexp.MatchString(input) { + if command := strings.TrimSpace(input); len(c.history) == 0 || command != c.history[len(c.history)-1] { + c.history = append(c.history, command) + if c.prompter != nil { + c.prompter.AppendHistory(command) + } + } + } + c.Evaluate(input) + input = "" + } + } + } +} + +// Execute runs the JavaScript file specified as the argument. +func (c *Console) Execute(path string) error { + return c.jsre.Exec(path) +} + +// Stop cleans up the console and terminates the runtime envorinment. +func (c *Console) Stop(graceful bool) error { + if err := ioutil.WriteFile(c.histPath, []byte(strings.Join(c.history, "\n")), 0600); err != nil { + return err + } + if err := os.Chmod(c.histPath, 0600); err != nil { // Force 0600, even if it was different previously + return err + } + c.jsre.Stop(graceful) + return nil +} diff --git a/console/console_test.go b/console/console_test.go new file mode 100644 index 0000000000..9110878242 --- /dev/null +++ b/console/console_test.go @@ -0,0 +1,296 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package console + +import ( + "bytes" + "errors" + "fmt" + "io/ioutil" + "math/big" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core" + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/internal/jsre" + "github.com/ethereum/go-ethereum/node" +) + +const ( + testInstance = "console-tester" + testAddress = "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" +) + +// hookedPrompter implements UserPrompter to simulate use input via channels. +type hookedPrompter struct { + scheduler chan string +} + +func (p *hookedPrompter) PromptInput(prompt string) (string, error) { + // Send the prompt to the tester + select { + case p.scheduler <- prompt: + case <-time.After(time.Second): + return "", errors.New("prompt timeout") + } + // Retrieve the response and feed to the console + select { + case input := <-p.scheduler: + return input, nil + case <-time.After(time.Second): + return "", errors.New("input timeout") + } +} + +func (p *hookedPrompter) PromptPassword(prompt string) (string, error) { + return "", errors.New("not implemented") +} +func (p *hookedPrompter) PromptConfirm(prompt string) (bool, error) { + return false, errors.New("not implemented") +} +func (p *hookedPrompter) SetHistory(history []string) {} +func (p *hookedPrompter) AppendHistory(command string) {} +func (p *hookedPrompter) SetWordCompleter(completer WordCompleter) {} + +// tester is a console test environment for the console tests to operate on. +type tester struct { + workspace string + stack *node.Node + ethereum *eth.Ethereum + console *Console + input *hookedPrompter + output *bytes.Buffer + + lastConfirm string +} + +// newTester creates a test environment based on which the console can operate. +// Please ensure you call Close() on the returned tester to avoid leaks. +func newTester(t *testing.T, confOverride func(*eth.Config)) *tester { + // Create a temporary storage for the node keys and initialize it + workspace, err := ioutil.TempDir("", "console-tester-") + if err != nil { + t.Fatalf("failed to create temporary keystore: %v", err) + } + accman := accounts.NewPlaintextManager(filepath.Join(workspace, "keystore")) + + // Create a networkless protocol stack and start an Ethereum service within + stack, err := node.New(&node.Config{DataDir: workspace, Name: testInstance, NoDiscovery: true}) + if err != nil { + t.Fatalf("failed to create node: %v", err) + } + ethConf := ð.Config{ + ChainConfig: &core.ChainConfig{HomesteadBlock: new(big.Int)}, + Etherbase: common.HexToAddress(testAddress), + AccountManager: accman, + PowTest: true, + } + if confOverride != nil { + confOverride(ethConf) + } + if err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { return eth.New(ctx, ethConf) }); err != nil { + t.Fatalf("failed to register Ethereum protocol: %v", err) + } + // Start the node and assemble the JavaScript console around it + if err = stack.Start(); err != nil { + t.Fatalf("failed to start test stack: %v", err) + } + client, err := stack.Attach() + if err != nil { + t.Fatalf("failed to attach to node: %v", err) + } + prompter := &hookedPrompter{scheduler: make(chan string)} + printer := new(bytes.Buffer) + + console, err := New(Config{ + DataDir: stack.DataDir(), + DocRoot: "testdata", + Client: client, + Prompter: prompter, + Printer: printer, + Preload: []string{"preload.js"}, + }) + if err != nil { + t.Fatalf("failed to create JavaScript console: %v", err) + } + // Create the final tester and return + var ethereum *eth.Ethereum + stack.Service(ðereum) + + return &tester{ + workspace: workspace, + stack: stack, + ethereum: ethereum, + console: console, + input: prompter, + output: printer, + } +} + +// Close cleans up any temporary data folders and held resources. +func (env *tester) Close(t *testing.T) { + if err := env.console.Stop(false); err != nil { + t.Errorf("failed to stop embedded console: %v", err) + } + if err := env.stack.Stop(); err != nil { + t.Errorf("failed to stop embedded node: %v", err) + } + os.RemoveAll(env.workspace) +} + +// Tests that the node lists the correct welcome message, notably that it contains +// the instance name, coinbase account, block number, data directory and supported +// console modules. +func TestWelcome(t *testing.T) { + tester := newTester(t, nil) + defer tester.Close(t) + + tester.console.Welcome() + + output := string(tester.output.Bytes()) + if want := "Welcome"; !strings.Contains(output, want) { + t.Fatalf("console output missing welcome message: have\n%s\nwant also %s", output, want) + } + if want := fmt.Sprintf("instance: %s", testInstance); !strings.Contains(output, want) { + t.Fatalf("console output missing instance: have\n%s\nwant also %s", output, want) + } + if want := fmt.Sprintf("coinbase: %s", testAddress); !strings.Contains(output, want) { + t.Fatalf("console output missing coinbase: have\n%s\nwant also %s", output, want) + } + if want := "at block: 0"; !strings.Contains(output, want) { + t.Fatalf("console output missing sync status: have\n%s\nwant also %s", output, want) + } + if want := fmt.Sprintf("datadir: %s", tester.workspace); !strings.Contains(output, want) { + t.Fatalf("console output missing coinbase: have\n%s\nwant also %s", output, want) + } +} + +// Tests that JavaScript statement evaluation works as intended. +func TestEvaluate(t *testing.T) { + tester := newTester(t, nil) + defer tester.Close(t) + + tester.console.Evaluate("2 + 2") + if output := string(tester.output.Bytes()); !strings.Contains(output, "4") { + t.Fatalf("statement evaluation failed: have %s, want %s", output, "4") + } +} + +// Tests that the console can be used in interactive mode. +func TestInteractive(t *testing.T) { + // Create a tester and run an interactive console in the background + tester := newTester(t, nil) + defer tester.Close(t) + + go tester.console.Interactive() + + // Wait for a promt and send a statement back + select { + case <-tester.input.scheduler: + case <-time.After(time.Second): + t.Fatalf("initial prompt timeout") + } + select { + case tester.input.scheduler <- "2+2": + case <-time.After(time.Second): + t.Fatalf("input feedback timeout") + } + // Wait for the second promt and ensure first statement was evaluated + select { + case <-tester.input.scheduler: + case <-time.After(time.Second): + t.Fatalf("secondary prompt timeout") + } + if output := string(tester.output.Bytes()); !strings.Contains(output, "4") { + t.Fatalf("statement evaluation failed: have %s, want %s", output, "4") + } +} + +// Tests that preloaded JavaScript files have been executed before user is given +// input. +func TestPreload(t *testing.T) { + tester := newTester(t, nil) + defer tester.Close(t) + + tester.console.Evaluate("preloaded") + if output := string(tester.output.Bytes()); !strings.Contains(output, "some-preloaded-string") { + t.Fatalf("preloaded variable missing: have %s, want %s", output, "some-preloaded-string") + } +} + +// Tests that JavaScript scripts can be executes from the configured asset path. +func TestExecute(t *testing.T) { + tester := newTester(t, nil) + defer tester.Close(t) + + tester.console.Execute("exec.js") + + tester.console.Evaluate("execed") + if output := string(tester.output.Bytes()); !strings.Contains(output, "some-executed-string") { + t.Fatalf("execed variable missing: have %s, want %s", output, "some-executed-string") + } +} + +// Tests that the JavaScript objects returned by statement executions are properly +// pretty printed instead of just displaing "[object]". +func TestPrettyPrint(t *testing.T) { + tester := newTester(t, nil) + defer tester.Close(t) + + tester.console.Evaluate("obj = {int: 1, string: 'two', list: [3, 3, 3], obj: {null: null, func: function(){}}}") + + // Define some specially formatted fields + var ( + one = jsre.NumberColor("1") + two = jsre.StringColor("\"two\"") + three = jsre.NumberColor("3") + null = jsre.SpecialColor("null") + fun = jsre.FunctionColor("function()") + ) + // Assemble the actual output we're after and verify + want := `{ + int: ` + one + `, + list: [` + three + `, ` + three + `, ` + three + `], + obj: { + null: ` + null + `, + func: ` + fun + ` + }, + string: ` + two + ` +} +` + if output := string(tester.output.Bytes()); output != want { + t.Fatalf("pretty print mismatch: have %s, want %s", output, want) + } +} + +// Tests that the JavaScript exceptions are properly formatted and colored. +func TestPrettyError(t *testing.T) { + tester := newTester(t, nil) + defer tester.Close(t) + tester.console.Evaluate("throw 'hello'") + + want := jsre.ErrorColor("hello") + "\n" + if output := string(tester.output.Bytes()); output != want { + t.Fatalf("pretty error mismatch: have %s, want %s", output, want) + } +} diff --git a/console/prompter.go b/console/prompter.go new file mode 100644 index 0000000000..0e4a8a53ec --- /dev/null +++ b/console/prompter.go @@ -0,0 +1,165 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package console + +import ( + "fmt" + "strings" + + "github.com/peterh/liner" +) + +// Stdin holds the stdin line reader (also using stdout for printing prompts). +// Only this reader may be used for input because it keeps an internal buffer. +var Stdin = newTerminalPrompter() + +// UserPrompter defines the methods needed by the console to promt the user for +// various types of inputs. +type UserPrompter interface { + // PromptInput displays the given prompt to the user and requests some textual + // data to be entered, returning the input of the user. + PromptInput(prompt string) (string, error) + + // PromptPassword displays the given prompt to the user and requests some textual + // data to be entered, but one which must not be echoed out into the terminal. + // The method returns the input provided by the user. + PromptPassword(prompt string) (string, error) + + // PromptConfirm displays the given prompt to the user and requests a boolean + // choice to be made, returning that choice. + PromptConfirm(prompt string) (bool, error) + + // SetHistory sets the the input scrollback history that the prompter will allow + // the user to scoll back to. + SetHistory(history []string) + + // AppendHistory appends an entry to the scrollback history. It should be called + // if and only if the prompt to append was a valid command. + AppendHistory(command string) + + // SetWordCompleter sets the completion function that the prompter will call to + // fetch completion candidates when the user presses tab. + SetWordCompleter(completer WordCompleter) +} + +// WordCompleter takes the currently edited line with the cursor position and +// returns the completion candidates for the partial word to be completed. If +// the line is "Hello, wo!!!" and the cursor is before the first '!', ("Hello, +// wo!!!", 9) is passed to the completer which may returns ("Hello, ", {"world", +// "Word"}, "!!!") to have "Hello, world!!!". +type WordCompleter func(line string, pos int) (string, []string, string) + +// terminalPrompter is a UserPrompter backed by the liner package. It supports +// prompting the user for various input, among others for non-echoing password +// input. +type terminalPrompter struct { + *liner.State + warned bool + supported bool + normalMode liner.ModeApplier + rawMode liner.ModeApplier +} + +// newTerminalPrompter creates a liner based user input prompter working off the +// standard input and output streams. +func newTerminalPrompter() *terminalPrompter { + p := new(terminalPrompter) + // Get the original mode before calling NewLiner. + // This is usually regular "cooked" mode where characters echo. + normalMode, _ := liner.TerminalMode() + // Turn on liner. It switches to raw mode. + p.State = liner.NewLiner() + rawMode, err := liner.TerminalMode() + if err != nil || !liner.TerminalSupported() { + p.supported = false + } else { + p.supported = true + p.normalMode = normalMode + p.rawMode = rawMode + // Switch back to normal mode while we're not prompting. + normalMode.ApplyMode() + } + p.SetCtrlCAborts(true) + p.SetTabCompletionStyle(liner.TabPrints) + + return p +} + +// PromptInput displays the given prompt to the user and requests some textual +// data to be entered, returning the input of the user. +func (p *terminalPrompter) PromptInput(prompt string) (string, error) { + if p.supported { + p.rawMode.ApplyMode() + defer p.normalMode.ApplyMode() + } else { + // liner tries to be smart about printing the prompt + // and doesn't print anything if input is redirected. + // Un-smart it by printing the prompt always. + fmt.Print(prompt) + prompt = "" + defer fmt.Println() + } + return p.State.Prompt(prompt) +} + +// PromptPassword displays the given prompt to the user and requests some textual +// data to be entered, but one which must not be echoed out into the terminal. +// The method returns the input provided by the user. +func (p *terminalPrompter) PromptPassword(prompt string) (passwd string, err error) { + if p.supported { + p.rawMode.ApplyMode() + defer p.normalMode.ApplyMode() + return p.State.PasswordPrompt(prompt) + } + if !p.warned { + fmt.Println("!! Unsupported terminal, password will be echoed.") + p.warned = true + } + // Just as in Prompt, handle printing the prompt here instead of relying on liner. + fmt.Print(prompt) + passwd, err = p.State.Prompt("") + fmt.Println() + return passwd, err +} + +// PromptConfirm displays the given prompt to the user and requests a boolean +// choice to be made, returning that choice. +func (p *terminalPrompter) PromptConfirm(prompt string) (bool, error) { + input, err := p.Prompt(prompt + " [y/N] ") + if len(input) > 0 && strings.ToUpper(input[:1]) == "Y" { + return true, nil + } + return false, err +} + +// SetHistory sets the the input scrollback history that the prompter will allow +// the user to scoll back to. +func (p *terminalPrompter) SetHistory(history []string) { + p.State.ReadHistory(strings.NewReader(strings.Join(history, "\n"))) +} + +// AppendHistory appends an entry to the scrollback history. It should be called +// if and only if the prompt to append was a valid command. +func (p *terminalPrompter) AppendHistory(command string) { + p.State.AppendHistory(command) +} + +// SetWordCompleter sets the completion function that the prompter will call to +// fetch completion candidates when the user presses tab. +func (p *terminalPrompter) SetWordCompleter(completer WordCompleter) { + p.State.SetWordCompleter(liner.WordCompleter(completer)) +} diff --git a/console/testdata/exec.js b/console/testdata/exec.js new file mode 100644 index 0000000000..59e34d7c40 --- /dev/null +++ b/console/testdata/exec.js @@ -0,0 +1 @@ +var execed = "some-executed-string"; diff --git a/console/testdata/preload.js b/console/testdata/preload.js new file mode 100644 index 0000000000..556793970f --- /dev/null +++ b/console/testdata/preload.js @@ -0,0 +1 @@ +var preloaded = "some-preloaded-string"; diff --git a/core/bad_block.go b/core/bad_block.go deleted file mode 100644 index cd3fb575a8..0000000000 --- a/core/bad_block.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package core - -import ( - "bytes" - "encoding/json" - "io/ioutil" - "net/http" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/logger" - "github.com/ethereum/go-ethereum/logger/glog" - "github.com/ethereum/go-ethereum/rlp" -) - -// DisabledBadBlockReporting can be set to prevent blocks being reported. -var DisableBadBlockReporting = true - -// ReportBlock reports the block to the block reporting tool found at -// badblocks.ethdev.com -func ReportBlock(block *types.Block, err error) { - if DisableBadBlockReporting { - return - } - - const url = "https://badblocks.ethdev.com" - - blockRlp, _ := rlp.EncodeToBytes(block) - data := map[string]interface{}{ - "block": common.Bytes2Hex(blockRlp), - "errortype": err.Error(), - "hints": map[string]interface{}{ - "receipts": "NYI", - "vmtrace": "NYI", - }, - } - jsonStr, _ := json.Marshal(map[string]interface{}{"method": "eth_badBlock", "params": []interface{}{data}, "id": "1", "jsonrpc": "2.0"}) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonStr)) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - glog.V(logger.Error).Infoln("POST err:", err) - return - } - defer resp.Body.Close() - - if glog.V(logger.Debug) { - glog.Infoln("response Status:", resp.Status) - glog.Infoln("response Headers:", resp.Header) - body, _ := ioutil.ReadAll(resp.Body) - glog.Infoln("response Body:", string(body)) - } -} diff --git a/core/blockchain.go b/core/blockchain.go index 4598800d54..bd84adfe9a 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -819,6 +819,7 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { tstart = time.Now() nonceChecked = make([]bool, len(chain)) + statedb *state.StateDB ) // Start the parallel nonce verifier. @@ -885,7 +886,11 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { // Create a new statedb using the parent block and report an // error if it fails. - statedb, err := state.New(self.GetBlock(block.ParentHash()).Root(), self.chainDb) + if statedb == nil { + statedb, err = state.New(self.GetBlock(block.ParentHash()).Root(), self.chainDb) + } else { + err = statedb.Reset(chain[i-1].Root()) + } if err != nil { reportBlock(block, err) return i, err @@ -1117,15 +1122,12 @@ func (self *BlockChain) update() { } } -// reportBlock reports the given block and error using the canonical block -// reporting tool. Reporting the block to the service is handled in a separate -// goroutine. +// reportBlock logs a bad block error. func reportBlock(block *types.Block, err error) { if glog.V(logger.Error) { glog.Errorf("Bad block #%v (%s)\n", block.Number(), block.Hash().Hex()) glog.Errorf(" %v", err) } - go ReportBlock(block, err) } // InsertHeaderChain attempts to insert the given header chain in to the local diff --git a/core/state/statedb.go b/core/state/statedb.go index 22ffa36a06..70673799ed 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -68,6 +68,28 @@ func New(root common.Hash, db ethdb.Database) (*StateDB, error) { }, nil } +// Reset clears out all emphemeral state objects from the state db, but keeps +// the underlying state trie to avoid reloading data for the next operations. +func (self *StateDB) Reset(root common.Hash) error { + var ( + err error + tr = self.trie + ) + if self.trie.Hash() != root { + if tr, err = trie.NewSecure(root, self.db); err != nil { + return err + } + } + *self = StateDB{ + db: self.db, + trie: tr, + stateObjects: make(map[string]*StateObject), + refund: new(big.Int), + logs: make(map[common.Hash]vm.Logs), + } + return nil +} + func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { self.thash = thash self.bhash = bhash @@ -127,7 +149,7 @@ func (self *StateDB) GetNonce(addr common.Address) uint64 { return stateObject.nonce } - return 0 + return StartingNonce } func (self *StateDB) GetCode(addr common.Address) []byte { diff --git a/core/tx_pool.go b/core/tx_pool.go index f2eb2bbdd3..5963563770 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -368,6 +368,9 @@ func (self *TxPool) AddTransactions(txs []*types.Transaction) { // GetTransaction returns a transaction if it is contained in the pool // and nil otherwise. func (tp *TxPool) GetTransaction(hash common.Hash) *types.Transaction { + tp.mu.RLock() + defer tp.mu.RUnlock() + // check the txs first if tx, ok := tp.pending[hash]; ok { return tx @@ -421,12 +424,18 @@ func (self *TxPool) RemoveTransactions(txs types.Transactions) { self.mu.Lock() defer self.mu.Unlock() for _, tx := range txs { - self.RemoveTx(tx.Hash()) + self.removeTx(tx.Hash()) } } // RemoveTx removes the transaction with the given hash from the pool. func (pool *TxPool) RemoveTx(hash common.Hash) { + pool.mu.Lock() + defer pool.mu.Unlock() + pool.removeTx(hash) +} + +func (pool *TxPool) removeTx(hash common.Hash) { // delete from pending pool delete(pool.pending, hash) // delete from queue diff --git a/core/types/block.go b/core/types/block.go index 387a063aeb..37b6f3ec17 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -141,8 +141,10 @@ type Block struct { // of the chain up to and including the block. td *big.Int - // ReceivedAt is used by package eth to track block propagation time. - ReceivedAt time.Time + // These fields are used by package eth to track + // inter-peer block relay. + ReceivedAt time.Time + ReceivedFrom interface{} } // DeprecatedTd is an old relic for extracting the TD of a block. It is in the diff --git a/core/types/bloom9_test.go b/core/types/bloom9_test.go index 58e8f70731..a28ac0e7af 100644 --- a/core/types/bloom9_test.go +++ b/core/types/bloom9_test.go @@ -39,12 +39,12 @@ func TestBloom(t *testing.T) { } for _, data := range positive { - if !bloom.Test(new(big.Int).SetBytes([]byte(data))) { + if !bloom.TestBytes([]byte(data)) { t.Error("expected", data, "to test true") } } for _, data := range negative { - if bloom.Test(new(big.Int).SetBytes([]byte(data))) { + if bloom.TestBytes([]byte(data)) { t.Error("did not expect", data, "to test true") } } diff --git a/eth/api.go b/eth/api.go index d048904f36..f5f942c27d 100644 --- a/eth/api.go +++ b/eth/api.go @@ -113,7 +113,7 @@ func (s *PublicEthereumAPI) GasPrice() *big.Int { // GetCompilers returns the collection of available smart contract compilers func (s *PublicEthereumAPI) GetCompilers() ([]string, error) { solc, err := s.e.Solc() - if err != nil && solc != nil { + if err == nil && solc != nil { return []string{"Solidity"}, nil } diff --git a/eth/bad_block.go b/eth/bad_block.go new file mode 100644 index 0000000000..3a6c3d85cb --- /dev/null +++ b/eth/bad_block.go @@ -0,0 +1,74 @@ +// Copyright 2015 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/logger" + "github.com/ethereum/go-ethereum/logger/glog" + "github.com/ethereum/go-ethereum/rlp" +) + +const ( + // The Ethereum main network genesis block. + defaultGenesisHash = "0xd4e56740f876aef8c010b86a40d5f56745a118d0906a34e69aec8c0db1cb8fa3" + badBlocksURL = "https://badblocks.ethdev.com" +) + +var EnableBadBlockReporting = false + +func sendBadBlockReport(block *types.Block, err error) { + if !EnableBadBlockReporting { + return + } + + var ( + blockRLP, _ = rlp.EncodeToBytes(block) + params = map[string]interface{}{ + "block": common.Bytes2Hex(blockRLP), + "blockHash": block.Hash().Hex(), + "errortype": err.Error(), + "client": "go", + } + ) + if !block.ReceivedAt.IsZero() { + params["receivedAt"] = block.ReceivedAt.UTC().String() + } + if p, ok := block.ReceivedFrom.(*peer); ok { + params["receivedFrom"] = map[string]interface{}{ + "enode": fmt.Sprintf("enode://%x@%v", p.ID(), p.RemoteAddr()), + "name": p.Name(), + "protocolVersion": p.version, + } + } + jsonStr, _ := json.Marshal(map[string]interface{}{"method": "eth_badBlock", "id": "1", "jsonrpc": "2.0", "params": []interface{}{params}}) + client := http.Client{Timeout: 8 * time.Second} + resp, err := client.Post(badBlocksURL, "application/json", bytes.NewReader(jsonStr)) + if err != nil { + glog.V(logger.Debug).Infoln(err) + return + } + glog.V(logger.Debug).Infof("Bad Block Report posted (%d)", resp.StatusCode) + resp.Body.Close() +} diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 0f76357cba..92124cfeb2 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -35,6 +35,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/trie" "github.com/rcrowley/go-metrics" ) @@ -42,6 +43,7 @@ var ( MaxHashFetch = 512 // Amount of hashes to be fetched per retrieval request MaxBlockFetch = 128 // Amount of blocks to be fetched per retrieval request MaxHeaderFetch = 192 // Amount of block headers to be fetched per retrieval request + MaxSkeletonSize = 128 // Number of header fetches to need for a skeleton assembly MaxBodyFetch = 128 // Amount of block bodies to be fetched per retrieval request MaxReceiptFetch = 256 // Amount of transaction receipts to allow fetching per request MaxStateFetch = 384 // Amount of node state values to allow fetching per request @@ -52,65 +54,72 @@ var ( blockTargetRTT = 3 * time.Second / 2 // [eth/61] Target time for completing a block retrieval request blockTTL = 3 * blockTargetRTT // [eth/61] Maximum time allowance before a block request is considered expired - headerTTL = 3 * time.Second // [eth/62] Time it takes for a header request to time out - bodyTargetRTT = 3 * time.Second / 2 // [eth/62] Target time for completing a block body retrieval request - bodyTTL = 3 * bodyTargetRTT // [eth/62] Maximum time allowance before a block body request is considered expired - receiptTargetRTT = 3 * time.Second / 2 // [eth/63] Target time for completing a receipt retrieval request - receiptTTL = 3 * receiptTargetRTT // [eth/63] Maximum time allowance before a receipt request is considered expired - stateTargetRTT = 2 * time.Second / 2 // [eth/63] Target time for completing a state trie retrieval request - stateTTL = 3 * stateTargetRTT // [eth/63] Maximum time allowance before a node data request is considered expired + rttMinEstimate = 2 * time.Second // Minimum round-trip time to target for download requests + rttMaxEstimate = 20 * time.Second // Maximum rount-trip time to target for download requests + rttMinConfidence = 0.1 // Worse confidence factor in our estimated RTT value + ttlScaling = 3 // Constant scaling factor for RTT -> TTL conversion + ttlLimit = time.Minute // Maximum TTL allowance to prevent reaching crazy timeouts - maxQueuedHashes = 256 * 1024 // [eth/61] Maximum number of hashes to queue for import (DOS protection) - maxQueuedHeaders = 256 * 1024 // [eth/62] Maximum number of headers to queue for import (DOS protection) - maxResultsProcess = 256 // Number of download results to import at once into the chain + qosTuningPeers = 5 // Number of peers to tune based on (best peers) + qosConfidenceCap = 10 // Number of peers above which not to modify RTT confidence + qosTuningImpact = 0.25 // Impact that a new tuning target has on the previous value + + maxQueuedHashes = 32 * 1024 // [eth/61] Maximum number of hashes to queue for import (DOS protection) + maxQueuedHeaders = 32 * 1024 // [eth/62] Maximum number of headers to queue for import (DOS protection) + maxHeadersProcess = 2048 // Number of header download results to import at once into the chain + maxResultsProcess = 2048 // Number of content download results to import at once into the chain fsHeaderCheckFrequency = 100 // Verification frequency of the downloaded headers during fast sync fsHeaderSafetyNet = 2048 // Number of headers to discard in case a chain violation is detected fsHeaderForceVerify = 24 // Number of headers to verify before and after the pivot to accept it fsPivotInterval = 512 // Number of headers out of which to randomize the pivot point fsMinFullBlocks = 1024 // Number of blocks to retrieve fully even in fast sync + fsCriticalTrials = 10 // Number of times to retry in the cricical section before bailing ) var ( - errBusy = errors.New("busy") - errUnknownPeer = errors.New("peer is unknown or unhealthy") - errBadPeer = errors.New("action from bad peer ignored") - errStallingPeer = errors.New("peer is stalling") - errNoPeers = errors.New("no peers to keep download active") - errTimeout = errors.New("timeout") - errEmptyHashSet = errors.New("empty hash set by peer") - errEmptyHeaderSet = errors.New("empty header set by peer") - errPeersUnavailable = errors.New("no peers available or all tried for download") - errAlreadyInPool = errors.New("hash already in pool") - errInvalidAncestor = errors.New("retrieved ancestor is invalid") - errInvalidChain = errors.New("retrieved hash chain is invalid") - errInvalidBlock = errors.New("retrieved block is invalid") - errInvalidBody = errors.New("retrieved block body is invalid") - errInvalidReceipt = errors.New("retrieved receipt is invalid") - errCancelHashFetch = errors.New("hash download canceled (requested)") - errCancelBlockFetch = errors.New("block download canceled (requested)") - errCancelHeaderFetch = errors.New("block header download canceled (requested)") - errCancelBodyFetch = errors.New("block body download canceled (requested)") - errCancelReceiptFetch = errors.New("receipt download canceled (requested)") - errCancelStateFetch = errors.New("state data download canceled (requested)") - errCancelProcessing = errors.New("processing canceled (requested)") - errNoSyncActive = errors.New("no sync active") + errBusy = errors.New("busy") + errUnknownPeer = errors.New("peer is unknown or unhealthy") + errBadPeer = errors.New("action from bad peer ignored") + errStallingPeer = errors.New("peer is stalling") + errNoPeers = errors.New("no peers to keep download active") + errTimeout = errors.New("timeout") + errEmptyHashSet = errors.New("empty hash set by peer") + errEmptyHeaderSet = errors.New("empty header set by peer") + errPeersUnavailable = errors.New("no peers available or all tried for download") + errAlreadyInPool = errors.New("hash already in pool") + errInvalidAncestor = errors.New("retrieved ancestor is invalid") + errInvalidChain = errors.New("retrieved hash chain is invalid") + errInvalidBlock = errors.New("retrieved block is invalid") + errInvalidBody = errors.New("retrieved block body is invalid") + errInvalidReceipt = errors.New("retrieved receipt is invalid") + errCancelHashFetch = errors.New("hash download canceled (requested)") + errCancelBlockFetch = errors.New("block download canceled (requested)") + errCancelHeaderFetch = errors.New("block header download canceled (requested)") + errCancelBodyFetch = errors.New("block body download canceled (requested)") + errCancelReceiptFetch = errors.New("receipt download canceled (requested)") + errCancelStateFetch = errors.New("state data download canceled (requested)") + errCancelHeaderProcessing = errors.New("header processing canceled (requested)") + errCancelContentProcessing = errors.New("content processing canceled (requested)") + errNoSyncActive = errors.New("no sync active") ) type Downloader struct { - mode SyncMode // Synchronisation mode defining the strategy used (per sync cycle) - noFast bool // Flag to disable fast syncing in case of a security error - mux *event.TypeMux // Event multiplexer to announce sync operation events + mode SyncMode // Synchronisation mode defining the strategy used (per sync cycle) + mux *event.TypeMux // Event multiplexer to announce sync operation events queue *queue // Scheduler for selecting the hashes to download peers *peerSet // Set of active peers from which download can proceed - interrupt int32 // Atomic boolean to signal termination + fsPivotLock *types.Header // Pivot header on critical section entry (cannot change between retries) + fsPivotFails int // Number of fast sync failures in the critical section + + rttEstimate uint64 // Round trip time to target for download requests + rttConfidence uint64 // Confidence in the estimated RTT (unit: millionths to allow atomic ops) // Statistics syncStatsChainOrigin uint64 // Origin block number where syncing started at syncStatsChainHeight uint64 // Highest block number known when syncing started - syncStatsStateTotal uint64 // Total number of node state entries known so far syncStatsStateDone uint64 // Number of state trie entries already pulled syncStatsLock sync.RWMutex // Lock protecting the sync stats fields @@ -137,20 +146,24 @@ type Downloader struct { // Channels newPeerCh chan *peer - hashCh chan dataPack // [eth/61] Channel receiving inbound hashes - blockCh chan dataPack // [eth/61] Channel receiving inbound blocks - headerCh chan dataPack // [eth/62] Channel receiving inbound block headers - bodyCh chan dataPack // [eth/62] Channel receiving inbound block bodies - receiptCh chan dataPack // [eth/63] Channel receiving inbound receipts - stateCh chan dataPack // [eth/63] Channel receiving inbound node state data - blockWakeCh chan bool // [eth/61] Channel to signal the block fetcher of new tasks - bodyWakeCh chan bool // [eth/62] Channel to signal the block body fetcher of new tasks - receiptWakeCh chan bool // [eth/63] Channel to signal the receipt fetcher of new tasks - stateWakeCh chan bool // [eth/63] Channel to signal the state fetcher of new tasks + hashCh chan dataPack // [eth/61] Channel receiving inbound hashes + blockCh chan dataPack // [eth/61] Channel receiving inbound blocks + headerCh chan dataPack // [eth/62] Channel receiving inbound block headers + bodyCh chan dataPack // [eth/62] Channel receiving inbound block bodies + receiptCh chan dataPack // [eth/63] Channel receiving inbound receipts + stateCh chan dataPack // [eth/63] Channel receiving inbound node state data + blockWakeCh chan bool // [eth/61] Channel to signal the block fetcher of new tasks + bodyWakeCh chan bool // [eth/62] Channel to signal the block body fetcher of new tasks + receiptWakeCh chan bool // [eth/63] Channel to signal the receipt fetcher of new tasks + stateWakeCh chan bool // [eth/63] Channel to signal the state fetcher of new tasks + headerProcCh chan []*types.Header // [eth/62] Channel to feed the header processor new tasks cancelCh chan struct{} // Channel to cancel mid-flight syncs cancelLock sync.RWMutex // Lock to protect the cancel channel in delivers + quitCh chan struct{} // Quit channel to signal termination + quitLock sync.RWMutex // Lock to prevent double closes + // Testing hooks syncInitHook func(uint64, uint64) // Method to call upon initiating a new sync run bodyFetchHook func([]*types.Header) // Method to call upon starting a block body fetch @@ -164,11 +177,13 @@ func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, ha headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { - return &Downloader{ + dl := &Downloader{ mode: FullSync, mux: mux, queue: newQueue(stateDb), peers: newPeerSet(), + rttEstimate: uint64(rttMaxEstimate), + rttConfidence: uint64(1000000), hasHeader: hasHeader, hasBlockAndState: hasBlockAndState, getHeader: getHeader, @@ -194,7 +209,11 @@ func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, ha bodyWakeCh: make(chan bool, 1), receiptWakeCh: make(chan bool, 1), stateWakeCh: make(chan bool, 1), + headerProcCh: make(chan []*types.Header, 1), + quitCh: make(chan struct{}), } + go dl.qosTuner() + return dl } // Progress retrieves the synchronisation boundaries, specifically the origin @@ -241,6 +260,8 @@ func (d *Downloader) RegisterPeer(id string, version int, head common.Hash, glog.V(logger.Error).Infoln("Register failed:", err) return err } + d.qosReduceConfidence() + return nil } @@ -308,20 +329,32 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode default: } } - // Reset any ephemeral sync statistics - d.syncStatsLock.Lock() - d.syncStatsStateTotal = 0 - d.syncStatsStateDone = 0 - d.syncStatsLock.Unlock() - + for _, ch := range []chan dataPack{d.hashCh, d.blockCh, d.headerCh, d.bodyCh, d.receiptCh, d.stateCh} { + for empty := false; !empty; { + select { + case <-ch: + default: + empty = true + } + } + } + for empty := false; !empty; { + select { + case <-d.headerProcCh: + default: + empty = true + } + } // Create cancel channel for aborting mid-flight d.cancelLock.Lock() d.cancelCh = make(chan struct{}) d.cancelLock.Unlock() + defer d.cancel() // No matter what, we can't leave the cancel channel open + // Set the requested sync mode, unless it's forbidden d.mode = mode - if d.mode == FastSync && d.noFast { + if d.mode == FastSync && d.fsPivotFails >= fsCriticalTrials { d.mode = FullSync } // Retrieve the origin peer and initiate the downloading process @@ -369,11 +402,11 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e d.syncStatsLock.Unlock() // Initiate the sync using a concurrent hash and block retrieval algorithm - d.queue.Prepare(origin+1, d.mode, 0) + d.queue.Prepare(origin+1, d.mode, 0, nil) if d.syncInitHook != nil { d.syncInitHook(origin, latest) } - return d.spawnSync( + return d.spawnSync(origin+1, func() error { return d.fetchHashes61(p, td, origin+1) }, func() error { return d.fetchBlocks61(origin + 1) }, ) @@ -384,7 +417,9 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e if err != nil { return err } - origin, err := d.findAncestor(p, latest) + height := latest.Number.Uint64() + + origin, err := d.findAncestor(p, height) if err != nil { return err } @@ -392,22 +427,27 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e if d.syncStatsChainHeight <= origin || d.syncStatsChainOrigin > origin { d.syncStatsChainOrigin = origin } - d.syncStatsChainHeight = latest + d.syncStatsChainHeight = height d.syncStatsLock.Unlock() // Initiate the sync using a concurrent header and content retrieval algorithm pivot := uint64(0) switch d.mode { case LightSync: - pivot = latest + pivot = height case FastSync: // Calculate the new fast/slow sync pivot point - pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval))) - if err != nil { - panic(fmt.Sprintf("Failed to access crypto random source: %v", err)) - } - if latest > uint64(fsMinFullBlocks)+pivotOffset.Uint64() { - pivot = latest - uint64(fsMinFullBlocks) - pivotOffset.Uint64() + if d.fsPivotLock == nil { + pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval))) + if err != nil { + panic(fmt.Sprintf("Failed to access crypto random source: %v", err)) + } + if height > uint64(fsMinFullBlocks)+pivotOffset.Uint64() { + pivot = height - uint64(fsMinFullBlocks) - pivotOffset.Uint64() + } + } else { + // Pivot point locked in, use this and do not pick a new one! + pivot = d.fsPivotLock.Number.Uint64() } // If the point is below the origin, move origin back to ensure state download if pivot < origin { @@ -419,15 +459,16 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e } glog.V(logger.Debug).Infof("Fast syncing until pivot block #%d", pivot) } - d.queue.Prepare(origin+1, d.mode, pivot) + d.queue.Prepare(origin+1, d.mode, pivot, latest) if d.syncInitHook != nil { - d.syncInitHook(origin, latest) + d.syncInitHook(origin, height) } - return d.spawnSync( - func() error { return d.fetchHeaders(p, td, origin+1) }, // Headers are always retrieved - func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync - func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync - func() error { return d.fetchNodeData() }, // Node state data is retrieved during fast sync + return d.spawnSync(origin+1, + func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved + func() error { return d.processHeaders(origin+1, td) }, // Headers are always retrieved + func() error { return d.fetchBodies(origin + 1) }, // Bodies are retrieved during normal and fast sync + func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync + func() error { return d.fetchNodeData() }, // Node state data is retrieved during fast sync ) default: @@ -439,11 +480,11 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e // spawnSync runs d.process and all given fetcher functions to completion in // separate goroutines, returning the first error that appears. -func (d *Downloader) spawnSync(fetchers ...func() error) error { +func (d *Downloader) spawnSync(origin uint64, fetchers ...func() error) error { var wg sync.WaitGroup errc := make(chan error, len(fetchers)+1) wg.Add(len(fetchers) + 1) - go func() { defer wg.Done(); errc <- d.process() }() + go func() { defer wg.Done(); errc <- d.processContent() }() for _, fn := range fetchers { fn := fn go func() { defer wg.Done(); errc <- fn() }() @@ -486,7 +527,16 @@ func (d *Downloader) cancel() { // Terminate interrupts the downloader, canceling all pending operations. // The downloader cannot be reused after calling Terminate. func (d *Downloader) Terminate() { - atomic.StoreInt32(&d.interrupt, 1) + // Close the termination channel (make sure double close is allowed) + d.quitLock.Lock() + select { + case <-d.quitCh: + default: + close(d.quitCh) + } + d.quitLock.Unlock() + + // Cancel any pending download requests d.cancel() } @@ -702,9 +752,9 @@ func (d *Downloader) fetchHashes61(p *peer, td *big.Int, from uint64) error { getHashes := func(from uint64) { glog.V(logger.Detail).Infof("%v: fetching %d hashes from #%d", p, MaxHashFetch, from) - go p.getAbsHashes(from, MaxHashFetch) request = time.Now() timeout.Reset(hashTTL) + go p.getAbsHashes(from, MaxHashFetch) } // Start pulling hashes, until all are exhausted getHashes(from) @@ -903,7 +953,7 @@ func (d *Downloader) fetchBlocks61(from uint64) error { // Reserve a chunk of hashes for a peer. A nil can mean either that // no more hashes are available, or that the peer is known not to // have them. - request := d.queue.ReserveBlocks(peer, peer.BlockCapacity()) + request := d.queue.ReserveBlocks(peer, peer.BlockCapacity(blockTargetRTT)) if request == nil { continue } @@ -938,17 +988,17 @@ func (d *Downloader) fetchBlocks61(from uint64) error { // fetchHeight retrieves the head header of the remote peer to aid in estimating // the total time a pending synchronisation would take. -func (d *Downloader) fetchHeight(p *peer) (uint64, error) { +func (d *Downloader) fetchHeight(p *peer) (*types.Header, error) { glog.V(logger.Debug).Infof("%v: retrieving remote chain height", p) // Request the advertised remote head block and wait for the response go p.getRelHeaders(p.head, 1, 0, false) - timeout := time.After(headerTTL) + timeout := time.After(d.requestTTL()) for { select { case <-d.cancelCh: - return 0, errCancelBlockFetch + return nil, errCancelBlockFetch case packet := <-d.headerCh: // Discard anything not from the origin peer @@ -960,13 +1010,13 @@ func (d *Downloader) fetchHeight(p *peer) (uint64, error) { headers := packet.(*headerPack).headers if len(headers) != 1 { glog.V(logger.Debug).Infof("%v: invalid number of head headers: %d != 1", p, len(headers)) - return 0, errBadPeer + return nil, errBadPeer } - return headers[0].Number.Uint64(), nil + return headers[0], nil case <-timeout: glog.V(logger.Debug).Infof("%v: head header timeout", p) - return 0, errTimeout + return nil, errTimeout case <-d.bodyCh: case <-d.stateCh: @@ -1012,7 +1062,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { // Wait for the remote response to the head fetch number, hash := uint64(0), common.Hash{} - timeout := time.After(hashTTL) + timeout := time.After(d.requestTTL()) for finished := false; !finished; { select { @@ -1050,7 +1100,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { continue } // Otherwise check if we already know the header or not - if (d.mode != LightSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) { + if (d.mode == FullSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.hasHeader(headers[i].Hash())) { number, hash = headers[i].Number.Uint64(), headers[i].Hash() break } @@ -1089,7 +1139,7 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { // Split our chain interval in two, and request the hash to cross check check := (start + end) / 2 - timeout := time.After(hashTTL) + timeout := time.After(d.requestTTL()) go p.getAbsHeaders(uint64(check), 1, 0, false) // Wait until a reply arrives to this request @@ -1149,55 +1199,39 @@ func (d *Downloader) findAncestor(p *peer, height uint64) (uint64, error) { return start, nil } -// fetchHeaders keeps retrieving headers from the requested number, until no more -// are returned, potentially throttling on the way. -// -// The queue parameter can be used to switch between queuing headers for block -// body download too, or directly import as pure header chains. -func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { - glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from) +// fetchHeaders keeps retrieving headers concurrently from the number +// requested, until no more are returned, potentially throttling on the way. To +// facilitate concurrency but still protect against malicious nodes sending bad +// headers, we construct a header chain skeleton using the "origin" peer we are +// syncing with, and fill in the missing headers using anyone else. Headers from +// other peers are only accepted if they map cleanly to the skeleton. If no one +// can fill in the skeleton - not even the origin peer - it's assumed invalid and +// the origin is dropped. +func (d *Downloader) fetchHeaders(p *peer, from uint64) error { + glog.V(logger.Debug).Infof("%v: directing header downloads from #%d", p, from) defer glog.V(logger.Debug).Infof("%v: header download terminated", p) - // Calculate the pivoting point for switching from fast to slow sync - pivot := d.queue.FastSyncPivot() - - // Keep a count of uncertain headers to roll back - rollback := []*types.Header{} - defer func() { - if len(rollback) > 0 { - // Flatten the headers and roll them back - hashes := make([]common.Hash, len(rollback)) - for i, header := range rollback { - hashes[i] = header.Hash() - } - lh, lfb, lb := d.headHeader().Number, d.headFastBlock().Number(), d.headBlock().Number() - d.rollback(hashes) - glog.V(logger.Warn).Infof("Rolled back %d headers (LH: %d->%d, FB: %d->%d, LB: %d->%d)", - len(hashes), lh, d.headHeader().Number, lfb, d.headFastBlock().Number(), lb, d.headBlock().Number()) - - // If we're already past the pivot point, this could be an attack, disable fast sync - if rollback[len(rollback)-1].Number.Uint64() > pivot { - d.noFast = true - } - } - }() - - // Create a timeout timer, and the associated hash fetcher - request := time.Now() // time of the last fetch request + // Create a timeout timer, and the associated header fetcher + skeleton := true // Skeleton assembly phase or finishing up + request := time.Now() // time of the last skeleton fetch request timeout := time.NewTimer(0) // timer to dump a non-responsive active peer <-timeout.C // timeout channel should be initially empty defer timeout.Stop() getHeaders := func(from uint64) { - glog.V(logger.Detail).Infof("%v: fetching %d headers from #%d", p, MaxHeaderFetch, from) - - go p.getAbsHeaders(from, MaxHeaderFetch, 0, false) request = time.Now() - timeout.Reset(headerTTL) + timeout.Reset(d.requestTTL()) + + if skeleton { + glog.V(logger.Detail).Infof("%v: fetching %d skeleton headers from #%d", p, MaxHeaderFetch, from) + go p.getAbsHeaders(from+uint64(MaxHeaderFetch)-1, MaxSkeletonSize, MaxHeaderFetch-1, false) + } else { + glog.V(logger.Detail).Infof("%v: fetching %d full headers from #%d", p, MaxHeaderFetch, from) + go p.getAbsHeaders(from, MaxHeaderFetch, 0, false) + } } - // Start pulling headers, until all are exhausted + // Start pulling the header chain skeleton until all is done getHeaders(from) - gotHeaders := false for { select { @@ -1205,116 +1239,52 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { return errCancelHeaderFetch case packet := <-d.headerCh: - // Make sure the active peer is giving us the headers + // Make sure the active peer is giving us the skeleton headers if packet.PeerId() != p.id { - glog.V(logger.Debug).Infof("Received headers from incorrect peer (%s)", packet.PeerId()) + glog.V(logger.Debug).Infof("Received skeleton headers from incorrect peer (%s)", packet.PeerId()) break } headerReqTimer.UpdateSince(request) timeout.Stop() + // If the skeleton's finished, pull any remaining head headers directly from the origin + if packet.Items() == 0 && skeleton { + skeleton = false + getHeaders(from) + continue + } // If no more headers are inbound, notify the content fetchers and return if packet.Items() == 0 { glog.V(logger.Debug).Infof("%v: no available headers", p) - - for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh, d.stateWakeCh} { - select { - case ch <- false: - case <-d.cancelCh: - } + select { + case d.headerProcCh <- nil: + return nil + case <-d.cancelCh: + return errCancelHeaderFetch } - // If no headers were retrieved at all, the peer violated it's TD promise that it had a - // better chain compared to ours. The only exception is if it's promised blocks were - // already imported by other means (e.g. fetcher): - // - // R , L : Both at block 10 - // R: Mine block 11, and propagate it to L - // L: Queue block 11 for import - // L: Notice that R's head and TD increased compared to ours, start sync - // L: Import of block 11 finishes - // L: Sync begins, and finds common ancestor at 11 - // L: Request new headers up from 11 (R's TD was higher, it must have something) - // R: Nothing to give - if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 { - return errStallingPeer - } - // If fast or light syncing, ensure promised headers are indeed delivered. This is - // needed to detect scenarios where an attacker feeds a bad pivot and then bails out - // of delivering the post-pivot blocks that would flag the invalid content. - // - // This check cannot be executed "as is" for full imports, since blocks may still be - // queued for processing when the header download completes. However, as long as the - // peer gave us something useful, we're already happy/progressed (above check). - if d.mode == FastSync || d.mode == LightSync { - if td.Cmp(d.getTd(d.headHeader().Hash())) > 0 { - return errStallingPeer - } - } - rollback = nil - return nil } - gotHeaders = true headers := packet.(*headerPack).headers - // Otherwise insert all the new headers, aborting in case of junk - glog.V(logger.Detail).Infof("%v: schedule %d headers from #%d", p, len(headers), from) - - if d.mode == FastSync || d.mode == LightSync { - // Collect the yet unknown headers to mark them as uncertain - unknown := make([]*types.Header, 0, len(headers)) - for _, header := range headers { - if !d.hasHeader(header.Hash()) { - unknown = append(unknown, header) - } - } - // If we're importing pure headers, verify based on their recentness - frequency := fsHeaderCheckFrequency - if headers[len(headers)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot { - frequency = 1 - } - if n, err := d.insertHeaders(headers, frequency); err != nil { - // If some headers were inserted, add them too to the rollback list - if n > 0 { - rollback = append(rollback, headers[:n]...) - } - glog.V(logger.Debug).Infof("%v: invalid header #%d [%x…]: %v", p, headers[n].Number, headers[n].Hash().Bytes()[:4], err) + // If we received a skeleton batch, resolve internals concurrently + if skeleton { + filled, proced, err := d.fillHeaderSkeleton(from, headers) + if err != nil { + glog.V(logger.Debug).Infof("%v: skeleton chain invalid: %v", p, err) return errInvalidChain } - // All verifications passed, store newly found uncertain headers - rollback = append(rollback, unknown...) - if len(rollback) > fsHeaderSafetyNet { - rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...) + headers = filled[proced:] + from += uint64(proced) + } + // Insert all the new headers and fetch the next batch + if len(headers) > 0 { + glog.V(logger.Detail).Infof("%v: schedule %d headers from #%d", p, len(headers), from) + select { + case d.headerProcCh <- headers: + case <-d.cancelCh: + return errCancelHeaderFetch } + from += uint64(len(headers)) } - if d.mode == FullSync || d.mode == FastSync { - inserts := d.queue.Schedule(headers, from) - if len(inserts) != len(headers) { - glog.V(logger.Debug).Infof("%v: stale headers", p) - return errBadPeer - } - } - // Notify the content fetchers of new headers, but stop if queue is full - cont := d.queue.PendingBlocks() < maxQueuedHeaders && d.queue.PendingReceipts() < maxQueuedHeaders - for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh, d.stateWakeCh} { - if cont { - // We still have headers to fetch, send continuation wake signal (potential) - select { - case ch <- true: - default: - } - } else { - // Header limit reached, send a termination wake signal (enforced) - select { - case ch <- false: - case <-d.cancelCh: - } - } - } - if !cont { - return nil - } - // Queue not yet full, fetch the next batch - from += uint64(len(headers)) getHeaders(from) case <-timeout.C: @@ -1330,7 +1300,11 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { case <-d.cancelCh: } } - return nil + select { + case d.headerProcCh <- nil: + case <-d.cancelCh: + } + return errBadPeer case <-d.hashCh: case <-d.blockCh: @@ -1340,6 +1314,43 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error { } } +// fillHeaderSkeleton concurrently retrieves headers from all our available peers +// and maps them to the provided skeleton header chain. +// +// Any partial results from the beginning of the skeleton is (if possible) forwarded +// immediately to the header processor to keep the rest of the pipeline full even +// in the case of header stalls. +// +// The method returs the entire filled skeleton and also the number of headers +// already forwarded for processing. +func (d *Downloader) fillHeaderSkeleton(from uint64, skeleton []*types.Header) ([]*types.Header, int, error) { + glog.V(logger.Debug).Infof("Filling up skeleton from #%d", from) + d.queue.ScheduleSkeleton(from, skeleton) + + var ( + deliver = func(packet dataPack) (int, error) { + pack := packet.(*headerPack) + return d.queue.DeliverHeaders(pack.peerId, pack.headers, d.headerProcCh) + } + expire = func() map[string]int { return d.queue.ExpireHeaders(d.requestTTL()) } + throttle = func() bool { return false } + reserve = func(p *peer, count int) (*fetchRequest, bool, error) { + return d.queue.ReserveHeaders(p, count), false, nil + } + fetch = func(p *peer, req *fetchRequest) error { return p.FetchHeaders(req.From, MaxHeaderFetch) } + capacity = func(p *peer) int { return p.HeaderCapacity(d.requestRTT()) } + setIdle = func(p *peer, accepted int) { p.SetHeadersIdle(accepted) } + ) + err := d.fetchParts(errCancelHeaderFetch, d.headerCh, deliver, d.queue.headerContCh, expire, + d.queue.PendingHeaders, d.queue.InFlightHeaders, throttle, reserve, + nil, fetch, d.queue.CancelHeaders, capacity, d.peers.HeaderIdlePeers, setIdle, "Header") + + glog.V(logger.Debug).Infof("Skeleton fill terminated: %v", err) + + filled, proced := d.queue.RetrieveHeaders() + return filled, proced, err +} + // fetchBodies iteratively downloads the scheduled block bodies, taking any // available peers, reserving a chunk of blocks for each, waiting for delivery // and also periodically checking for timeouts. @@ -1351,9 +1362,9 @@ func (d *Downloader) fetchBodies(from uint64) error { pack := packet.(*bodyPack) return d.queue.DeliverBodies(pack.peerId, pack.transactions, pack.uncles) } - expire = func() map[string]int { return d.queue.ExpireBodies(bodyTTL) } + expire = func() map[string]int { return d.queue.ExpireBodies(d.requestTTL()) } fetch = func(p *peer, req *fetchRequest) error { return p.FetchBodies(req) } - capacity = func(p *peer) int { return p.BlockCapacity() } + capacity = func(p *peer) int { return p.BlockCapacity(d.requestRTT()) } setIdle = func(p *peer, accepted int) { p.SetBodiesIdle(accepted) } ) err := d.fetchParts(errCancelBodyFetch, d.bodyCh, deliver, d.bodyWakeCh, expire, @@ -1375,9 +1386,9 @@ func (d *Downloader) fetchReceipts(from uint64) error { pack := packet.(*receiptPack) return d.queue.DeliverReceipts(pack.peerId, pack.receipts) } - expire = func() map[string]int { return d.queue.ExpireReceipts(receiptTTL) } + expire = func() map[string]int { return d.queue.ExpireReceipts(d.requestTTL()) } fetch = func(p *peer, req *fetchRequest) error { return p.FetchReceipts(req) } - capacity = func(p *peer) int { return p.ReceiptCapacity() } + capacity = func(p *peer) int { return p.ReceiptCapacity(d.requestRTT()) } setIdle = func(p *peer, accepted int) { p.SetReceiptsIdle(accepted) } ) err := d.fetchParts(errCancelReceiptFetch, d.receiptCh, deliver, d.receiptWakeCh, expire, @@ -1398,6 +1409,11 @@ func (d *Downloader) fetchNodeData() error { deliver = func(packet dataPack) (int, error) { start := time.Now() return d.queue.DeliverNodeData(packet.PeerId(), packet.(*statePack).states, func(err error, delivered int) { + // If the peer returned old-requested data, forgive + if err == trie.ErrNotRequested { + glog.V(logger.Info).Infof("peer %s: replied to stale state request, forgiving", packet.PeerId()) + return + } if err != nil { // If the node data processing failed, the root hash is very wrong, abort glog.V(logger.Error).Infof("peer %d: state processing failed: %v", packet.PeerId(), err) @@ -1405,26 +1421,30 @@ func (d *Downloader) fetchNodeData() error { return } // Processing succeeded, notify state fetcher of continuation - if d.queue.PendingNodeData() > 0 { + pending := d.queue.PendingNodeData() + if pending > 0 { select { case d.stateWakeCh <- true: default: } } - // Log a message to the user and return d.syncStatsLock.Lock() - defer d.syncStatsLock.Unlock() d.syncStatsStateDone += uint64(delivered) - glog.V(logger.Info).Infof("imported %d state entries in %v: processed %d in total", delivered, time.Since(start), d.syncStatsStateDone) + d.syncStatsLock.Unlock() + + // Log a message to the user and return + if delivered > 0 { + glog.V(logger.Info).Infof("imported %d state entries in %v: processed %d, pending at least %d", delivered, time.Since(start), d.syncStatsStateDone, pending) + } }) } - expire = func() map[string]int { return d.queue.ExpireNodeData(stateTTL) } + expire = func() map[string]int { return d.queue.ExpireNodeData(d.requestTTL()) } throttle = func() bool { return false } reserve = func(p *peer, count int) (*fetchRequest, bool, error) { return d.queue.ReserveNodeData(p, count), false, nil } fetch = func(p *peer, req *fetchRequest) error { return p.FetchNodeData(req) } - capacity = func(p *peer) int { return p.NodeDataCapacity() } + capacity = func(p *peer) int { return p.NodeDataCapacity(d.requestRTT()) } setIdle = func(p *peer, accepted int) { p.SetNodeDataIdle(accepted) } ) err := d.fetchParts(errCancelStateFetch, d.stateCh, deliver, d.stateWakeCh, expire, @@ -1438,6 +1458,28 @@ func (d *Downloader) fetchNodeData() error { // fetchParts iteratively downloads scheduled block parts, taking any available // peers, reserving a chunk of fetch requests for each, waiting for delivery and // also periodically checking for timeouts. +// +// As the scheduling/timeout logic mostly is the same for all downloaded data +// types, this method is used by each for data gathering and is instrumented with +// various callbacks to handle the slight differences between processing them. +// +// The instrumentation parameters: +// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) +// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) +// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) +// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) +// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) +// - pending: task callback for the number of requests still needing download (detect completion/non-completability) +// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) +// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) +// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) +// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) +// - fetch: network callback to actually send a particular download request to a physical remote peer +// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) +// - capacity: network callback to retreive the estimated type-specific bandwidth capacity of a peer (traffic shaping) +// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks +// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) +// - kind: textual label of the type being downloaded to display in log mesages func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peer, int) (*fetchRequest, bool, error), fetchHook func([]*types.Header), fetch func(*peer, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peer) int, @@ -1554,7 +1596,9 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv continue } if glog.V(logger.Detail) { - if len(request.Headers) > 0 { + if request.From > 0 { + glog.Infof("%s: requesting %s(s) from #%d", peer, strings.ToLower(kind), request.From) + } else if len(request.Headers) > 0 { glog.Infof("%s: requesting %d %s(s), first at #%d", peer, len(request.Headers), strings.ToLower(kind), request.Headers[0].Number) } else { glog.Infof("%s: requesting %d %s(s)", peer, len(request.Hashes), strings.ToLower(kind)) @@ -1588,9 +1632,178 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv } } -// process takes fetch results from the queue and tries to import them into the -// chain. The type of import operation will depend on the result contents. -func (d *Downloader) process() error { +// processHeaders takes batches of retrieved headers from an input channel and +// keeps processing and scheduling them into the header chain and downloader's +// queue until the stream ends or a failure occurs. +func (d *Downloader) processHeaders(origin uint64, td *big.Int) error { + // Calculate the pivoting point for switching from fast to slow sync + pivot := d.queue.FastSyncPivot() + + // Keep a count of uncertain headers to roll back + rollback := []*types.Header{} + defer func() { + if len(rollback) > 0 { + // Flatten the headers and roll them back + hashes := make([]common.Hash, len(rollback)) + for i, header := range rollback { + hashes[i] = header.Hash() + } + lastHeader, lastFastBlock, lastBlock := d.headHeader().Number, d.headFastBlock().Number(), d.headBlock().Number() + d.rollback(hashes) + glog.V(logger.Warn).Infof("Rolled back %d headers (LH: %d->%d, FB: %d->%d, LB: %d->%d)", + len(hashes), lastHeader, d.headHeader().Number, lastFastBlock, d.headFastBlock().Number(), lastBlock, d.headBlock().Number()) + + // If we're already past the pivot point, this could be an attack, thread carefully + if rollback[len(rollback)-1].Number.Uint64() > pivot { + // If we didn't ever fail, lock in te pivot header (must! not! change!) + if d.fsPivotFails == 0 { + for _, header := range rollback { + if header.Number.Uint64() == pivot { + glog.V(logger.Warn).Infof("Fast-sync critical section failure, locked pivot to header #%d [%x…]", pivot, header.Hash().Bytes()[:4]) + d.fsPivotLock = header + } + } + } + d.fsPivotFails++ + } + } + }() + + // Wait for batches of headers to process + gotHeaders := false + + for { + select { + case <-d.cancelCh: + return errCancelHeaderProcessing + + case headers := <-d.headerProcCh: + // Terminate header processing if we synced up + if len(headers) == 0 { + // Notify everyone that headers are fully processed + for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh, d.stateWakeCh} { + select { + case ch <- false: + case <-d.cancelCh: + } + } + // If no headers were retrieved at all, the peer violated it's TD promise that it had a + // better chain compared to ours. The only exception is if it's promised blocks were + // already imported by other means (e.g. fecher): + // + // R , L : Both at block 10 + // R: Mine block 11, and propagate it to L + // L: Queue block 11 for import + // L: Notice that R's head and TD increased compared to ours, start sync + // L: Import of block 11 finishes + // L: Sync begins, and finds common ancestor at 11 + // L: Request new headers up from 11 (R's TD was higher, it must have something) + // R: Nothing to give + if !gotHeaders && td.Cmp(d.getTd(d.headBlock().Hash())) > 0 { + return errStallingPeer + } + // If fast or light syncing, ensure promised headers are indeed delivered. This is + // needed to detect scenarios where an attacker feeds a bad pivot and then bails out + // of delivering the post-pivot blocks that would flag the invalid content. + // + // This check cannot be executed "as is" for full imports, since blocks may still be + // queued for processing when the header download completes. However, as long as the + // peer gave us something useful, we're already happy/progressed (above check). + if d.mode == FastSync || d.mode == LightSync { + if td.Cmp(d.getTd(d.headHeader().Hash())) > 0 { + return errStallingPeer + } + } + // Disable any rollback and return + rollback = nil + return nil + } + // Otherwise split the chunk of headers into batches and process them + gotHeaders = true + + for len(headers) > 0 { + // Terminate if something failed in between processing chunks + select { + case <-d.cancelCh: + return errCancelHeaderProcessing + default: + } + // Select the next chunk of headers to import + limit := maxHeadersProcess + if limit > len(headers) { + limit = len(headers) + } + chunk := headers[:limit] + + // In case of header only syncing, validate the chunk immediately + if d.mode == FastSync || d.mode == LightSync { + // Collect the yet unknown headers to mark them as uncertain + unknown := make([]*types.Header, 0, len(headers)) + for _, header := range chunk { + if !d.hasHeader(header.Hash()) { + unknown = append(unknown, header) + } + } + // If we're importing pure headers, verify based on their recentness + frequency := fsHeaderCheckFrequency + if chunk[len(chunk)-1].Number.Uint64()+uint64(fsHeaderForceVerify) > pivot { + frequency = 1 + } + if n, err := d.insertHeaders(chunk, frequency); err != nil { + // If some headers were inserted, add them too to the rollback list + if n > 0 { + rollback = append(rollback, chunk[:n]...) + } + glog.V(logger.Debug).Infof("invalid header #%d [%x…]: %v", chunk[n].Number, chunk[n].Hash().Bytes()[:4], err) + return errInvalidChain + } + // All verifications passed, store newly found uncertain headers + rollback = append(rollback, unknown...) + if len(rollback) > fsHeaderSafetyNet { + rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...) + } + } + // If we're fast syncing and just pulled in the pivot, make sure it's the one locked in + if d.mode == FastSync && d.fsPivotLock != nil && chunk[0].Number.Uint64() <= pivot && chunk[len(chunk)-1].Number.Uint64() >= pivot { + if pivot := chunk[int(pivot-chunk[0].Number.Uint64())]; pivot.Hash() != d.fsPivotLock.Hash() { + glog.V(logger.Warn).Infof("Pivot doesn't match locked in version: have #%v [%x…], want #%v [%x…]", pivot.Number, pivot.Hash().Bytes()[:4], d.fsPivotLock.Number, d.fsPivotLock.Hash().Bytes()[:4]) + return errInvalidChain + } + } + // Unless we're doing light chains, schedule the headers for associated content retrieval + if d.mode == FullSync || d.mode == FastSync { + // If we've reached the allowed number of pending headers, stall a bit + for d.queue.PendingBlocks() >= maxQueuedHeaders || d.queue.PendingReceipts() >= maxQueuedHeaders { + select { + case <-d.cancelCh: + return errCancelHeaderProcessing + case <-time.After(time.Second): + } + } + // Otherwise insert the headers for content retrieval + inserts := d.queue.Schedule(chunk, origin) + if len(inserts) != len(chunk) { + glog.V(logger.Debug).Infof("stale headers") + return errBadPeer + } + } + headers = headers[limit:] + origin += uint64(limit) + } + // Signal the content downloaders of the availablility of new tasks + for _, ch := range []chan bool{d.bodyWakeCh, d.receiptWakeCh, d.stateWakeCh} { + select { + case ch <- true: + default: + } + } + } + } +} + +// processContent takes fetch results from the queue and tries to import them +// into the chain. The type of import operation will depend on the result contents. +func (d *Downloader) processContent() error { pivot := d.queue.FastSyncPivot() for { results := d.queue.WaitResults() @@ -1607,8 +1820,10 @@ func (d *Downloader) process() error { } for len(results) != 0 { // Check for any termination requests - if atomic.LoadInt32(&d.interrupt) == 1 { - return errCancelProcessing + select { + case <-d.quitCh: + return errCancelContentProcessing + default: } // Retrieve the a batch of results to import var ( @@ -1709,3 +1924,74 @@ func (d *Downloader) deliver(id string, destCh chan dataPack, packet dataPack, i return errNoSyncActive } } + +// qosTuner is the quality of service tuning loop that occasionally gathers the +// peer latency statistics and updates the estimated request round trip time. +func (d *Downloader) qosTuner() { + for { + // Retrieve the current median RTT and integrate into the previoust target RTT + rtt := time.Duration(float64(1-qosTuningImpact)*float64(atomic.LoadUint64(&d.rttEstimate)) + qosTuningImpact*float64(d.peers.medianRTT())) + atomic.StoreUint64(&d.rttEstimate, uint64(rtt)) + + // A new RTT cycle passed, increase our confidence in the estimated RTT + conf := atomic.LoadUint64(&d.rttConfidence) + conf = conf + (1000000-conf)/2 + atomic.StoreUint64(&d.rttConfidence, conf) + + // Log the new QoS values and sleep until the next RTT + glog.V(logger.Debug).Infof("Quality of service: rtt %v, conf %.3f, ttl %v", rtt, float64(conf)/1000000.0, d.requestTTL()) + select { + case <-d.quitCh: + return + case <-time.After(rtt): + } + } +} + +// qosReduceConfidence is meant to be called when a new peer joins the downloader's +// peer set, needing to reduce the confidence we have in out QoS estimates. +func (d *Downloader) qosReduceConfidence() { + // If we have a single peer, confidence is always 1 + peers := uint64(d.peers.Len()) + if peers == 1 { + atomic.StoreUint64(&d.rttConfidence, 1000000) + return + } + // If we have a ton of peers, don't drop confidence) + if peers >= uint64(qosConfidenceCap) { + return + } + // Otherwise drop the confidence factor + conf := atomic.LoadUint64(&d.rttConfidence) * (peers - 1) / peers + if float64(conf)/1000000 < rttMinConfidence { + conf = uint64(rttMinConfidence * 1000000) + } + atomic.StoreUint64(&d.rttConfidence, conf) + + rtt := time.Duration(atomic.LoadUint64(&d.rttEstimate)) + glog.V(logger.Debug).Infof("Quality of service: rtt %v, conf %.3f, ttl %v", rtt, float64(conf)/1000000.0, d.requestTTL()) +} + +// requestRTT returns the current target round trip time for a download request +// to complete in. +// +// Note, the returned RTT is .9 of the actually estimated RTT. The reason is that +// the downloader tries to adapt queries to the RTT, so multiple RTT values can +// be adapted to, but smaller ones are preffered (stabler download stream). +func (d *Downloader) requestRTT() time.Duration { + return time.Duration(atomic.LoadUint64(&d.rttEstimate)) * 9 / 10 +} + +// requestTTL returns the current timeout allowance for a single download request +// to finish under. +func (d *Downloader) requestTTL() time.Duration { + var ( + rtt = time.Duration(atomic.LoadUint64(&d.rttEstimate)) + conf = float64(atomic.LoadUint64(&d.rttConfidence)) / 1000000.0 + ) + ttl := time.Duration(ttlScaling) * time.Duration(float64(rtt)/conf) + if ttl > ttlLimit { + ttl = ttlLimit + } + return ttl +} diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index b0b0c2bd32..a9c069a926 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -149,22 +149,25 @@ type downloadTester struct { peerReceipts map[string]map[common.Hash]types.Receipts // Receipts belonging to different test peers peerChainTds map[string]map[common.Hash]*big.Int // Total difficulties of the blocks in the peer chains + peerMissingStates map[string]map[common.Hash]bool // State entries that fast sync should not return + lock sync.RWMutex } // newTester creates a new downloader test mocker. func newTester() *downloadTester { tester := &downloadTester{ - ownHashes: []common.Hash{genesis.Hash()}, - ownHeaders: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()}, - ownBlocks: map[common.Hash]*types.Block{genesis.Hash(): genesis}, - ownReceipts: map[common.Hash]types.Receipts{genesis.Hash(): nil}, - ownChainTd: map[common.Hash]*big.Int{genesis.Hash(): genesis.Difficulty()}, - peerHashes: make(map[string][]common.Hash), - peerHeaders: make(map[string]map[common.Hash]*types.Header), - peerBlocks: make(map[string]map[common.Hash]*types.Block), - peerReceipts: make(map[string]map[common.Hash]types.Receipts), - peerChainTds: make(map[string]map[common.Hash]*big.Int), + ownHashes: []common.Hash{genesis.Hash()}, + ownHeaders: map[common.Hash]*types.Header{genesis.Hash(): genesis.Header()}, + ownBlocks: map[common.Hash]*types.Block{genesis.Hash(): genesis}, + ownReceipts: map[common.Hash]types.Receipts{genesis.Hash(): nil}, + ownChainTd: map[common.Hash]*big.Int{genesis.Hash(): genesis.Difficulty()}, + peerHashes: make(map[string][]common.Hash), + peerHeaders: make(map[string]map[common.Hash]*types.Header), + peerBlocks: make(map[string]map[common.Hash]*types.Block), + peerReceipts: make(map[string]map[common.Hash]types.Receipts), + peerChainTds: make(map[string]map[common.Hash]*big.Int), + peerMissingStates: make(map[string]map[common.Hash]bool), } tester.stateDb, _ = ethdb.NewMemDatabase() tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00}) @@ -176,6 +179,12 @@ func newTester() *downloadTester { return tester } +// terminate aborts any operations on the embedded downloader and releases all +// held resources. +func (dl *downloadTester) terminate() { + dl.downloader.Terminate() +} + // sync starts synchronizing with a remote peer, blocking until it completes. func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error { dl.lock.RLock() @@ -188,7 +197,17 @@ func (dl *downloadTester) sync(id string, td *big.Int, mode SyncMode) error { } } dl.lock.RUnlock() - return dl.downloader.synchronise(id, hash, td, mode) + + // Synchronise with the chosen peer and ensure proper cleanup afterwards + err := dl.downloader.synchronise(id, hash, td, mode) + select { + case <-dl.downloader.cancelCh: + // Ok, downloader fully cancelled after sync cycle + default: + // Downloader is still accepting packets, can block a peer up + panic("downloader active post sync cycle") // panic will be caught by tester + } + return err } // hasHeader checks if a header is present in the testers canonical chain. @@ -398,6 +417,7 @@ func (dl *downloadTester) newSlowPeer(id string, version int, hashes []common.Ha dl.peerBlocks[id] = make(map[common.Hash]*types.Block) dl.peerReceipts[id] = make(map[common.Hash]types.Receipts) dl.peerChainTds[id] = make(map[common.Hash]*big.Int) + dl.peerMissingStates[id] = make(map[common.Hash]bool) genesis := hashes[len(hashes)-1] if header := headers[genesis]; header != nil { @@ -560,8 +580,8 @@ func (dl *downloadTester) peerGetAbsHeadersFn(id string, delay time.Duration) fu hashes := dl.peerHashes[id] headers := dl.peerHeaders[id] result := make([]*types.Header, 0, amount) - for i := 0; i < amount && len(hashes)-int(origin)-1-i >= 0; i++ { - if header, ok := headers[hashes[len(hashes)-int(origin)-1-i]]; ok { + for i := 0; i < amount && len(hashes)-int(origin)-1-i*(skip+1) >= 0; i++ { + if header, ok := headers[hashes[len(hashes)-int(origin)-1-i*(skip+1)]]; ok { result = append(result, header) } } @@ -638,7 +658,9 @@ func (dl *downloadTester) peerGetNodeDataFn(id string, delay time.Duration) func results := make([][]byte, 0, len(hashes)) for _, hash := range hashes { if data, err := testdb.Get(hash.Bytes()); err == nil { - results = append(results, data) + if !dl.peerMissingStates[id][hash] { + results = append(results, data) + } } } go dl.downloader.DeliverNodeData(id, results) @@ -724,6 +746,8 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() + tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Synchronise with the peer and make sure all relevant data was retrieved @@ -748,6 +772,8 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() + tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Wrap the importer to allow stepping @@ -835,6 +861,8 @@ func testForkedSync(t *testing.T, protocol int, mode SyncMode) { hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, true) tester := newTester() + defer tester.terminate() + tester.newPeer("fork A", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("fork B", protocol, hashesB, headersB, blocksB, receiptsB) @@ -869,6 +897,8 @@ func testHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, false) tester := newTester() + defer tester.terminate() + tester.newPeer("light", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("heavy", protocol, hashesB[fork/2:], headersB, blocksB, receiptsB) @@ -918,6 +948,8 @@ func testBoundedForkedSync(t *testing.T, protocol int, mode SyncMode) { hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, true) tester := newTester() + defer tester.terminate() + tester.newPeer("original", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("rewriter", protocol, hashesB, headersB, blocksB, receiptsB) @@ -952,6 +984,8 @@ func testBoundedHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { hashesA, hashesB, headersA, headersB, blocksA, blocksB, receiptsA, receiptsB := makeChainFork(common+fork, fork, genesis, nil, false) tester := newTester() + defer tester.terminate() + tester.newPeer("original", protocol, hashesA, headersA, blocksA, receiptsA) tester.newPeer("heavy-rewriter", protocol, hashesB[MaxForkAncestry-17:], headersB, blocksB, receiptsB) // Root the fork below the ancestor limit @@ -971,7 +1005,9 @@ func testBoundedHeavyForkedSync(t *testing.T, protocol int, mode SyncMode) { // bodies. func TestInactiveDownloader62(t *testing.T) { t.Parallel() + tester := newTester() + defer tester.terminate() // Check that neither block headers nor bodies are accepted if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { @@ -986,7 +1022,9 @@ func TestInactiveDownloader62(t *testing.T) { // bodies and receipts. func TestInactiveDownloader63(t *testing.T) { t.Parallel() + tester := newTester() + defer tester.terminate() // Check that neither block headers nor bodies are accepted if err := tester.downloader.DeliverHeaders("bad peer", []*types.Header{}); err != errNoSyncActive { @@ -1023,6 +1061,8 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() + tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Make sure canceling works with a pristine downloader @@ -1058,6 +1098,8 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() + for i := 0; i < targetPeers; i++ { id := fmt.Sprintf("peer #%d", i) tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts) @@ -1087,6 +1129,8 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) { // Create peers of every type tester := newTester() + defer tester.terminate() + tester.newPeer("peer 61", 61, hashes, nil, blocks, nil) tester.newPeer("peer 62", 62, hashes, headers, blocks, nil) tester.newPeer("peer 63", 63, hashes, headers, blocks, receipts) @@ -1124,6 +1168,8 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() + tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) // Instrument the downloader to signal body requests @@ -1177,6 +1223,7 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() // Attempt a full sync with an attacker feeding gapped headers tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) @@ -1209,6 +1256,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() // Attempt a full sync with an attacker feeding shifted headers tester.newPeer("attack", protocol, hashes, headers, blocks, receipts) @@ -1240,6 +1288,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) tester := newTester() + defer tester.terminate() // Attempt to sync with an attacker that feeds junk during the fast sync phase. // This should result in the last fsHeaderSafetyNet headers being rolled back. @@ -1258,6 +1307,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { // rolled back, and also the pivot point being reverted to a non-block status. tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts) missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 + delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing]) // Make sure the fast-attacker doesn't fill in delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing]) if err := tester.sync("block-attack", nil, mode); err == nil { @@ -1277,7 +1327,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts) missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1 - tester.downloader.noFast = false + tester.downloader.fsPivotFails = 0 tester.downloader.syncInitHook = func(uint64, uint64) { for i := missing; i <= len(hashes); i++ { delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i]) @@ -1296,6 +1346,8 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) { t.Errorf("fast sync pivot block #%d not rolled back", head) } } + tester.downloader.fsPivotFails = fsCriticalTrials + // Synchronise with the valid peer and make sure sync succeeds. Since the last // rollback should also disable fast syncing for this process, verify that we // did a fresh full sync. Note, we can't assert anything about the receipts @@ -1328,9 +1380,11 @@ func testHighTDStarvationAttack(t *testing.T, protocol int, mode SyncMode) { t.Parallel() tester := newTester() - hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil, false) + defer tester.terminate() + hashes, headers, blocks, receipts := makeChain(0, 0, genesis, nil, false) tester.newPeer("attack", protocol, []common.Hash{hashes[0]}, headers, blocks, receipts) + if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer { t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer) } @@ -1348,30 +1402,33 @@ func testBlockHeaderAttackerDropping(t *testing.T, protocol int) { result error drop bool }{ - {nil, false}, // Sync succeeded, all is well - {errBusy, false}, // Sync is already in progress, no problem - {errUnknownPeer, false}, // Peer is unknown, was already dropped, don't double drop - {errBadPeer, true}, // Peer was deemed bad for some reason, drop it - {errStallingPeer, true}, // Peer was detected to be stalling, drop it - {errNoPeers, false}, // No peers to download from, soft race, no issue - {errTimeout, true}, // No hashes received in due time, drop the peer - {errEmptyHashSet, true}, // No hashes were returned as a response, drop as it's a dead end - {errEmptyHeaderSet, true}, // No headers were returned as a response, drop as it's a dead end - {errPeersUnavailable, true}, // Nobody had the advertised blocks, drop the advertiser - {errInvalidAncestor, true}, // Agreed upon ancestor is not acceptable, drop the chain rewriter - {errInvalidChain, true}, // Hash chain was detected as invalid, definitely drop - {errInvalidBlock, false}, // A bad peer was detected, but not the sync origin - {errInvalidBody, false}, // A bad peer was detected, but not the sync origin - {errInvalidReceipt, false}, // A bad peer was detected, but not the sync origin - {errCancelHashFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop - {errCancelBlockFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop - {errCancelHeaderFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop - {errCancelBodyFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop - {errCancelReceiptFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop - {errCancelProcessing, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {nil, false}, // Sync succeeded, all is well + {errBusy, false}, // Sync is already in progress, no problem + {errUnknownPeer, false}, // Peer is unknown, was already dropped, don't double drop + {errBadPeer, true}, // Peer was deemed bad for some reason, drop it + {errStallingPeer, true}, // Peer was detected to be stalling, drop it + {errNoPeers, false}, // No peers to download from, soft race, no issue + {errTimeout, true}, // No hashes received in due time, drop the peer + {errEmptyHashSet, true}, // No hashes were returned as a response, drop as it's a dead end + {errEmptyHeaderSet, true}, // No headers were returned as a response, drop as it's a dead end + {errPeersUnavailable, true}, // Nobody had the advertised blocks, drop the advertiser + {errInvalidAncestor, true}, // Agreed upon ancestor is not acceptable, drop the chain rewriter + {errInvalidChain, true}, // Hash chain was detected as invalid, definitely drop + {errInvalidBlock, false}, // A bad peer was detected, but not the sync origin + {errInvalidBody, false}, // A bad peer was detected, but not the sync origin + {errInvalidReceipt, false}, // A bad peer was detected, but not the sync origin + {errCancelHashFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {errCancelBlockFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {errCancelHeaderFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {errCancelBodyFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {errCancelReceiptFetch, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {errCancelHeaderProcessing, false}, // Synchronisation was canceled, origin may be innocent, don't drop + {errCancelContentProcessing, false}, // Synchronisation was canceled, origin may be innocent, don't drop } // Run the tests and check disconnection status tester := newTester() + defer tester.terminate() + for i, tt := range tests { // Register a new peer and ensure it's presence id := fmt.Sprintf("test %d", i) @@ -1413,6 +1470,8 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { progress := make(chan struct{}) tester := newTester() + defer tester.terminate() + tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress @@ -1485,6 +1544,8 @@ func testForkedSyncProgress(t *testing.T, protocol int, mode SyncMode) { progress := make(chan struct{}) tester := newTester() + defer tester.terminate() + tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress @@ -1560,6 +1621,8 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) { progress := make(chan struct{}) tester := newTester() + defer tester.terminate() + tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress @@ -1636,6 +1699,8 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) { progress := make(chan struct{}) tester := newTester() + defer tester.terminate() + tester.downloader.syncInitHook = func(origin, latest uint64) { starting <- struct{}{} <-progress @@ -1722,7 +1787,7 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { impl := tester.peerGetAbsHeadersFn("peer", 0) go impl(from, count, skip, reverse) // None of the extra deliveries should block. - timeout := time.After(5 * time.Second) + timeout := time.After(15 * time.Second) for i := 0; i < cap(deliveriesDone); i++ { select { case <-deliveriesDone: @@ -1735,5 +1800,48 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) { if err := tester.sync("peer", nil, mode); err != nil { t.Errorf("sync failed: %v", err) } + tester.terminate() } } + +// Tests that if fast sync aborts in the critical section, it can restart a few +// times before giving up. +func TestFastCriticalRestarts63(t *testing.T) { testFastCriticalRestarts(t, 63) } +func TestFastCriticalRestarts64(t *testing.T) { testFastCriticalRestarts(t, 64) } + +func testFastCriticalRestarts(t *testing.T, protocol int) { + t.Parallel() + + // Create a large enough blockchin to actually fast sync on + targetBlocks := fsMinFullBlocks + 2*fsPivotInterval - 15 + hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil, false) + + // Create a tester peer with the critical section state roots missing (force failures) + tester := newTester() + defer tester.terminate() + + tester.newPeer("peer", protocol, hashes, headers, blocks, receipts) + for i := 0; i < fsPivotInterval; i++ { + tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true + } + // Synchronise with the peer a few times and make sure they fail until the retry limit + for i := 0; i < fsCriticalTrials; i++ { + // Attempt a sync and ensure it fails properly + if err := tester.sync("peer", nil, FastSync); err == nil { + t.Fatalf("failing fast sync succeeded: %v", err) + } + time.Sleep(500 * time.Millisecond) // Make sure no in-flight requests remain + + // If it's the first failure, pivot should be locked => reenable all others to detect pivot changes + if i == 0 { + tester.lock.Lock() + tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true} + tester.lock.Unlock() + } + } + // Retry limit exhausted, downloader will switch to full sync, should succeed + if err := tester.sync("peer", nil, FastSync); err != nil { + t.Fatalf("failed to synchronise blocks in slow sync: %v", err) + } + assertOwnChain(t, tester, targetBlocks+1) +} diff --git a/eth/downloader/peer.go b/eth/downloader/peer.go index c4846194ba..94d44fca46 100644 --- a/eth/downloader/peer.go +++ b/eth/downloader/peer.go @@ -23,6 +23,8 @@ import ( "errors" "fmt" "math" + "sort" + "strings" "sync" "sync/atomic" "time" @@ -31,8 +33,8 @@ import ( ) const ( - maxLackingHashes = 4096 // Maximum number of entries allowed on the list or lacking items - throughputImpact = 0.1 // The impact a single measurement has on a peer's final throughput value. + maxLackingHashes = 4096 // Maximum number of entries allowed on the list or lacking items + measurementImpact = 0.1 // The impact a single measurement has on a peer's final throughput value. ) // Hash and block fetchers belonging to eth/61 and below @@ -58,15 +60,20 @@ type peer struct { id string // Unique identifier of the peer head common.Hash // Hash of the peers latest known block + headerIdle int32 // Current header activity state of the peer (idle = 0, active = 1) blockIdle int32 // Current block activity state of the peer (idle = 0, active = 1) receiptIdle int32 // Current receipt activity state of the peer (idle = 0, active = 1) stateIdle int32 // Current node data activity state of the peer (idle = 0, active = 1) + headerThroughput float64 // Number of headers measured to be retrievable per second blockThroughput float64 // Number of blocks (bodies) measured to be retrievable per second receiptThroughput float64 // Number of receipts measured to be retrievable per second stateThroughput float64 // Number of node data pieces measured to be retrievable per second - blockStarted time.Time // Time instance when the last block (body)fetch was started + rtt time.Duration // Request round trip time to track responsiveness (QoS) + + headerStarted time.Time // Time instance when the last header fetch was started + blockStarted time.Time // Time instance when the last block (body) fetch was started receiptStarted time.Time // Time instance when the last receipt fetch was started stateStarted time.Time // Time instance when the last node data fetch was started @@ -118,10 +125,12 @@ func (p *peer) Reset() { p.lock.Lock() defer p.lock.Unlock() + atomic.StoreInt32(&p.headerIdle, 0) atomic.StoreInt32(&p.blockIdle, 0) atomic.StoreInt32(&p.receiptIdle, 0) atomic.StoreInt32(&p.stateIdle, 0) + p.headerThroughput = 0 p.blockThroughput = 0 p.receiptThroughput = 0 p.stateThroughput = 0 @@ -151,6 +160,24 @@ func (p *peer) Fetch61(request *fetchRequest) error { return nil } +// FetchHeaders sends a header retrieval request to the remote peer. +func (p *peer) FetchHeaders(from uint64, count int) error { + // Sanity check the protocol version + if p.version < 62 { + panic(fmt.Sprintf("header fetch [eth/62+] requested on eth/%d", p.version)) + } + // Short circuit if the peer is already fetching + if !atomic.CompareAndSwapInt32(&p.headerIdle, 0, 1) { + return errAlreadyFetching + } + p.headerStarted = time.Now() + + // Issue the header retrieval request (absolut upwards without gaps) + go p.getAbsHeaders(from, count, 0, false) + + return nil +} + // FetchBodies sends a block body retrieval request to the remote peer. func (p *peer) FetchBodies(request *fetchRequest) error { // Sanity check the protocol version @@ -217,6 +244,13 @@ func (p *peer) FetchNodeData(request *fetchRequest) error { return nil } +// SetHeadersIdle sets the peer to idle, allowing it to execute new header retrieval +// requests. Its estimated header retrieval throughput is updated with that measured +// just now. +func (p *peer) SetHeadersIdle(delivered int) { + p.setIdle(p.headerStarted, delivered, &p.headerThroughput, &p.headerIdle) +} + // SetBlocksIdle sets the peer to idle, allowing it to execute new block retrieval // requests. Its estimated block retrieval throughput is updated with that measured // just now. @@ -260,35 +294,47 @@ func (p *peer) setIdle(started time.Time, delivered int, throughput *float64, id return } // Otherwise update the throughput with a new measurement - measured := float64(delivered) / (float64(time.Since(started)+1) / float64(time.Second)) // +1 (ns) to ensure non-zero divisor - *throughput = (1-throughputImpact)*(*throughput) + throughputImpact*measured + elapsed := time.Since(started) + 1 // +1 (ns) to ensure non-zero divisor + measured := float64(delivered) / (float64(elapsed) / float64(time.Second)) + + *throughput = (1-measurementImpact)*(*throughput) + measurementImpact*measured + p.rtt = time.Duration((1-measurementImpact)*float64(p.rtt) + measurementImpact*float64(elapsed)) +} + +// HeaderCapacity retrieves the peers header download allowance based on its +// previously discovered throughput. +func (p *peer) HeaderCapacity(targetRTT time.Duration) int { + p.lock.RLock() + defer p.lock.RUnlock() + + return int(math.Min(1+math.Max(1, p.headerThroughput*float64(targetRTT)/float64(time.Second)), float64(MaxHeaderFetch))) } // BlockCapacity retrieves the peers block download allowance based on its // previously discovered throughput. -func (p *peer) BlockCapacity() int { +func (p *peer) BlockCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() - return int(math.Max(1, math.Min(p.blockThroughput*float64(blockTargetRTT)/float64(time.Second), float64(MaxBlockFetch)))) + return int(math.Min(1+math.Max(1, p.blockThroughput*float64(targetRTT)/float64(time.Second)), float64(MaxBlockFetch))) } // ReceiptCapacity retrieves the peers receipt download allowance based on its // previously discovered throughput. -func (p *peer) ReceiptCapacity() int { +func (p *peer) ReceiptCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() - return int(math.Max(1, math.Min(p.receiptThroughput*float64(receiptTargetRTT)/float64(time.Second), float64(MaxReceiptFetch)))) + return int(math.Min(1+math.Max(1, p.receiptThroughput*float64(targetRTT)/float64(time.Second)), float64(MaxReceiptFetch))) } // NodeDataCapacity retrieves the peers state download allowance based on its // previously discovered throughput. -func (p *peer) NodeDataCapacity() int { +func (p *peer) NodeDataCapacity(targetRTT time.Duration) int { p.lock.RLock() defer p.lock.RUnlock() - return int(math.Max(1, math.Min(p.stateThroughput*float64(stateTargetRTT)/float64(time.Second), float64(MaxStateFetch)))) + return int(math.Min(1+math.Max(1, p.stateThroughput*float64(targetRTT)/float64(time.Second)), float64(MaxStateFetch))) } // MarkLacking appends a new entity to the set of items (blocks, receipts, states) @@ -322,15 +368,17 @@ func (p *peer) String() string { p.lock.RLock() defer p.lock.RUnlock() - return fmt.Sprintf("Peer %s [%s]", p.id, - fmt.Sprintf("blocks %3.2f/s, ", p.blockThroughput)+ - fmt.Sprintf("receipts %3.2f/s, ", p.receiptThroughput)+ - fmt.Sprintf("states %3.2f/s, ", p.stateThroughput)+ - fmt.Sprintf("lacking %4d", len(p.lacking)), - ) + return fmt.Sprintf("Peer %s [%s]", p.id, strings.Join([]string{ + fmt.Sprintf("hs %3.2f/s", p.headerThroughput), + fmt.Sprintf("bs %3.2f/s", p.blockThroughput), + fmt.Sprintf("rs %3.2f/s", p.receiptThroughput), + fmt.Sprintf("ss %3.2f/s", p.stateThroughput), + fmt.Sprintf("miss %4d", len(p.lacking)), + fmt.Sprintf("rtt %v", p.rtt), + }, ", ")) } -// peerSet represents the collection of active peer participating in the block +// peerSet represents the collection of active peer participating in the chain // download procedure. type peerSet struct { peers map[string]*peer @@ -359,9 +407,13 @@ func (ps *peerSet) Reset() { // peer is already known. // // The method also sets the starting throughput values of the new peer to the -// average of all existing peers, to give it a realistic change of being used +// average of all existing peers, to give it a realistic chance of being used // for data retrievals. func (ps *peerSet) Register(p *peer) error { + // Retrieve the current median RTT as a sane default + p.rtt = ps.medianRTT() + + // Register the new peer with some meaningful defaults ps.lock.Lock() defer ps.lock.Unlock() @@ -369,15 +421,17 @@ func (ps *peerSet) Register(p *peer) error { return errAlreadyRegistered } if len(ps.peers) > 0 { - p.blockThroughput, p.receiptThroughput, p.stateThroughput = 0, 0, 0 + p.headerThroughput, p.blockThroughput, p.receiptThroughput, p.stateThroughput = 0, 0, 0, 0 for _, peer := range ps.peers { peer.lock.RLock() + p.headerThroughput += peer.headerThroughput p.blockThroughput += peer.blockThroughput p.receiptThroughput += peer.receiptThroughput p.stateThroughput += peer.stateThroughput peer.lock.RUnlock() } + p.headerThroughput /= float64(len(ps.peers)) p.blockThroughput /= float64(len(ps.peers)) p.receiptThroughput /= float64(len(ps.peers)) p.stateThroughput /= float64(len(ps.peers)) @@ -441,6 +495,20 @@ func (ps *peerSet) BlockIdlePeers() ([]*peer, int) { return ps.idlePeers(61, 61, idle, throughput) } +// HeaderIdlePeers retrieves a flat list of all the currently header-idle peers +// within the active peer set, ordered by their reputation. +func (ps *peerSet) HeaderIdlePeers() ([]*peer, int) { + idle := func(p *peer) bool { + return atomic.LoadInt32(&p.headerIdle) == 0 + } + throughput := func(p *peer) float64 { + p.lock.RLock() + defer p.lock.RUnlock() + return p.headerThroughput + } + return ps.idlePeers(62, 64, idle, throughput) +} + // BodyIdlePeers retrieves a flat list of all the currently body-idle peers within // the active peer set, ordered by their reputation. func (ps *peerSet) BodyIdlePeers() ([]*peer, int) { @@ -508,3 +576,34 @@ func (ps *peerSet) idlePeers(minProtocol, maxProtocol int, idleCheck func(*peer) } return idle, total } + +// medianRTT returns the median RTT of te peerset, considering only the tuning +// peers if there are more peers available. +func (ps *peerSet) medianRTT() time.Duration { + // Gather all the currnetly measured round trip times + ps.lock.RLock() + defer ps.lock.RUnlock() + + rtts := make([]float64, 0, len(ps.peers)) + for _, p := range ps.peers { + p.lock.RLock() + rtts = append(rtts, float64(p.rtt)) + p.lock.RUnlock() + } + sort.Float64s(rtts) + + median := rttMaxEstimate + if qosTuningPeers <= len(rtts) { + median = time.Duration(rtts[qosTuningPeers/2]) // Median of our tuning peers + } else if len(rtts) > 0 { + median = time.Duration(rtts[len(rtts)/2]) // Median of our connected peers (maintain even like this some baseline qos) + } + // Restrict the RTT into some QoS defaults, irrelevant of true RTT + if median < rttMinEstimate { + median = rttMinEstimate + } + if median > rttMaxEstimate { + median = rttMaxEstimate + } + return median +} diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index d8d1bddcee..01897af6d4 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -40,7 +40,7 @@ import ( var ( blockCacheLimit = 8192 // Maximum number of blocks to cache before throttling the download - maxInFlightStates = 4096 // Maximum number of state downloads to allow concurrently + maxInFlightStates = 8192 // Maximum number of state downloads to allow concurrently ) var ( @@ -52,6 +52,7 @@ var ( // fetchRequest is a currently running data retrieval operation. type fetchRequest struct { Peer *peer // Peer to which the request was sent + From uint64 // [eth/62] Requested chain element index (used for skeleton fills only) Hashes map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority) Headers []*types.Header // [eth/62] Requested headers, sorted by request order Time time.Time // Time when the request was made @@ -79,6 +80,18 @@ type queue struct { headerHead common.Hash // [eth/62] Hash of the last queued header to verify order + // Headers are "special", they download in batches, supported by a skeleton chain + headerTaskPool map[uint64]*types.Header // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers + headerTaskQueue *prque.Prque // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for + headerPeerMiss map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable + headerPendPool map[string]*fetchRequest // [eth/62] Currently pending header retrieval operations + headerDonePool map[uint64]struct{} // [eth/62] Set of the completed header fetches + headerResults []*types.Header // [eth/62] Result cache accumulating the completed headers + headerProced int // [eth/62] Number of headers already processed from the results + headerOffset uint64 // [eth/62] Number of the first header in the result cache + headerContCh chan bool // [eth/62] Channel to notify when header download finishes + + // All data retrievals below are based on an already assembles header chain blockTaskPool map[common.Hash]*types.Header // [eth/62] Pending block (body) retrieval tasks, mapping hashes to headers blockTaskQueue *prque.Prque // [eth/62] Priority queue of the headers to fetch the blocks (bodies) for blockPendPool map[string]*fetchRequest // [eth/62] Currently pending block (body) retrieval operations @@ -113,6 +126,8 @@ func newQueue(stateDb ethdb.Database) *queue { return &queue{ hashPool: make(map[common.Hash]int), hashQueue: prque.New(), + headerPendPool: make(map[string]*fetchRequest), + headerContCh: make(chan bool), blockTaskPool: make(map[common.Hash]*types.Header), blockTaskQueue: prque.New(), blockPendPool: make(map[string]*fetchRequest), @@ -149,6 +164,8 @@ func (q *queue) Reset() { q.headerHead = common.Hash{} + q.headerPendPool = make(map[string]*fetchRequest) + q.blockTaskPool = make(map[common.Hash]*types.Header) q.blockTaskQueue.Reset() q.blockPendPool = make(map[string]*fetchRequest) @@ -178,6 +195,14 @@ func (q *queue) Close() { q.active.Broadcast() } +// PendingHeaders retrieves the number of header requests pending for retrieval. +func (q *queue) PendingHeaders() int { + q.lock.Lock() + defer q.lock.Unlock() + + return q.headerTaskQueue.Size() +} + // PendingBlocks retrieves the number of block (body) requests pending for retrieval. func (q *queue) PendingBlocks() int { q.lock.Lock() @@ -205,6 +230,15 @@ func (q *queue) PendingNodeData() int { return 0 } +// InFlightHeaders retrieves whether there are header fetch requests currently +// in flight. +func (q *queue) InFlightHeaders() bool { + q.lock.Lock() + defer q.lock.Unlock() + + return len(q.headerPendPool) > 0 +} + // InFlightBlocks retrieves whether there are block fetch requests currently in // flight. func (q *queue) InFlightBlocks() bool { @@ -317,6 +351,45 @@ func (q *queue) Schedule61(hashes []common.Hash, fifo bool) []common.Hash { return inserts } +// ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill +// up an already retrieved header skeleton. +func (q *queue) ScheduleSkeleton(from uint64, skeleton []*types.Header) { + q.lock.Lock() + defer q.lock.Unlock() + + // No skeleton retrieval can be in progress, fail hard if so (huge implementation bug) + if q.headerResults != nil { + panic("skeleton assembly already in progress") + } + // Shedule all the header retrieval tasks for the skeleton assembly + q.headerTaskPool = make(map[uint64]*types.Header) + q.headerTaskQueue = prque.New() + q.headerPeerMiss = make(map[string]map[uint64]struct{}) // Reset availability to correct invalid chains + q.headerResults = make([]*types.Header, len(skeleton)*MaxHeaderFetch) + q.headerProced = 0 + q.headerOffset = from + q.headerContCh = make(chan bool, 1) + + for i, header := range skeleton { + index := from + uint64(i*MaxHeaderFetch) + + q.headerTaskPool[index] = header + q.headerTaskQueue.Push(index, -float32(index)) + } +} + +// RetrieveHeaders retrieves the header chain assemble based on the scheduled +// skeleton. +func (q *queue) RetrieveHeaders() ([]*types.Header, int) { + q.lock.Lock() + defer q.lock.Unlock() + + headers, proced := q.headerResults, q.headerProced + q.headerResults, q.headerProced = nil, 0 + + return headers, proced +} + // Schedule adds a set of headers for the download queue for scheduling, returning // the new headers encountered. func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { @@ -437,6 +510,46 @@ func (q *queue) countProcessableItems() int { return len(q.resultCache) } +// ReserveHeaders reserves a set of headers for the given peer, skipping any +// previously failed batches. +func (q *queue) ReserveHeaders(p *peer, count int) *fetchRequest { + q.lock.Lock() + defer q.lock.Unlock() + + // Short circuit if the peer's already downloading something (sanity check to + // not corrupt state) + if _, ok := q.headerPendPool[p.id]; ok { + return nil + } + // Retrieve a batch of hashes, skipping previously failed ones + send, skip := uint64(0), []uint64{} + for send == 0 && !q.headerTaskQueue.Empty() { + from, _ := q.headerTaskQueue.Pop() + if q.headerPeerMiss[p.id] != nil { + if _, ok := q.headerPeerMiss[p.id][from.(uint64)]; ok { + skip = append(skip, from.(uint64)) + continue + } + } + send = from.(uint64) + } + // Merge all the skipped batches back + for _, from := range skip { + q.headerTaskQueue.Push(from, -float32(from)) + } + // Assemble and return the block download request + if send == 0 { + return nil + } + request := &fetchRequest{ + Peer: p, + From: send, + Time: time.Now(), + } + q.headerPendPool[p.id] = request + return request +} + // ReserveBlocks reserves a set of block hashes for the given peer, skipping any // previously failed download. func (q *queue) ReserveBlocks(p *peer, count int) *fetchRequest { @@ -635,6 +748,11 @@ func (q *queue) reserveHeaders(p *peer, count int, taskPool map[common.Hash]*typ return request, progress, nil } +// CancelHeaders aborts a fetch request, returning all pending skeleton indexes to the queue. +func (q *queue) CancelHeaders(request *fetchRequest) { + q.cancel(request, q.headerTaskQueue, q.headerPendPool) +} + // CancelBlocks aborts a fetch request, returning all pending hashes to the queue. func (q *queue) CancelBlocks(request *fetchRequest) { q.cancel(request, q.hashQueue, q.blockPendPool) @@ -663,6 +781,9 @@ func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool m q.lock.Lock() defer q.lock.Unlock() + if request.From > 0 { + taskQueue.Push(request.From, -float32(request.From)) + } for hash, index := range request.Hashes { taskQueue.Push(hash, float32(index)) } @@ -702,6 +823,15 @@ func (q *queue) Revoke(peerId string) { } } +// ExpireHeaders checks for in flight requests that exceeded a timeout allowance, +// canceling them and returning the responsible peers for penalisation. +func (q *queue) ExpireHeaders(timeout time.Duration) map[string]int { + q.lock.Lock() + defer q.lock.Unlock() + + return q.expire(timeout, q.headerPendPool, q.headerTaskQueue, headerTimeoutMeter) +} + // ExpireBlocks checks for in flight requests that exceeded a timeout allowance, // canceling them and returning the responsible peers for penalisation. func (q *queue) ExpireBlocks(timeout time.Duration) map[string]int { @@ -753,6 +883,9 @@ func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, timeoutMeter.Mark(1) // Return any non satisfied requests to the pool + if request.From > 0 { + taskQueue.Push(request.From, -float32(request.From)) + } for hash, index := range request.Hashes { taskQueue.Push(hash, float32(index)) } @@ -842,6 +975,94 @@ func (q *queue) DeliverBlocks(id string, blocks []*types.Block) (int, error) { } } +// DeliverHeaders injects a header retrieval response into the header results +// cache. This method either accepts all headers it received, or none of them +// if they do not map correctly to the skeleton. +// +// If the headers are accepted, the method makes an attempt to deliver the set +// of ready headers to the processor to keep the pipeline full. However it will +// not block to prevent stalling other pending deliveries. +func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh chan []*types.Header) (int, error) { + q.lock.Lock() + defer q.lock.Unlock() + + // Short circuit if the data was never requested + request := q.headerPendPool[id] + if request == nil { + return 0, errNoFetchesPending + } + headerReqTimer.UpdateSince(request.Time) + delete(q.headerPendPool, id) + + // Ensure headers can be mapped onto the skeleton chain + target := q.headerTaskPool[request.From].Hash() + + accepted := len(headers) == MaxHeaderFetch + if accepted { + if headers[0].Number.Uint64() != request.From { + glog.V(logger.Detail).Infof("Peer %s: first header #%v [%x] broke chain ordering, expected %d", id, headers[0].Number, headers[0].Hash().Bytes()[:4], request.From) + accepted = false + } else if headers[len(headers)-1].Hash() != target { + glog.V(logger.Detail).Infof("Peer %s: last header #%v [%x] broke skeleton structure, expected %x", id, headers[len(headers)-1].Number, headers[len(headers)-1].Hash().Bytes()[:4], target[:4]) + accepted = false + } + } + if accepted { + for i, header := range headers[1:] { + hash := header.Hash() + if want := request.From + 1 + uint64(i); header.Number.Uint64() != want { + glog.V(logger.Warn).Infof("Peer %s: header #%v [%x] broke chain ordering, expected %d", id, header.Number, hash[:4], want) + accepted = false + break + } + if headers[i].Hash() != header.ParentHash { + glog.V(logger.Warn).Infof("Peer %s: header #%v [%x] broke chain ancestry", id, header.Number, hash[:4]) + accepted = false + break + } + } + } + // If the batch of headers wasn't accepted, mark as unavailable + if !accepted { + glog.V(logger.Detail).Infof("Peer %s: skeleton filling from header #%d not accepted", id, request.From) + + miss := q.headerPeerMiss[id] + if miss == nil { + q.headerPeerMiss[id] = make(map[uint64]struct{}) + miss = q.headerPeerMiss[id] + } + miss[request.From] = struct{}{} + + q.headerTaskQueue.Push(request.From, -float32(request.From)) + return 0, errors.New("delivery not accepted") + } + // Clean up a successful fetch and try to deliver any sub-results + copy(q.headerResults[request.From-q.headerOffset:], headers) + delete(q.headerTaskPool, request.From) + + ready := 0 + for q.headerProced+ready < len(q.headerResults) && q.headerResults[q.headerProced+ready] != nil { + ready += MaxHeaderFetch + } + if ready > 0 { + // Headers are ready for delivery, gather them and push forward (non blocking) + process := make([]*types.Header, ready) + copy(process, q.headerResults[q.headerProced:q.headerProced+ready]) + + select { + case headerProcCh <- process: + glog.V(logger.Detail).Infof("%s: pre-scheduled %d headers from #%v", id, len(process), process[0].Number) + q.headerProced += len(process) + default: + } + } + // Check for termination and return + if len(q.headerTaskPool) == 0 { + q.headerContCh <- false + } + return len(headers), nil +} + // DeliverBodies injects a block body retrieval response into the results queue. // The method returns the number of blocks bodies accepted from the delivery and // also wakes any threads waiting for data delivery. @@ -1041,13 +1262,19 @@ func (q *queue) deliverNodeData(results []trie.SyncResult, callback func(error, // Prepare configures the result cache to allow accepting and caching inbound // fetch results. -func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64) { +func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.Header) { q.lock.Lock() defer q.lock.Unlock() + // Prepare the queue for sync results if q.resultOffset < offset { q.resultOffset = offset } q.fastSyncPivot = pivot q.mode = mode + + // If long running fast sync, also start up a head stateretrieval immediately + if mode == FastSync && pivot > 0 { + q.stateScheduler = state.NewStateSync(head.Root, q.stateDatabase) + } } diff --git a/eth/handler.go b/eth/handler.go index 202acdc78a..1e4dc1289e 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -59,7 +59,9 @@ type blockFetcherFn func([]common.Hash) error type ProtocolManager struct { networkId int - fastSync uint32 + fastSync uint32 // Flag whether fast sync is enabled (gets disabled if we already have blocks) + synced uint32 // Flag whether we're considered synchronised (enables transaction processing) + txpool txPool blockchain *core.BlockChain chaindb ethdb.Database @@ -83,6 +85,8 @@ type ProtocolManager struct { // wait group is used for graceful shutdowns during downloading // and processing wg sync.WaitGroup + + badBlockReportingEnabled bool } // NewProtocolManager returns a new ethereum sub protocol manager. The Ethereum sub protocol manages peers capable @@ -150,7 +154,7 @@ func NewProtocolManager(config *core.ChainConfig, fastSync bool, networkId int, // Construct the different synchronisation mechanisms manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeader, blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, - blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, + blockchain.GetTd, blockchain.InsertHeaderChain, manager.insertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer) validator := func(block *types.Block, parent *types.Block) error { @@ -159,11 +163,28 @@ func NewProtocolManager(config *core.ChainConfig, fastSync bool, networkId int, heighter := func() uint64 { return blockchain.CurrentBlock().NumberU64() } - manager.fetcher = fetcher.New(blockchain.GetBlock, validator, manager.BroadcastBlock, heighter, blockchain.InsertChain, manager.removePeer) + inserter := func(blocks types.Blocks) (int, error) { + atomic.StoreUint32(&manager.synced, 1) // Mark initial sync done on any fetcher import + return manager.insertChain(blocks) + } + manager.fetcher = fetcher.New(blockchain.GetBlock, validator, manager.BroadcastBlock, heighter, inserter, manager.removePeer) + + if blockchain.Genesis().Hash().Hex() == defaultGenesisHash && networkId == 1 { + glog.V(logger.Debug).Infoln("Bad Block Reporting is enabled") + manager.badBlockReportingEnabled = true + } return manager, nil } +func (pm *ProtocolManager) insertChain(blocks types.Blocks) (i int, err error) { + i, err = pm.blockchain.InsertChain(blocks) + if pm.badBlockReportingEnabled && core.IsValidationErr(err) && i < len(blocks) { + go sendBadBlockReport(blocks[i], err) + } + return i, err +} + func (pm *ProtocolManager) removePeer(id string) { // Short circuit if the peer was already removed peer := pm.peers.Peer(id) @@ -378,6 +399,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { // Update the receive timestamp of each block for _, block := range blocks { block.ReceivedAt = msg.ReceivedAt + block.ReceivedFrom = p } // Filter out any explicitly requested blocks, deliver the rest to the downloader if blocks := pm.fetcher.FilterBlocks(blocks); len(blocks) > 0 { @@ -664,6 +686,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { return errResp(ErrDecode, "block validation %v: %v", msg, err) } request.Block.ReceivedAt = msg.ReceivedAt + request.Block.ReceivedFrom = p // Mark the peer as owning the block and schedule it for import p.MarkBlock(request.Block.Hash()) @@ -681,8 +704,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } case msg.Code == TxMsg: - // Transactions arrived, make sure we have a valid chain to handle them - if atomic.LoadUint32(&pm.fastSync) == 1 { + // Transactions arrived, make sure we have a valid and fresh chain to handle them + if atomic.LoadUint32(&pm.synced) == 0 { break } // Transactions can be processed, parse all of them and deliver to the pool diff --git a/eth/protocol_test.go b/eth/protocol_test.go index 0a82e2e795..f860d0a35a 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -97,6 +97,7 @@ func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) } func testRecvTransactions(t *testing.T, protocol int) { txAdded := make(chan []*types.Transaction) pm := newTestProtocolManagerMust(t, false, 0, nil, txAdded) + pm.synced = 1 // mark synced to accept transactions p, _ := newTestPeer("peer", protocol, pm, true) defer pm.Stop() defer p.close() diff --git a/eth/sync.go b/eth/sync.go index 4b16c13226..52f7e90e7b 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -174,6 +174,8 @@ func (pm *ProtocolManager) synchronise(peer *peer) { if err := pm.downloader.Synchronise(peer.id, peer.Head(), peer.Td(), mode); err != nil { return } + atomic.StoreUint32(&pm.synced, 1) // Mark initial sync done + // If fast sync was enabled, and we synced up, disable it if atomic.LoadUint32(&pm.fastSync) == 1 { // Disable fast sync if we indeed have something in our chain diff --git a/jsre/bignumber_js.go b/internal/jsre/bignumber_js.go similarity index 100% rename from jsre/bignumber_js.go rename to internal/jsre/bignumber_js.go diff --git a/jsre/completion.go b/internal/jsre/completion.go similarity index 100% rename from jsre/completion.go rename to internal/jsre/completion.go diff --git a/jsre/completion_test.go b/internal/jsre/completion_test.go similarity index 98% rename from jsre/completion_test.go rename to internal/jsre/completion_test.go index 92af5ddb64..ccbd73dccc 100644 --- a/jsre/completion_test.go +++ b/internal/jsre/completion_test.go @@ -17,12 +17,13 @@ package jsre import ( + "os" "reflect" "testing" ) func TestCompleteKeywords(t *testing.T) { - re := New("") + re := New("", os.Stdout) re.Run(` function theClass() { this.foo = 3; diff --git a/jsre/ethereum_js.go b/internal/jsre/ethereum_js.go similarity index 100% rename from jsre/ethereum_js.go rename to internal/jsre/ethereum_js.go diff --git a/jsre/jsre.go b/internal/jsre/jsre.go similarity index 93% rename from jsre/jsre.go rename to internal/jsre/jsre.go index 59730bc0da..4813893047 100644 --- a/jsre/jsre.go +++ b/internal/jsre/jsre.go @@ -21,9 +21,9 @@ import ( crand "crypto/rand" "encoding/binary" "fmt" + "io" "io/ioutil" "math/rand" - "sync" "time" "github.com/ethereum/go-ethereum/common" @@ -40,9 +40,10 @@ It provides some helper functions to */ type JSRE struct { assetPath string + output io.Writer evalQueue chan *evalReq stopEventLoop chan bool - loopWg sync.WaitGroup + closed chan struct{} } // jsTimer is a single timer instance with a callback function @@ -60,13 +61,14 @@ type evalReq struct { } // runtime must be stopped with Stop() after use and cannot be used after stopping -func New(assetPath string) *JSRE { +func New(assetPath string, output io.Writer) *JSRE { re := &JSRE{ assetPath: assetPath, + output: output, + closed: make(chan struct{}), evalQueue: make(chan *evalReq), stopEventLoop: make(chan bool), } - re.loopWg.Add(1) go re.runEventLoop() re.Set("loadScript", re.loadScript) re.Set("inspect", prettyPrintJS) @@ -95,6 +97,8 @@ func randomSource() *rand.Rand { // functions should be used if and only if running a routine that was already // called from JS through an RPC call. func (self *JSRE) runEventLoop() { + defer close(self.closed) + vm := otto.New() r := randomSource() vm.SetRandomSource(r.Float64) @@ -210,8 +214,6 @@ loop: timer.timer.Stop() delete(registry, timer) } - - self.loopWg.Done() } // Do executes the given function on the JS event loop. @@ -224,8 +226,11 @@ func (self *JSRE) Do(fn func(*otto.Otto)) { // stops the event loop before exit, optionally waits for all timers to expire func (self *JSRE) Stop(waitForCallbacks bool) { - self.stopEventLoop <- waitForCallbacks - self.loopWg.Wait() + select { + case <-self.closed: + case self.stopEventLoop <- waitForCallbacks: + <-self.closed + } } // Exec(file) loads and runs the contents of a file @@ -292,19 +297,21 @@ func (self *JSRE) loadScript(call otto.FunctionCall) otto.Value { return otto.TrueValue() } -// EvalAndPrettyPrint evaluates code and pretty prints the result to -// standard output. -func (self *JSRE) EvalAndPrettyPrint(code string) (err error) { +// Evaluate executes code and pretty prints the result to the specified output +// stream. +func (self *JSRE) Evaluate(code string, w io.Writer) error { + var fail error + self.Do(func(vm *otto.Otto) { - var val otto.Value - val, err = vm.Run(code) + val, err := vm.Run(code) if err != nil { - return + prettyError(vm, err, w) + } else { + prettyPrint(vm, val, w) } - prettyPrint(vm, val) - fmt.Println() + fmt.Fprintln(w) }) - return err + return fail } // Compile compiles and then runs a piece of JS code. diff --git a/jsre/jsre_test.go b/internal/jsre/jsre_test.go similarity index 98% rename from jsre/jsre_test.go rename to internal/jsre/jsre_test.go index ffb6999db7..bcb6e0dd23 100644 --- a/jsre/jsre_test.go +++ b/internal/jsre/jsre_test.go @@ -51,7 +51,7 @@ func newWithTestJS(t *testing.T, testjs string) (*JSRE, string) { t.Fatal("cannot create test.js:", err) } } - return New(dir), dir + return New(dir, os.Stdout), dir } func TestExec(t *testing.T) { @@ -102,7 +102,7 @@ func TestNatto(t *testing.T) { } func TestBind(t *testing.T) { - jsre := New("") + jsre := New("", os.Stdout) defer jsre.Stop(false) jsre.Bind("no", &testNativeObjectBinding{}) diff --git a/jsre/pretty.go b/internal/jsre/pretty.go similarity index 72% rename from jsre/pretty.go rename to internal/jsre/pretty.go index cd7fa5232b..30d8660ff6 100644 --- a/jsre/pretty.go +++ b/internal/jsre/pretty.go @@ -18,6 +18,7 @@ package jsre import ( "fmt" + "io" "sort" "strconv" "strings" @@ -32,10 +33,11 @@ const ( ) var ( - functionColor = color.New(color.FgMagenta) - specialColor = color.New(color.Bold) - numberColor = color.New(color.FgRed) - stringColor = color.New(color.FgGreen) + FunctionColor = color.New(color.FgMagenta).SprintfFunc() + SpecialColor = color.New(color.Bold).SprintfFunc() + NumberColor = color.New(color.FgRed).SprintfFunc() + StringColor = color.New(color.FgGreen).SprintfFunc() + ErrorColor = color.New(color.FgHiRed).SprintfFunc() ) // these fields are hidden when printing objects. @@ -50,19 +52,39 @@ var boringKeys = map[string]bool{ } // prettyPrint writes value to standard output. -func prettyPrint(vm *otto.Otto, value otto.Value) { - ppctx{vm}.printValue(value, 0, false) +func prettyPrint(vm *otto.Otto, value otto.Value, w io.Writer) { + ppctx{vm: vm, w: w}.printValue(value, 0, false) } -func prettyPrintJS(call otto.FunctionCall) otto.Value { +// prettyError writes err to standard output. +func prettyError(vm *otto.Otto, err error, w io.Writer) { + failure := err.Error() + if ottoErr, ok := err.(*otto.Error); ok { + failure = ottoErr.String() + } + fmt.Fprint(w, ErrorColor("%s", failure)) +} + +// jsErrorString adds a backtrace to errors generated by otto. +func jsErrorString(err error) string { + if ottoErr, ok := err.(*otto.Error); ok { + return ottoErr.String() + } + return err.Error() +} + +func prettyPrintJS(call otto.FunctionCall, w io.Writer) otto.Value { for _, v := range call.ArgumentList { - prettyPrint(call.Otto, v) - fmt.Println() + prettyPrint(call.Otto, v, w) + fmt.Fprintln(w) } return otto.UndefinedValue() } -type ppctx struct{ vm *otto.Otto } +type ppctx struct { + vm *otto.Otto + w io.Writer +} func (ctx ppctx) indent(level int) string { return strings.Repeat(indentString, level) @@ -73,22 +95,22 @@ func (ctx ppctx) printValue(v otto.Value, level int, inArray bool) { case v.IsObject(): ctx.printObject(v.Object(), level, inArray) case v.IsNull(): - specialColor.Print("null") + fmt.Fprint(ctx.w, SpecialColor("null")) case v.IsUndefined(): - specialColor.Print("undefined") + fmt.Fprint(ctx.w, SpecialColor("undefined")) case v.IsString(): s, _ := v.ToString() - stringColor.Printf("%q", s) + fmt.Fprint(ctx.w, StringColor("%q", s)) case v.IsBoolean(): b, _ := v.ToBoolean() - specialColor.Printf("%t", b) + fmt.Fprint(ctx.w, SpecialColor("%t", b)) case v.IsNaN(): - numberColor.Printf("NaN") + fmt.Fprint(ctx.w, NumberColor("NaN")) case v.IsNumber(): s, _ := v.ToString() - numberColor.Printf("%s", s) + fmt.Fprint(ctx.w, NumberColor("%s", s)) default: - fmt.Printf("") + fmt.Fprint(ctx.w, "") } } @@ -98,75 +120,75 @@ func (ctx ppctx) printObject(obj *otto.Object, level int, inArray bool) { lv, _ := obj.Get("length") len, _ := lv.ToInteger() if len == 0 { - fmt.Printf("[]") + fmt.Fprintf(ctx.w, "[]") return } if level > maxPrettyPrintLevel { - fmt.Print("[...]") + fmt.Fprint(ctx.w, "[...]") return } - fmt.Print("[") + fmt.Fprint(ctx.w, "[") for i := int64(0); i < len; i++ { el, err := obj.Get(strconv.FormatInt(i, 10)) if err == nil { ctx.printValue(el, level+1, true) } if i < len-1 { - fmt.Printf(", ") + fmt.Fprintf(ctx.w, ", ") } } - fmt.Print("]") + fmt.Fprint(ctx.w, "]") case "Object": // Print values from bignumber.js as regular numbers. if ctx.isBigNumber(obj) { - numberColor.Print(toString(obj)) + fmt.Fprint(ctx.w, NumberColor("%s", toString(obj))) return } // Otherwise, print all fields indented, but stop if we're too deep. keys := ctx.fields(obj) if len(keys) == 0 { - fmt.Print("{}") + fmt.Fprint(ctx.w, "{}") return } if level > maxPrettyPrintLevel { - fmt.Print("{...}") + fmt.Fprint(ctx.w, "{...}") return } - fmt.Println("{") + fmt.Fprintln(ctx.w, "{") for i, k := range keys { v, _ := obj.Get(k) - fmt.Printf("%s%s: ", ctx.indent(level+1), k) + fmt.Fprintf(ctx.w, "%s%s: ", ctx.indent(level+1), k) ctx.printValue(v, level+1, false) if i < len(keys)-1 { - fmt.Printf(",") + fmt.Fprintf(ctx.w, ",") } - fmt.Println() + fmt.Fprintln(ctx.w) } if inArray { level-- } - fmt.Printf("%s}", ctx.indent(level)) + fmt.Fprintf(ctx.w, "%s}", ctx.indent(level)) case "Function": // Use toString() to display the argument list if possible. if robj, err := obj.Call("toString"); err != nil { - functionColor.Print("function()") + fmt.Fprint(ctx.w, FunctionColor("function()")) } else { desc := strings.Trim(strings.Split(robj.String(), "{")[0], " \t\n") desc = strings.Replace(desc, " (", "(", 1) - functionColor.Print(desc) + fmt.Fprint(ctx.w, FunctionColor("%s", desc)) } case "RegExp": - stringColor.Print(toString(obj)) + fmt.Fprint(ctx.w, StringColor("%s", toString(obj))) default: if v, _ := obj.Get("toString"); v.IsFunction() && level <= maxPrettyPrintLevel { s, _ := obj.Call("toString") - fmt.Printf("<%s %s>", obj.Class(), s.String()) + fmt.Fprintf(ctx.w, "<%s %s>", obj.Class(), s.String()) } else { - fmt.Printf("<%s>", obj.Class()) + fmt.Fprintf(ctx.w, "<%s>", obj.Class()) } } } diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index 1928913dea..8d5d1500ff 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -18,44 +18,17 @@ package web3ext var Modules = map[string]string{ - "txpool": TxPool_JS, "admin": Admin_JS, - "personal": Personal_JS, + "debug": Debug_JS, "eth": Eth_JS, "miner": Miner_JS, - "debug": Debug_JS, "net": Net_JS, + "personal": Personal_JS, + "rpc": RPC_JS, + "shh": Shh_JS, + "txpool": TxPool_JS, } -const TxPool_JS = ` -web3._extend({ - property: 'txpool', - methods: - [ - ], - properties: - [ - new web3._extend.Property({ - name: 'content', - getter: 'txpool_content' - }), - new web3._extend.Property({ - name: 'inspect', - getter: 'txpool_inspect' - }), - new web3._extend.Property({ - name: 'status', - getter: 'txpool_status', - outputFormatter: function(status) { - status.pending = web3._extend.utils.toDecimal(status.pending); - status.queued = web3._extend.utils.toDecimal(status.queued); - return status; - } - }) - ] -}); -` - const Admin_JS = ` web3._extend({ property: 'admin', @@ -176,60 +149,6 @@ web3._extend({ }); ` -const Eth_JS = ` -web3._extend({ - property: 'eth', - methods: - [ - new web3._extend.Method({ - name: 'sign', - call: 'eth_sign', - params: 2, - inputFormatter: [web3._extend.formatters.inputAddressFormatter, null] - }), - new web3._extend.Method({ - name: 'resend', - call: 'eth_resend', - params: 3, - inputFormatter: [web3._extend.formatters.inputTransactionFormatter, web3._extend.utils.fromDecimal, web3._extend.utils.fromDecimal] - }), - new web3._extend.Method({ - name: 'getNatSpec', - call: 'eth_getNatSpec', - params: 1, - inputFormatter: [web3._extend.formatters.inputTransactionFormatter] - }), - new web3._extend.Method({ - name: 'signTransaction', - call: 'eth_signTransaction', - params: 1, - inputFormatter: [web3._extend.formatters.inputTransactionFormatter] - }), - new web3._extend.Method({ - name: 'submitTransaction', - call: 'eth_submitTransaction', - params: 1, - inputFormatter: [web3._extend.formatters.inputTransactionFormatter] - }) - ], - properties: - [ - new web3._extend.Property({ - name: 'pendingTransactions', - getter: 'eth_pendingTransactions', - outputFormatter: function(txs) { - var formatted = []; - for (var i = 0; i < txs.length; i++) { - formatted.push(web3._extend.formatters.outputTransactionFormatter(txs[i])); - formatted[i].blockHash = null; - } - return formatted; - } - }) - ] -}); -` - const Debug_JS = ` web3._extend({ property: 'debug', @@ -382,6 +301,60 @@ web3._extend({ }); ` +const Eth_JS = ` +web3._extend({ + property: 'eth', + methods: + [ + new web3._extend.Method({ + name: 'sign', + call: 'eth_sign', + params: 2, + inputFormatter: [web3._extend.formatters.inputAddressFormatter, null] + }), + new web3._extend.Method({ + name: 'resend', + call: 'eth_resend', + params: 3, + inputFormatter: [web3._extend.formatters.inputTransactionFormatter, web3._extend.utils.fromDecimal, web3._extend.utils.fromDecimal] + }), + new web3._extend.Method({ + name: 'getNatSpec', + call: 'eth_getNatSpec', + params: 1, + inputFormatter: [web3._extend.formatters.inputTransactionFormatter] + }), + new web3._extend.Method({ + name: 'signTransaction', + call: 'eth_signTransaction', + params: 1, + inputFormatter: [web3._extend.formatters.inputTransactionFormatter] + }), + new web3._extend.Method({ + name: 'submitTransaction', + call: 'eth_submitTransaction', + params: 1, + inputFormatter: [web3._extend.formatters.inputTransactionFormatter] + }) + ], + properties: + [ + new web3._extend.Property({ + name: 'pendingTransactions', + getter: 'eth_pendingTransactions', + outputFormatter: function(txs) { + var formatted = []; + for (var i = 0; i < txs.length; i++) { + formatted.push(web3._extend.formatters.outputTransactionFormatter(txs[i])); + formatted[i].blockHash = null; + } + return formatted; + } + }) + ] +}); +` + const Miner_JS = ` web3._extend({ property: 'miner', @@ -412,7 +385,7 @@ web3._extend({ name: 'setGasPrice', call: 'miner_setGasPrice', params: 1, - inputFormatter: [web3._extend.utils.fromDecial] + inputFormatter: [web3._extend.utils.fromDecimal] }), new web3._extend.Method({ name: 'startAutoDAG', @@ -491,7 +464,35 @@ web3._extend({ [ new web3._extend.Property({ name: 'version', - getter: 'shh_version' + getter: 'shh_version', + outputFormatter: web3._extend.utils.toDecimal + }) + ] +}); +` + +const TxPool_JS = ` +web3._extend({ + property: 'txpool', + methods: [], + properties: + [ + new web3._extend.Property({ + name: 'content', + getter: 'txpool_content' + }), + new web3._extend.Property({ + name: 'inspect', + getter: 'txpool_inspect' + }), + new web3._extend.Property({ + name: 'status', + getter: 'txpool_status', + outputFormatter: function(status) { + status.pending = web3._extend.utils.toDecimal(status.pending); + status.queued = web3._extend.utils.toDecimal(status.queued); + return status; + } }) ] }); diff --git a/node/node.go b/node/node.go index 06a1b7aed7..1f517a027e 100644 --- a/node/node.go +++ b/node/node.go @@ -49,7 +49,7 @@ type Node struct { datadir string // Path to the currently used data directory eventmux *event.TypeMux // Event multiplexer used between the services of a stack - serverConfig *p2p.Server // Configuration of the underlying P2P networking layer + serverConfig p2p.Config server *p2p.Server // Currently running P2P networking layer serviceFuncs []ServiceConstructor // Service constructors (in dependency order) @@ -97,7 +97,7 @@ func New(conf *Config) (*Node, error) { } return &Node{ datadir: conf.DataDir, - serverConfig: &p2p.Server{ + serverConfig: p2p.Config{ PrivateKey: conf.NodeKey(), Name: conf.Name, Discovery: !conf.NoDiscovery, @@ -151,9 +151,7 @@ func (n *Node) Start() error { return ErrNodeRunning } // Otherwise copy and specialize the P2P configuration - running := new(p2p.Server) - *running = *n.serverConfig - + running := &p2p.Server{Config: n.serverConfig} services := make(map[reflect.Type]Service) for _, constructor := range n.serviceFuncs { // Create a new context for the particular service diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 3447660a3f..05d9b75626 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -478,7 +478,8 @@ func TestDialResolve(t *testing.T) { } // Now run the task, it should resolve the ID once. - srv := &Server{ntab: table, Dialer: &net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}} + config := Config{Dialer: &net.Dialer{Deadline: time.Now().Add(-5 * time.Minute)}} + srv := &Server{ntab: table, Config: config} tasks[0].Do(srv) if !reflect.DeepEqual(table.resolveCalls, []discover.NodeID{dest.ID}) { t.Fatalf("wrong resolve calls, got %v", table.resolveCalls) diff --git a/p2p/server.go b/p2p/server.go index 3b2f2b0786..880aa7cf1f 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -54,12 +54,8 @@ var errServerStopped = errors.New("server stopped") var srvjslog = logger.NewJsonLogger() -// Server manages all peer connections. -// -// The fields of Server are used as configuration parameters. -// You should set them before starting the Server. Fields may not be -// modified while the server is running. -type Server struct { +// Config holds Server options. +type Config struct { // This field must be set to a valid secp256k1 private key. PrivateKey *ecdsa.PrivateKey @@ -120,6 +116,12 @@ type Server struct { // If NoDial is true, the server will not dial any peers. NoDial bool +} + +// Server manages all peer connections. +type Server struct { + // Config fields may not be modified while the server is running. + Config // Hooks for testing. These are useful because we can inhibit // the whole protocol stack. diff --git a/p2p/server_test.go b/p2p/server_test.go index b437ac3676..deb34f5bb1 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -67,11 +67,14 @@ func (c *testTransport) close(err error) { } func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { + config := Config{ + Name: "test", + MaxPeers: 10, + ListenAddr: "127.0.0.1:0", + PrivateKey: newkey(), + } server := &Server{ - Name: "test", - MaxPeers: 10, - ListenAddr: "127.0.0.1:0", - PrivateKey: newkey(), + Config: config, newPeerHook: pf, newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, } @@ -200,10 +203,10 @@ func TestServerTaskScheduling(t *testing.T) { // The Server in this test isn't actually running // because we're only interested in what run does. srv := &Server{ - MaxPeers: 10, - quit: make(chan struct{}), - ntab: fakeTable{}, - running: true, + Config: Config{MaxPeers: 10}, + quit: make(chan struct{}), + ntab: fakeTable{}, + running: true, } srv.loopWG.Add(1) go func() { @@ -314,10 +317,12 @@ func (t *testTask) Do(srv *Server) { func TestServerAtCap(t *testing.T) { trustedID := randomID() srv := &Server{ - PrivateKey: newkey(), - MaxPeers: 10, - NoDial: true, - TrustedNodes: []*discover.Node{{ID: trustedID}}, + Config: Config{ + PrivateKey: newkey(), + MaxPeers: 10, + NoDial: true, + TrustedNodes: []*discover.Node{{ID: trustedID}}, + }, } if err := srv.Start(); err != nil { t.Fatalf("could not start: %v", err) @@ -415,10 +420,12 @@ func TestServerSetupConn(t *testing.T) { for i, test := range tests { srv := &Server{ - PrivateKey: srvkey, - MaxPeers: 10, - NoDial: true, - Protocols: []Protocol{discard}, + Config: Config{ + PrivateKey: srvkey, + MaxPeers: 10, + NoDial: true, + Protocols: []Protocol{discard}, + }, newTransport: func(fd net.Conn) transport { return test.tt }, } if !test.dontstart { diff --git a/rpc/json.go b/rpc/json.go index 8a3bea2eeb..151ed546e7 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -30,7 +30,7 @@ import ( ) const ( - jsonRPCVersion = "2.0" + JSONRPCVersion = "2.0" serviceMethodSeparator = "_" subscribeMethod = "eth_subscribe" unsubscribeMethod = "eth_unsubscribe" @@ -302,31 +302,31 @@ func parsePositionalArguments(args json.RawMessage, callbackArgs []reflect.Type) // CreateResponse will create a JSON-RPC success response with the given id and reply as result. func (c *jsonCodec) CreateResponse(id interface{}, reply interface{}) interface{} { if isHexNum(reflect.TypeOf(reply)) { - return &JSONSuccessResponse{Version: jsonRPCVersion, Id: id, Result: fmt.Sprintf(`%#x`, reply)} + return &JSONSuccessResponse{Version: JSONRPCVersion, Id: id, Result: fmt.Sprintf(`%#x`, reply)} } - return &JSONSuccessResponse{Version: jsonRPCVersion, Id: id, Result: reply} + return &JSONSuccessResponse{Version: JSONRPCVersion, Id: id, Result: reply} } // CreateErrorResponse will create a JSON-RPC error response with the given id and error. func (c *jsonCodec) CreateErrorResponse(id interface{}, err RPCError) interface{} { - return &JSONErrResponse{Version: jsonRPCVersion, Id: id, Error: JSONError{Code: err.Code(), Message: err.Error()}} + return &JSONErrResponse{Version: JSONRPCVersion, Id: id, Error: JSONError{Code: err.Code(), Message: err.Error()}} } // CreateErrorResponseWithInfo will create a JSON-RPC error response with the given id and error. // info is optional and contains additional information about the error. When an empty string is passed it is ignored. func (c *jsonCodec) CreateErrorResponseWithInfo(id interface{}, err RPCError, info interface{}) interface{} { - return &JSONErrResponse{Version: jsonRPCVersion, Id: id, + return &JSONErrResponse{Version: JSONRPCVersion, Id: id, Error: JSONError{Code: err.Code(), Message: err.Error(), Data: info}} } // CreateNotification will create a JSON-RPC notification with the given subscription id and event as params. func (c *jsonCodec) CreateNotification(subid string, event interface{}) interface{} { if isHexNum(reflect.TypeOf(event)) { - return &jsonNotification{Version: jsonRPCVersion, Method: notificationMethod, + return &jsonNotification{Version: JSONRPCVersion, Method: notificationMethod, Params: jsonSubscription{Subscription: subid, Result: fmt.Sprintf(`%#x`, event)}} } - return &jsonNotification{Version: jsonRPCVersion, Method: notificationMethod, + return &jsonNotification{Version: JSONRPCVersion, Method: notificationMethod, Params: jsonSubscription{Subscription: subid, Result: event}} } diff --git a/rpc/server.go b/rpc/server.go index 001107a1b7..69f3271e8e 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -34,7 +34,8 @@ const ( notificationBufferSize = 10000 // max buffered notifications before codec is closed - DefaultIPCApis = "admin,eth,debug,miner,net,shh,txpool,personal,web3" + MetadataApi = "rpc" + DefaultIPCApis = "admin,debug,eth,miner,net,personal,shh,txpool,web3" DefaultHTTPApis = "eth,net,web3" ) @@ -61,7 +62,7 @@ func NewServer() *Server { // register a default service which will provide meta information about the RPC service such as the services and // methods it offers. rpcService := &RPCService{server} - server.RegisterName("rpc", rpcService) + server.RegisterName(MetadataApi, rpcService) return server } diff --git a/rpc/utils.go b/rpc/utils.go index 86938e9b37..fe482e19dd 100644 --- a/rpc/utils.go +++ b/rpc/utils.go @@ -234,7 +234,7 @@ func SupportedModules(client Client) (map[string]string, error) { req := JSONRequest{ Id: []byte("1"), Version: "2.0", - Method: "rpc_modules", + Method: MetadataApi + "_modules", } if err := client.Send(req); err != nil { return nil, err diff --git a/tests/init.go b/tests/init.go index 5112b274d0..0c07f8b237 100644 --- a/tests/init.go +++ b/tests/init.go @@ -25,8 +25,6 @@ import ( "net/http" "os" "path/filepath" - - "github.com/ethereum/go-ethereum/core" ) var ( @@ -59,11 +57,6 @@ var ( VmSkipTests = []string{} ) -// Disable reporting bad blocks for the tests -func init() { - core.DisableBadBlockReporting = true -} - func readJson(reader io.Reader, value interface{}) error { data, err := ioutil.ReadAll(reader) if err != nil { diff --git a/trie/iterator.go b/trie/iterator.go index ceef52ec8d..88c4cee7fa 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -62,7 +62,7 @@ func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byt switch node := node.(type) { case fullNode: if len(key) > 0 { - k := self.next(node[key[0]], key[1:], isIterStart) + k := self.next(node.Children[key[0]], key[1:], isIterStart) if k != nil { return append([]byte{key[0]}, k...) } @@ -74,7 +74,7 @@ func (self *Iterator) next(node interface{}, key []byte, isIterStart bool) []byt } for i := r; i < 16; i++ { - k := self.key(node[i]) + k := self.key(node.Children[i]) if k != nil { return append([]byte{i}, k...) } @@ -130,12 +130,12 @@ func (self *Iterator) key(node interface{}) []byte { } return append(k, self.key(node.Val)...) case fullNode: - if node[16] != nil { - self.Value = node[16].(valueNode) + if node.Children[16] != nil { + self.Value = node.Children[16].(valueNode) return []byte{16} } for i := 0; i < 16; i++ { - k := self.key(node[i]) + k := self.key(node.Children[i]) if k != nil { return append([]byte{byte(i)}, k...) } @@ -175,7 +175,7 @@ type NodeIterator struct { // NewNodeIterator creates an post-order trie iterator. func NewNodeIterator(trie *Trie) *NodeIterator { - if bytes.Compare(trie.Root(), emptyRoot.Bytes()) == 0 { + if trie.Hash() == emptyState { return new(NodeIterator) } return &NodeIterator{trie: trie} @@ -205,9 +205,11 @@ func (it *NodeIterator) step() error { } // Initialize the iterator if we've just started, or pop off the old node otherwise if len(it.stack) == 0 { - it.stack = append(it.stack, &nodeIteratorState{node: it.trie.root, child: -1}) + // Always start with a collapsed root + root := it.trie.Hash() + it.stack = append(it.stack, &nodeIteratorState{node: hashNode(root[:]), child: -1}) if it.stack[0].node == nil { - return fmt.Errorf("root node missing: %x", it.trie.Root()) + return fmt.Errorf("root node missing: %x", it.trie.Hash()) } } else { it.stack = it.stack[:len(it.stack)-1] @@ -225,11 +227,11 @@ func (it *NodeIterator) step() error { } if node, ok := parent.node.(fullNode); ok { // Full node, traverse all children, then the node itself - if parent.child >= len(node) { + if parent.child >= len(node.Children) { break } - for parent.child++; parent.child < len(node); parent.child++ { - if current := node[parent.child]; current != nil { + for parent.child++; parent.child < len(node.Children); parent.child++ { + if current := node.Children[parent.child]; current != nil { it.stack = append(it.stack, &nodeIteratorState{node: current, parent: ancestor, child: -1}) break } diff --git a/trie/node.go b/trie/node.go index 0bfa21dc43..b97d370be4 100644 --- a/trie/node.go +++ b/trie/node.go @@ -29,18 +29,36 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b type node interface { fstring(string) string + cache() (hashNode, bool) } type ( - fullNode [17]node + fullNode struct { + Children [17]node // Actual trie node data to encode/decode (needs custom encoder) + hash hashNode // Cached hash of the node to prevent rehashing (may be nil) + dirty bool // Cached flag whether the node's new or already stored + } shortNode struct { - Key []byte - Val node + Key []byte + Val node + hash hashNode // Cached hash of the node to prevent rehashing (may be nil) + dirty bool // Cached flag whether the node's new or already stored } hashNode []byte valueNode []byte ) +// EncodeRLP encodes a full node into the consensus RLP format. +func (n fullNode) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, n.Children) +} + +// Cache accessors to retrieve precalculated values (avoid lengthy type switches). +func (n fullNode) cache() (hashNode, bool) { return n.hash, n.dirty } +func (n shortNode) cache() (hashNode, bool) { return n.hash, n.dirty } +func (n hashNode) cache() (hashNode, bool) { return nil, true } +func (n valueNode) cache() (hashNode, bool) { return nil, true } + // Pretty printing. func (n fullNode) String() string { return n.fstring("") } func (n shortNode) String() string { return n.fstring("") } @@ -49,7 +67,7 @@ func (n valueNode) String() string { return n.fstring("") } func (n fullNode) fstring(ind string) string { resp := fmt.Sprintf("[\n%s ", ind) - for i, node := range n { + for i, node := range n.Children { if node == nil { resp += fmt.Sprintf("%s: ", indices[i]) } else { @@ -68,16 +86,16 @@ func (n valueNode) fstring(ind string) string { return fmt.Sprintf("%x ", []byte(n)) } -func mustDecodeNode(dbkey, buf []byte) node { - n, err := decodeNode(buf) +func mustDecodeNode(hash, buf []byte) node { + n, err := decodeNode(hash, buf) if err != nil { - panic(fmt.Sprintf("node %x: %v", dbkey, err)) + panic(fmt.Sprintf("node %x: %v", hash, err)) } return n } // decodeNode parses the RLP encoding of a trie node. -func decodeNode(buf []byte) (node, error) { +func decodeNode(hash, buf []byte) (node, error) { if len(buf) == 0 { return nil, io.ErrUnexpectedEOF } @@ -87,18 +105,18 @@ func decodeNode(buf []byte) (node, error) { } switch c, _ := rlp.CountValues(elems); c { case 2: - n, err := decodeShort(elems) + n, err := decodeShort(hash, buf, elems) return n, wrapError(err, "short") case 17: - n, err := decodeFull(elems) + n, err := decodeFull(hash, buf, elems) return n, wrapError(err, "full") default: return nil, fmt.Errorf("invalid number of list elements: %v", c) } } -func decodeShort(buf []byte) (node, error) { - kbuf, rest, err := rlp.SplitString(buf) +func decodeShort(hash, buf, elems []byte) (node, error) { + kbuf, rest, err := rlp.SplitString(elems) if err != nil { return nil, err } @@ -109,30 +127,30 @@ func decodeShort(buf []byte) (node, error) { if err != nil { return nil, fmt.Errorf("invalid value node: %v", err) } - return shortNode{key, valueNode(val)}, nil + return shortNode{key, valueNode(val), hash, false}, nil } r, _, err := decodeRef(rest) if err != nil { return nil, wrapError(err, "val") } - return shortNode{key, r}, nil + return shortNode{key, r, hash, false}, nil } -func decodeFull(buf []byte) (fullNode, error) { - var n fullNode +func decodeFull(hash, buf, elems []byte) (fullNode, error) { + n := fullNode{hash: hash} for i := 0; i < 16; i++ { - cld, rest, err := decodeRef(buf) + cld, rest, err := decodeRef(elems) if err != nil { return n, wrapError(err, fmt.Sprintf("[%d]", i)) } - n[i], buf = cld, rest + n.Children[i], elems = cld, rest } - val, _, err := rlp.SplitString(buf) + val, _, err := rlp.SplitString(elems) if err != nil { return n, err } if len(val) > 0 { - n[16] = valueNode(val) + n.Children[16] = valueNode(val) } return n, nil } @@ -152,7 +170,7 @@ func decodeRef(buf []byte) (node, []byte, error) { err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) return nil, buf, err } - n, err := decodeNode(buf) + n, err := decodeNode(nil, buf) return n, rest, err case kind == rlp.String && len(val) == 0: // empty node diff --git a/trie/proof.go b/trie/proof.go index 37a70fb34d..5135de0473 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -54,7 +54,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { } nodes = append(nodes, n) case fullNode: - tn = n[key[0]] + tn = n.Children[key[0]] key = key[1:] nodes = append(nodes, n) case hashNode: @@ -77,7 +77,7 @@ func (t *Trie) Prove(key []byte) []rlp.RawValue { for i, n := range nodes { // Don't bother checking for errors here since hasher panics // if encoding doesn't work and we're not writing to any database. - n, _ = t.hasher.replaceChildren(n, nil) + n, _, _ = t.hasher.hashChildren(n, nil) hn, _ := t.hasher.store(n, nil, false) if _, ok := hn.(hashNode); ok || i == 0 { // If the node's database encoding is a hash (or is the @@ -103,7 +103,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value if !bytes.Equal(sha.Sum(nil), wantHash) { return nil, fmt.Errorf("bad proof node %d: hash mismatch", i) } - n, err := decodeNode(buf) + n, err := decodeNode(wantHash, buf) if err != nil { return nil, fmt.Errorf("bad proof node %d: %v", i, err) } @@ -139,7 +139,7 @@ func get(tn node, key []byte) ([]byte, node) { tn = n.Val key = key[len(n.Key):] case fullNode: - tn = n[key[0]] + tn = n.Children[key[0]] key = key[1:] case hashNode: return key, n diff --git a/trie/secure_trie.go b/trie/secure_trie.go index be7defe83b..1d027c1027 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -162,11 +162,11 @@ func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { } t.secKeyCache = make(map[string][]byte) } - n, err := t.hashRoot(db) + n, clean, err := t.hashRoot(db) if err != nil { return (common.Hash{}), err } - t.root = n + t.root = clean return common.BytesToHash(n.(hashNode)), nil } diff --git a/trie/sync.go b/trie/sync.go index d55399d06b..6e9e029b93 100644 --- a/trie/sync.go +++ b/trie/sync.go @@ -17,6 +17,7 @@ package trie import ( + "errors" "fmt" "github.com/ethereum/go-ethereum/common" @@ -24,6 +25,10 @@ import ( "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) +// ErrNotRequested is returned by the trie sync when it's requested to process a +// node it did not request. +var ErrNotRequested = errors.New("not requested") + // request represents a scheduled or already in-flight state retrieval request. type request struct { hash common.Hash // Hash of the node data content to retrieve @@ -75,8 +80,9 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c if root == emptyRoot { return } - blob, _ := s.database.Get(root.Bytes()) - if local, err := decodeNode(blob); local != nil && err == nil { + key := root.Bytes() + blob, _ := s.database.Get(key) + if local, err := decodeNode(key, blob); local != nil && err == nil { return } // Assemble the new sub-trie sync request @@ -143,7 +149,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { // If the item was not requested, bail out request := s.requests[item.Hash] if request == nil { - return i, fmt.Errorf("not requested: %x", item.Hash) + return i, ErrNotRequested } // If the item is a raw entry request, commit directly if request.object == nil { @@ -152,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) { continue } // Decode the node data content and update the request - node, err := decodeNode(item.Data) + node, err := decodeNode(item.Hash[:], item.Data) if err != nil { return i, err } @@ -213,9 +219,9 @@ func (s *TrieSync) children(req *request) ([]*request, error) { }} case fullNode: for i := 0; i < 17; i++ { - if node[i] != nil { + if node.Children[i] != nil { children = append(children, child{ - node: &node[i], + node: &node.Children[i], depth: req.depth + 1, }) } @@ -238,7 +244,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) { if node, ok := (*child.node).(hashNode); ok { // Try to resolve the node from the local database blob, _ := s.database.Get(node) - if local, err := decodeNode(blob); local != nil && err == nil { + if local, err := decodeNode(node[:], blob); local != nil && err == nil { *child.node = local continue } diff --git a/trie/trie.go b/trie/trie.go index cc5dcf2a65..a530e7b2a3 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -129,7 +129,7 @@ func (t *Trie) TryGet(key []byte) ([]byte, error) { tn = n.Val pos += len(n.Key) case fullNode: - tn = n[key[pos]] + tn = n.Children[key[pos]] pos++ case nil: return nil, nil @@ -169,13 +169,13 @@ func (t *Trie) Update(key, value []byte) { func (t *Trie) TryUpdate(key, value []byte) error { k := compactHexDecode(key) if len(value) != 0 { - n, err := t.insert(t.root, nil, k, valueNode(value)) + _, n, err := t.insert(t.root, nil, k, valueNode(value)) if err != nil { return err } t.root = n } else { - n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k) if err != nil { return err } @@ -184,9 +184,12 @@ func (t *Trie) TryUpdate(key, value []byte) error { return nil } -func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) { +func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { if len(key) == 0 { - return value, nil + if v, ok := n.(valueNode); ok { + return !bytes.Equal(v, value.(valueNode)), value, nil + } + return true, value, nil } switch n := n.(type) { case shortNode: @@ -194,53 +197,63 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (node, error) { // If the whole key matches, keep this short node as is // and only update the value. if matchlen == len(n.Key) { - nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) + dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) if err != nil { - return nil, err + return false, nil, err } - return shortNode{n.Key, nn}, nil + if !dirty { + return false, n, nil + } + return true, shortNode{n.Key, nn, nil, true}, nil } // Otherwise branch out at the index where they differ. - var branch fullNode + branch := fullNode{dirty: true} var err error - branch[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) + _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) if err != nil { - return nil, err + return false, nil, err } - branch[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) + _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) if err != nil { - return nil, err + return false, nil, err } // Replace this shortNode with the branch if it occurs at index 0. if matchlen == 0 { - return branch, nil + return true, branch, nil } // Otherwise, replace it with a short node leading up to the branch. - return shortNode{key[:matchlen], branch}, nil + return true, shortNode{key[:matchlen], branch, nil, true}, nil case fullNode: - nn, err := t.insert(n[key[0]], append(prefix, key[0]), key[1:], value) + dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) if err != nil { - return nil, err + return false, nil, err } - n[key[0]] = nn - return n, nil + if !dirty { + return false, n, nil + } + n.Children[key[0]], n.hash, n.dirty = nn, nil, true + return true, n, nil case nil: - return shortNode{key, value}, nil + return true, shortNode{key, value, nil, true}, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load // the node and insert into it. This leaves all child nodes on // the path to the value in the trie. - // - // TODO: track whether insertion changed the value and keep - // n as a hash node if it didn't. rn, err := t.resolveHash(n, prefix, key) if err != nil { - return nil, err + return false, nil, err } - return t.insert(rn, prefix, key, value) + dirty, nn, err := t.insert(rn, prefix, key, value) + if err != nil { + return false, nil, err + } + if !dirty { + return false, rn, nil + } + return true, nn, nil default: panic(fmt.Sprintf("%T: invalid node: %v", n, n)) @@ -258,7 +271,7 @@ func (t *Trie) Delete(key []byte) { // If a node was not found in the database, a MissingNodeError is returned. func (t *Trie) TryDelete(key []byte) error { k := compactHexDecode(key) - n, err := t.delete(t.root, nil, k) + _, n, err := t.delete(t.root, nil, k) if err != nil { return err } @@ -269,23 +282,26 @@ func (t *Trie) TryDelete(key []byte) error { // delete returns the new root of the trie with key deleted. // It reduces the trie to minimal form by simplifying // nodes on the way up after deleting recursively. -func (t *Trie) delete(n node, prefix, key []byte) (node, error) { +func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { switch n := n.(type) { case shortNode: matchlen := prefixLen(key, n.Key) if matchlen < len(n.Key) { - return n, nil // don't replace n on mismatch + return false, n, nil // don't replace n on mismatch } if matchlen == len(key) { - return nil, nil // remove n entirely for whole matches + return true, nil, nil // remove n entirely for whole matches } // The key is longer than n.Key. Remove the remaining suffix // from the subtrie. Child can never be nil here since the // subtrie must contain at least two other values with keys // longer than n.Key. - child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) + dirty, child, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) if err != nil { - return nil, err + return false, nil, err + } + if !dirty { + return false, n, nil } switch child := child.(type) { case shortNode: @@ -295,17 +311,21 @@ func (t *Trie) delete(n node, prefix, key []byte) (node, error) { // always creates a new slice) instead of append to // avoid modifying n.Key since it might be shared with // other nodes. - return shortNode{concat(n.Key, child.Key...), child.Val}, nil + return true, shortNode{concat(n.Key, child.Key...), child.Val, nil, true}, nil default: - return shortNode{n.Key, child}, nil + return true, shortNode{n.Key, child, nil, true}, nil } case fullNode: - nn, err := t.delete(n[key[0]], append(prefix, key[0]), key[1:]) + dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) if err != nil { - return nil, err + return false, nil, err } - n[key[0]] = nn + if !dirty { + return false, n, nil + } + n.Children[key[0]], n.hash, n.dirty = nn, nil, true + // Check how many non-nil entries are left after deleting and // reduce the full node to a short node if only one entry is // left. Since n must've contained at least two children @@ -316,7 +336,7 @@ func (t *Trie) delete(n node, prefix, key []byte) (node, error) { // value that is left in n or -2 if n contains at least two // values. pos := -1 - for i, cld := range n { + for i, cld := range n.Children { if cld != nil { if pos == -1 { pos = i @@ -334,37 +354,41 @@ func (t *Trie) delete(n node, prefix, key []byte) (node, error) { // shortNode{..., shortNode{...}}. Since the entry // might not be loaded yet, resolve it just for this // check. - cnode, err := t.resolve(n[pos], prefix, []byte{byte(pos)}) + cnode, err := t.resolve(n.Children[pos], prefix, []byte{byte(pos)}) if err != nil { - return nil, err + return false, nil, err } if cnode, ok := cnode.(shortNode); ok { k := append([]byte{byte(pos)}, cnode.Key...) - return shortNode{k, cnode.Val}, nil + return true, shortNode{k, cnode.Val, nil, true}, nil } } // Otherwise, n is replaced by a one-nibble short node // containing the child. - return shortNode{[]byte{byte(pos)}, n[pos]}, nil + return true, shortNode{[]byte{byte(pos)}, n.Children[pos], nil, true}, nil } // n still contains at least two values and cannot be reduced. - return n, nil + return true, n, nil case nil: - return nil, nil + return false, nil, nil case hashNode: // We've hit a part of the trie that isn't loaded yet. Load // the node and delete from it. This leaves all child nodes on // the path to the value in the trie. - // - // TODO: track whether deletion actually hit a key and keep - // n as a hash node if it didn't. rn, err := t.resolveHash(n, prefix, key) if err != nil { - return nil, err + return false, nil, err } - return t.delete(rn, prefix, key) + dirty, nn, err := t.delete(rn, prefix, key) + if err != nil { + return false, nil, err + } + if !dirty { + return false, rn, nil + } + return true, nn, nil default: panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) @@ -413,8 +437,9 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() } // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - root, _ := t.hashRoot(nil) - return common.BytesToHash(root.(hashNode)) + hash, cached, _ := t.hashRoot(nil) + t.root = cached + return common.BytesToHash(hash.(hashNode)) } // Commit writes all nodes to the trie's database. @@ -437,17 +462,17 @@ func (t *Trie) Commit() (root common.Hash, err error) { // the changes made to db are written back to the trie's attached // database before using the trie. func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) { - n, err := t.hashRoot(db) + hash, cached, err := t.hashRoot(db) if err != nil { return (common.Hash{}), err } - t.root = n - return common.BytesToHash(n.(hashNode)), nil + t.root = cached + return common.BytesToHash(hash.(hashNode)), nil } -func (t *Trie) hashRoot(db DatabaseWriter) (node, error) { +func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) { if t.root == nil { - return hashNode(emptyRoot.Bytes()), nil + return hashNode(emptyRoot.Bytes()), nil, nil } if t.hasher == nil { t.hasher = newHasher() @@ -464,51 +489,87 @@ func newHasher() *hasher { return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()} } -func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, error) { - hashed, err := h.replaceChildren(n, db) +// hash collapses a node down into a hash node, also returning a copy of the +// original node initialzied with the computed hash to replace the original one. +func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) { + // If we're not storing the node, just hashing, use avaialble cached data + if hash, dirty := n.cache(); hash != nil && (db == nil || !dirty) { + return hash, n, nil + } + // Trie not processed yet or needs storage, walk the children + collapsed, cached, err := h.hashChildren(n, db) if err != nil { - return hashNode{}, err + return hashNode{}, n, err } - if n, err = h.store(hashed, db, force); err != nil { - return hashNode{}, err + hashed, err := h.store(collapsed, db, force) + if err != nil { + return hashNode{}, n, err } - return n, nil + // Cache the hash and RLP blob of the ndoe for later reuse + if hash, ok := hashed.(hashNode); ok && !force { + switch cached := cached.(type) { + case shortNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + case fullNode: + cached.hash = hash + if db != nil { + cached.dirty = false + } + return hashed, cached, nil + } + } + return hashed, cached, nil } -// hashChildren replaces child nodes of n with their hashes if the encoded -// size of the child is larger than a hash. -func (h *hasher) replaceChildren(n node, db DatabaseWriter) (node, error) { +// hashChildren replaces the children of a node with their hashes if the encoded +// size of the child is larger than a hash, returning the collapsed node as well +// as a replacement for the original node with the child hashes cached in. +func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) { var err error - switch n := n.(type) { + + switch n := original.(type) { case shortNode: + // Hash the short node's child, caching the newly hashed subtree + cached := n + cached.Key = common.CopyBytes(cached.Key) + n.Key = compactEncode(n.Key) if _, ok := n.Val.(valueNode); !ok { - if n.Val, err = h.hash(n.Val, db, false); err != nil { - return n, err + if n.Val, cached.Val, err = h.hash(n.Val, db, false); err != nil { + return n, original, err } } if n.Val == nil { - // Ensure that nil children are encoded as empty strings. - n.Val = valueNode(nil) + n.Val = valueNode(nil) // Ensure that nil children are encoded as empty strings. } - return n, nil + return n, cached, nil + case fullNode: + // Hash the full node's children, caching the newly hashed subtrees + cached := fullNode{dirty: n.dirty} + for i := 0; i < 16; i++ { - if n[i] != nil { - if n[i], err = h.hash(n[i], db, false); err != nil { - return n, err + if n.Children[i] != nil { + if n.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false); err != nil { + return n, original, err } } else { - // Ensure that nil children are encoded as empty strings. - n[i] = valueNode(nil) + n.Children[i] = valueNode(nil) // Ensure that nil children are encoded as empty strings. } } - if n[16] == nil { - n[16] = valueNode(nil) + cached.Children[16] = n.Children[16] + if n.Children[16] == nil { + n.Children[16] = valueNode(nil) } - return n, nil + return n, cached, nil + default: - return n, nil + // Value and hash nodes don't have children so they're left as were + return n, original, nil } } @@ -517,21 +578,23 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) { if _, isHash := n.(hashNode); n == nil || isHash { return n, nil } + // Generate the RLP encoding of the node h.tmp.Reset() if err := rlp.Encode(h.tmp, n); err != nil { panic("encode error: " + err.Error()) } if h.tmp.Len() < 32 && !force { - // Nodes smaller than 32 bytes are stored inside their parent. - return n, nil + return n, nil // Nodes smaller than 32 bytes are stored inside their parent } // Larger nodes are replaced by their hash and stored in the database. - h.sha.Reset() - h.sha.Write(h.tmp.Bytes()) - key := hashNode(h.sha.Sum(nil)) - if db != nil { - err := db.Put(key, h.tmp.Bytes()) - return key, err + hash, _ := n.cache() + if hash == nil { + h.sha.Reset() + h.sha.Write(h.tmp.Bytes()) + hash = hashNode(h.sha.Sum(nil)) } - return key, nil + if db != nil { + return hash, db.Put(hash, h.tmp.Bytes()) + } + return hash, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index bb761b5551..121ba24c1e 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -295,7 +295,7 @@ func TestReplication(t *testing.T) { for _, val := range vals2 { updateString(trie2, val.k, val.v) } - if trie2.Hash() != exp { + if hash := trie2.Hash(); hash != exp { t.Errorf("root failure. expected %x got %x", exp, hash) } }