Merge branch 'fjl-poc8-net-integration' into develop
This commit is contained in:
commit
5c251b6928
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
This file is part of go-ethereum
|
||||||
|
|
||||||
|
go-ethereum is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
go-ethereum is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Command bootnode runs a bootstrap node for the Discovery Protocol.
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"encoding/hex"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var (
|
||||||
|
listenAddr = flag.String("addr", ":30301", "listen address")
|
||||||
|
genKey = flag.String("genkey", "", "generate a node key and quit")
|
||||||
|
nodeKeyFile = flag.String("nodekey", "", "private key filename")
|
||||||
|
nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)")
|
||||||
|
natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
|
||||||
|
|
||||||
|
nodeKey *ecdsa.PrivateKey
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
flag.Parse()
|
||||||
|
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
|
||||||
|
|
||||||
|
if *genKey != "" {
|
||||||
|
writeKey(*genKey)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
natm, err := nat.Parse(*natdesc)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("-nat: %v", err)
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case *nodeKeyFile == "" && *nodeKeyHex == "":
|
||||||
|
log.Fatal("Use -nodekey or -nodekeyhex to specify a private key")
|
||||||
|
case *nodeKeyFile != "" && *nodeKeyHex != "":
|
||||||
|
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
|
||||||
|
case *nodeKeyFile != "":
|
||||||
|
if nodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
|
||||||
|
log.Fatalf("-nodekey: %v", err)
|
||||||
|
}
|
||||||
|
case *nodeKeyHex != "":
|
||||||
|
if nodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
|
||||||
|
log.Fatalf("-nodekeyhex: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeKey(target string) {
|
||||||
|
key, err := crypto.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("could not generate key: %v", err)
|
||||||
|
}
|
||||||
|
b := crypto.FromECDSA(key)
|
||||||
|
if target == "-" {
|
||||||
|
fmt.Println(hex.EncodeToString(b))
|
||||||
|
} else {
|
||||||
|
if err := ioutil.WriteFile(target, b, 0600); err != nil {
|
||||||
|
log.Fatal("write error: ", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -21,6 +21,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
@ -28,7 +29,9 @@ import (
|
||||||
"os/user"
|
"os/user"
|
||||||
"path"
|
"path"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
"github.com/ethereum/go-ethereum/vm"
|
"github.com/ethereum/go-ethereum/vm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,14 +45,14 @@ var (
|
||||||
StartWebSockets bool
|
StartWebSockets bool
|
||||||
RpcPort int
|
RpcPort int
|
||||||
WsPort int
|
WsPort int
|
||||||
NatType string
|
|
||||||
PMPGateway string
|
|
||||||
OutboundPort string
|
OutboundPort string
|
||||||
ShowGenesis bool
|
ShowGenesis bool
|
||||||
AddPeer string
|
AddPeer string
|
||||||
MaxPeer int
|
MaxPeer int
|
||||||
GenAddr bool
|
GenAddr bool
|
||||||
SeedNode string
|
BootNodes string
|
||||||
|
NodeKey *ecdsa.PrivateKey
|
||||||
|
NAT nat.Interface
|
||||||
SecretFile string
|
SecretFile string
|
||||||
ExportDir string
|
ExportDir string
|
||||||
NonInteractive bool
|
NonInteractive bool
|
||||||
|
@ -84,6 +87,7 @@ func defaultDataDir() string {
|
||||||
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
|
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
|
// TODO: move common flag processing to cmd/util
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
|
@ -93,18 +97,12 @@ func Init() {
|
||||||
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
|
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
|
||||||
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
|
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
|
||||||
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
|
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
|
||||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
|
||||||
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
|
|
||||||
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
|
|
||||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
|
||||||
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
|
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
|
||||||
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
|
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
|
||||||
flag.BoolVar(&StartRpc, "rpc", false, "start rpc server")
|
flag.BoolVar(&StartRpc, "rpc", false, "start rpc server")
|
||||||
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
|
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
|
||||||
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
|
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
|
||||||
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
|
|
||||||
flag.BoolVar(&SHH, "shh", true, "whisper protocol (on)")
|
|
||||||
flag.BoolVar(&Dial, "dial", true, "dial out connections (on)")
|
|
||||||
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
|
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
|
||||||
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
|
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
|
||||||
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
|
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
|
||||||
|
@ -127,8 +125,38 @@ func Init() {
|
||||||
flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console")
|
flag.BoolVar(&StartJsConsole, "js", false, "launches javascript console")
|
||||||
flag.BoolVar(&PrintVersion, "version", false, "prints version number")
|
flag.BoolVar(&PrintVersion, "version", false, "prints version number")
|
||||||
|
|
||||||
|
// Network stuff
|
||||||
|
var (
|
||||||
|
nodeKeyFile = flag.String("nodekey", "", "network private key file")
|
||||||
|
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
|
||||||
|
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
|
||||||
|
)
|
||||||
|
flag.BoolVar(&Dial, "dial", true, "dial out connections (default on)")
|
||||||
|
flag.BoolVar(&SHH, "shh", true, "run whisper protocol (default on)")
|
||||||
|
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
||||||
|
|
||||||
|
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
|
||||||
|
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if NAT, err = nat.Parse(*natstr); err != nil {
|
||||||
|
log.Fatalf("-nat: %v", err)
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case *nodeKeyFile != "" && *nodeKeyHex != "":
|
||||||
|
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
|
||||||
|
case *nodeKeyFile != "":
|
||||||
|
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
|
||||||
|
log.Fatalf("-nodekey: %v", err)
|
||||||
|
}
|
||||||
|
case *nodeKeyHex != "":
|
||||||
|
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
|
||||||
|
log.Fatalf("-nodekeyhex: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if VmType >= int(vm.MaxVmTy) {
|
if VmType >= int(vm.MaxVmTy) {
|
||||||
log.Fatal("Invalid VM type ", VmType)
|
log.Fatal("Invalid VM type ", VmType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/eth"
|
"github.com/ethereum/go-ethereum/eth"
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p"
|
||||||
"github.com/ethereum/go-ethereum/state"
|
"github.com/ethereum/go-ethereum/state"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,21 +62,19 @@ func main() {
|
||||||
utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
|
utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
|
||||||
|
|
||||||
ethereum, err := eth.New(ð.Config{
|
ethereum, err := eth.New(ð.Config{
|
||||||
Name: ClientIdentifier,
|
Name: p2p.MakeName(ClientIdentifier, Version),
|
||||||
Version: Version,
|
KeyStore: KeyStore,
|
||||||
KeyStore: KeyStore,
|
DataDir: Datadir,
|
||||||
DataDir: Datadir,
|
LogFile: LogFile,
|
||||||
LogFile: LogFile,
|
LogLevel: LogLevel,
|
||||||
LogLevel: LogLevel,
|
MaxPeers: MaxPeer,
|
||||||
LogFormat: LogFormat,
|
Port: OutboundPort,
|
||||||
Identifier: Identifier,
|
NAT: NAT,
|
||||||
MaxPeers: MaxPeer,
|
KeyRing: KeyRing,
|
||||||
Port: OutboundPort,
|
Shh: SHH,
|
||||||
NATType: PMPGateway,
|
Dial: Dial,
|
||||||
PMPGateway: PMPGateway,
|
BootNodes: BootNodes,
|
||||||
KeyRing: KeyRing,
|
NodeKey: NodeKey,
|
||||||
Shh: SHH,
|
|
||||||
Dial: Dial,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -135,7 +134,7 @@ func main() {
|
||||||
utils.StartWebSockets(ethereum, WsPort)
|
utils.StartWebSockets(ethereum, WsPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.StartEthereum(ethereum, SeedNode)
|
utils.StartEthereum(ethereum)
|
||||||
|
|
||||||
if StartJsConsole {
|
if StartJsConsole {
|
||||||
InitJsConsole(ethereum)
|
InitJsConsole(ethereum)
|
||||||
|
|
|
@ -79,6 +79,12 @@
|
||||||
contract.received({from: eth.coinbase}).changed(function() {
|
contract.received({from: eth.coinbase}).changed(function() {
|
||||||
refresh();
|
refresh();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
var ev = contract.SingleTransact({})
|
||||||
|
ev.watch(function(log) {
|
||||||
|
someElement.innerHTML += "tnaheousnthaoeu";
|
||||||
|
});
|
||||||
|
|
||||||
eth.watch('chain').changed(function() {
|
eth.watch('chain').changed(function() {
|
||||||
refresh();
|
refresh();
|
||||||
});
|
});
|
||||||
|
|
|
@ -32,18 +32,6 @@ Rectangle {
|
||||||
width: 500
|
width: 500
|
||||||
}
|
}
|
||||||
|
|
||||||
Label {
|
|
||||||
text: "Client ID"
|
|
||||||
}
|
|
||||||
TextField {
|
|
||||||
text: gui.getCustomIdentifier()
|
|
||||||
width: 500
|
|
||||||
placeholderText: "Anonymous"
|
|
||||||
onTextChanged: {
|
|
||||||
gui.setCustomIdentifier(text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TextArea {
|
TextArea {
|
||||||
objectName: "statsPane"
|
objectName: "statsPane"
|
||||||
width: parent.width
|
width: parent.width
|
||||||
|
|
|
@ -64,15 +64,6 @@ func (gui *Gui) Transact(recipient, value, gas, gasPrice, d string) (string, err
|
||||||
return gui.xeth.Transact(recipient, value, gas, gasPrice, data)
|
return gui.xeth.Transact(recipient, value, gas, gasPrice, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gui *Gui) SetCustomIdentifier(customIdentifier string) {
|
|
||||||
gui.clientIdentity.SetCustomIdentifier(customIdentifier)
|
|
||||||
gui.config.Save("id", customIdentifier)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gui *Gui) GetCustomIdentifier() string {
|
|
||||||
return gui.clientIdentity.GetCustomIdentifier()
|
|
||||||
}
|
|
||||||
|
|
||||||
// functions that allow Gui to implement interface guilogger.LogSystem
|
// functions that allow Gui to implement interface guilogger.LogSystem
|
||||||
func (gui *Gui) SetLogLevel(level logger.LogLevel) {
|
func (gui *Gui) SetLogLevel(level logger.LogLevel) {
|
||||||
gui.logLevel = level
|
gui.logLevel = level
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
@ -31,7 +32,9 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"bitbucket.org/kardianos/osext"
|
"bitbucket.org/kardianos/osext"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
"github.com/ethereum/go-ethereum/vm"
|
"github.com/ethereum/go-ethereum/vm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,19 +42,18 @@ var (
|
||||||
Identifier string
|
Identifier string
|
||||||
KeyRing string
|
KeyRing string
|
||||||
KeyStore string
|
KeyStore string
|
||||||
PMPGateway string
|
|
||||||
StartRpc bool
|
StartRpc bool
|
||||||
StartWebSockets bool
|
StartWebSockets bool
|
||||||
RpcPort int
|
RpcPort int
|
||||||
WsPort int
|
WsPort int
|
||||||
UseUPnP bool
|
|
||||||
NatType string
|
|
||||||
OutboundPort string
|
OutboundPort string
|
||||||
ShowGenesis bool
|
ShowGenesis bool
|
||||||
AddPeer string
|
AddPeer string
|
||||||
MaxPeer int
|
MaxPeer int
|
||||||
GenAddr bool
|
GenAddr bool
|
||||||
SeedNode string
|
BootNodes string
|
||||||
|
NodeKey *ecdsa.PrivateKey
|
||||||
|
NAT nat.Interface
|
||||||
SecretFile string
|
SecretFile string
|
||||||
ExportDir string
|
ExportDir string
|
||||||
NonInteractive bool
|
NonInteractive bool
|
||||||
|
@ -99,6 +101,7 @@ func defaultDataDir() string {
|
||||||
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
|
var defaultConfigFile = path.Join(defaultDataDir(), "conf.ini")
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
|
// TODO: move common flag processing to cmd/utils
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "%s [options] [filename]:\noptions precedence: default < config file < environment variables < command line\n", os.Args[0])
|
||||||
flag.PrintDefaults()
|
flag.PrintDefaults()
|
||||||
|
@ -108,30 +111,51 @@ func Init() {
|
||||||
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
|
flag.StringVar(&Identifier, "id", "", "Custom client identifier")
|
||||||
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
|
flag.StringVar(&KeyRing, "keyring", "", "identifier for keyring to use")
|
||||||
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
|
flag.StringVar(&KeyStore, "keystore", "db", "system to store keyrings: db|file (db)")
|
||||||
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
|
||||||
flag.BoolVar(&UseUPnP, "upnp", true, "enable UPnP support")
|
|
||||||
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
|
||||||
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
|
flag.IntVar(&RpcPort, "rpcport", 8545, "port to start json-rpc server on")
|
||||||
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
|
flag.IntVar(&WsPort, "wsport", 40404, "port to start websocket rpc server on")
|
||||||
flag.BoolVar(&StartRpc, "rpc", true, "start rpc server")
|
flag.BoolVar(&StartRpc, "rpc", true, "start rpc server")
|
||||||
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
|
flag.BoolVar(&StartWebSockets, "ws", false, "start websocket server")
|
||||||
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
|
flag.BoolVar(&NonInteractive, "y", false, "non-interactive mode (say yes to confirmations)")
|
||||||
flag.StringVar(&SeedNode, "seednode", "poc-8.ethdev.com:30303", "ip:port of seed node to connect to. Set to blank for skip")
|
|
||||||
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
|
flag.BoolVar(&GenAddr, "genaddr", false, "create a new priv/pub key")
|
||||||
flag.StringVar(&NatType, "nat", "", "NAT support (UPNP|PMP) (none)")
|
|
||||||
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
|
flag.StringVar(&SecretFile, "import", "", "imports the file given (hex or mnemonic formats)")
|
||||||
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
|
flag.StringVar(&ExportDir, "export", "", "exports the session keyring to files in the directory given")
|
||||||
flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)")
|
flag.StringVar(&LogFile, "logfile", "", "log file (defaults to standard output)")
|
||||||
flag.StringVar(&Datadir, "datadir", defaultDataDir(), "specifies the datadir to use")
|
flag.StringVar(&Datadir, "datadir", defaultDataDir(), "specifies the datadir to use")
|
||||||
flag.StringVar(&PMPGateway, "pmp", "", "Gateway IP for PMP")
|
|
||||||
flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file")
|
flag.StringVar(&ConfigFile, "conf", defaultConfigFile, "config file")
|
||||||
flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)")
|
flag.StringVar(&DebugFile, "debug", "", "debug file (no debugging if not set)")
|
||||||
flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5: silent,error,warn,info,debug,debug detail)")
|
flag.IntVar(&LogLevel, "loglevel", int(logger.InfoLevel), "loglevel: 0-5: silent,error,warn,info,debug,debug detail)")
|
||||||
|
|
||||||
flag.StringVar(&AssetPath, "asset_path", defaultAssetPath(), "absolute path to GUI assets directory")
|
flag.StringVar(&AssetPath, "asset_path", defaultAssetPath(), "absolute path to GUI assets directory")
|
||||||
|
|
||||||
|
// Network stuff
|
||||||
|
var (
|
||||||
|
nodeKeyFile = flag.String("nodekey", "", "network private key file")
|
||||||
|
nodeKeyHex = flag.String("nodekeyhex", "", "network private key (for testing)")
|
||||||
|
natstr = flag.String("nat", "any", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)")
|
||||||
|
)
|
||||||
|
flag.StringVar(&OutboundPort, "port", "30303", "listening port")
|
||||||
|
flag.StringVar(&BootNodes, "bootnodes", "", "space-separated node URLs for discovery bootstrap")
|
||||||
|
flag.IntVar(&MaxPeer, "maxpeer", 30, "maximum desired peers")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
if NAT, err = nat.Parse(*natstr); err != nil {
|
||||||
|
log.Fatalf("-nat: %v", err)
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case *nodeKeyFile != "" && *nodeKeyHex != "":
|
||||||
|
log.Fatal("Options -nodekey and -nodekeyhex are mutually exclusive")
|
||||||
|
case *nodeKeyFile != "":
|
||||||
|
if NodeKey, err = crypto.LoadECDSA(*nodeKeyFile); err != nil {
|
||||||
|
log.Fatalf("-nodekey: %v", err)
|
||||||
|
}
|
||||||
|
case *nodeKeyHex != "":
|
||||||
|
if NodeKey, err = crypto.HexToECDSA(*nodeKeyHex); err != nil {
|
||||||
|
log.Fatalf("-nodekeyhex: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if VmType >= int(vm.MaxVmTy) {
|
if VmType >= int(vm.MaxVmTy) {
|
||||||
log.Fatal("Invalid VM type ", VmType)
|
log.Fatal("Invalid VM type ", VmType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,6 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
"github.com/ethereum/go-ethereum/miner"
|
"github.com/ethereum/go-ethereum/miner"
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
|
||||||
"github.com/ethereum/go-ethereum/ui/qt/qwhisper"
|
"github.com/ethereum/go-ethereum/ui/qt/qwhisper"
|
||||||
"github.com/ethereum/go-ethereum/xeth"
|
"github.com/ethereum/go-ethereum/xeth"
|
||||||
"github.com/obscuren/qml"
|
"github.com/obscuren/qml"
|
||||||
|
@ -77,9 +76,8 @@ type Gui struct {
|
||||||
|
|
||||||
xeth *xeth.XEth
|
xeth *xeth.XEth
|
||||||
|
|
||||||
Session string
|
Session string
|
||||||
clientIdentity *p2p.SimpleClientIdentity
|
config *ethutil.ConfigManager
|
||||||
config *ethutil.ConfigManager
|
|
||||||
|
|
||||||
plugins map[string]plugin
|
plugins map[string]plugin
|
||||||
|
|
||||||
|
@ -87,7 +85,7 @@ type Gui struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create GUI, but doesn't start it
|
// Create GUI, but doesn't start it
|
||||||
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIdentity *p2p.SimpleClientIdentity, session string, logLevel int) *Gui {
|
func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, session string, logLevel int) *Gui {
|
||||||
db, err := ethdb.NewLDBDatabase("tx_database")
|
db, err := ethdb.NewLDBDatabase("tx_database")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -95,15 +93,14 @@ func NewWindow(ethereum *eth.Ethereum, config *ethutil.ConfigManager, clientIden
|
||||||
|
|
||||||
xeth := xeth.New(ethereum)
|
xeth := xeth.New(ethereum)
|
||||||
gui := &Gui{eth: ethereum,
|
gui := &Gui{eth: ethereum,
|
||||||
txDb: db,
|
txDb: db,
|
||||||
xeth: xeth,
|
xeth: xeth,
|
||||||
logLevel: logger.LogLevel(logLevel),
|
logLevel: logger.LogLevel(logLevel),
|
||||||
Session: session,
|
Session: session,
|
||||||
open: false,
|
open: false,
|
||||||
clientIdentity: clientIdentity,
|
config: config,
|
||||||
config: config,
|
plugins: make(map[string]plugin),
|
||||||
plugins: make(map[string]plugin),
|
serviceEvents: make(chan ServEv, 1),
|
||||||
serviceEvents: make(chan ServEv, 1),
|
|
||||||
}
|
}
|
||||||
data, _ := ethutil.ReadAllFile(path.Join(ethutil.Config.ExecPath, "plugins.json"))
|
data, _ := ethutil.ReadAllFile(path.Join(ethutil.Config.ExecPath, "plugins.json"))
|
||||||
json.Unmarshal([]byte(data), &gui.plugins)
|
json.Unmarshal([]byte(data), &gui.plugins)
|
||||||
|
|
|
@ -52,19 +52,18 @@ func run() error {
|
||||||
config := utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
|
config := utils.InitConfig(VmType, ConfigFile, Datadir, "ETH")
|
||||||
|
|
||||||
ethereum, err := eth.New(ð.Config{
|
ethereum, err := eth.New(ð.Config{
|
||||||
Name: ClientIdentifier,
|
Name: p2p.MakeName(ClientIdentifier, Version),
|
||||||
Version: Version,
|
KeyStore: KeyStore,
|
||||||
KeyStore: KeyStore,
|
DataDir: Datadir,
|
||||||
DataDir: Datadir,
|
LogFile: LogFile,
|
||||||
LogFile: LogFile,
|
LogLevel: LogLevel,
|
||||||
LogLevel: LogLevel,
|
MaxPeers: MaxPeer,
|
||||||
Identifier: Identifier,
|
Port: OutboundPort,
|
||||||
MaxPeers: MaxPeer,
|
NAT: NAT,
|
||||||
Port: OutboundPort,
|
BootNodes: BootNodes,
|
||||||
NATType: PMPGateway,
|
NodeKey: NodeKey,
|
||||||
PMPGateway: PMPGateway,
|
KeyRing: KeyRing,
|
||||||
KeyRing: KeyRing,
|
Dial: true,
|
||||||
Dial: true,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mainlogger.Fatalln(err)
|
mainlogger.Fatalln(err)
|
||||||
|
@ -79,12 +78,12 @@ func run() error {
|
||||||
utils.StartWebSockets(ethereum, WsPort)
|
utils.StartWebSockets(ethereum, WsPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
gui := NewWindow(ethereum, config, ethereum.ClientIdentity().(*p2p.SimpleClientIdentity), KeyRing, LogLevel)
|
gui := NewWindow(ethereum, config, KeyRing, LogLevel)
|
||||||
|
|
||||||
utils.RegisterInterrupt(func(os.Signal) {
|
utils.RegisterInterrupt(func(os.Signal) {
|
||||||
gui.Stop()
|
gui.Stop()
|
||||||
})
|
})
|
||||||
go utils.StartEthereum(ethereum, SeedNode)
|
go utils.StartEthereum(ethereum)
|
||||||
|
|
||||||
fmt.Println("ETH stack took", time.Since(tstart))
|
fmt.Println("ETH stack took", time.Since(tstart))
|
||||||
|
|
||||||
|
|
|
@ -136,15 +136,15 @@ func (ui *UiLib) Muted(content string) {
|
||||||
|
|
||||||
func (ui *UiLib) Connect(button qml.Object) {
|
func (ui *UiLib) Connect(button qml.Object) {
|
||||||
if !ui.connected {
|
if !ui.connected {
|
||||||
ui.eth.Start(SeedNode)
|
ui.eth.Start()
|
||||||
ui.connected = true
|
ui.connected = true
|
||||||
button.Set("enabled", false)
|
button.Set("enabled", false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ui *UiLib) ConnectToPeer(addr string) {
|
func (ui *UiLib) ConnectToPeer(nodeURL string) {
|
||||||
if err := ui.eth.SuggestPeer(addr); err != nil {
|
if err := ui.eth.SuggestPeer(nodeURL); err != nil {
|
||||||
guilogger.Infoln(err)
|
guilogger.Infoln("SuggestPeer error: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
/*
|
|
||||||
This file is part of go-ethereum
|
|
||||||
|
|
||||||
go-ethereum is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU General Public License as published by
|
|
||||||
the Free Software Foundation, either version 3 of the License, or
|
|
||||||
(at your option) any later version.
|
|
||||||
|
|
||||||
go-ethereum is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU General Public License
|
|
||||||
along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/elliptic"
|
|
||||||
"flag"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
natType = flag.String("nat", "", "NAT traversal implementation")
|
|
||||||
pmpGateway = flag.String("gateway", "", "gateway address for NAT-PMP")
|
|
||||||
listenAddr = flag.String("addr", ":30301", "listen address")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
nat, err := p2p.ParseNAT(*natType, *pmpGateway)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("invalid nat:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.InfoLevel))
|
|
||||||
key, _ := crypto.GenerateKey()
|
|
||||||
marshaled := elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y)
|
|
||||||
|
|
||||||
srv := p2p.Server{
|
|
||||||
MaxPeers: 100,
|
|
||||||
Identity: p2p.NewSimpleClientIdentity("Ethereum(G)", "0.1", "Peer Server Two", marshaled),
|
|
||||||
ListenAddr: *listenAddr,
|
|
||||||
NAT: nat,
|
|
||||||
NoDial: true,
|
|
||||||
}
|
|
||||||
if err := srv.Start(); err != nil {
|
|
||||||
log.Fatal("could not start server:", err)
|
|
||||||
}
|
|
||||||
select {}
|
|
||||||
}
|
|
|
@ -121,13 +121,11 @@ func exit(err error) {
|
||||||
os.Exit(status)
|
os.Exit(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartEthereum(ethereum *eth.Ethereum, SeedNode string) {
|
func StartEthereum(ethereum *eth.Ethereum) {
|
||||||
clilogger.Infof("Starting %s", ethereum.ClientIdentity())
|
clilogger.Infoln("Starting ", ethereum.Name())
|
||||||
err := ethereum.Start(SeedNode)
|
if err := ethereum.Start(); err != nil {
|
||||||
if err != nil {
|
|
||||||
exit(err)
|
exit(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
RegisterInterrupt(func(sig os.Signal) {
|
RegisterInterrupt(func(sig os.Signal) {
|
||||||
ethereum.Stop()
|
ethereum.Stop()
|
||||||
logger.Flush()
|
logger.Flush()
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/ethdb"
|
"github.com/ethereum/go-ethereum/ethdb"
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
"github.com/ethereum/go-ethereum/event"
|
"github.com/ethereum/go-ethereum/event"
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Implement our EthTest Manager
|
// Implement our EthTest Manager
|
||||||
|
@ -54,13 +53,6 @@ func (tm *TestManager) TxPool() *TxPool {
|
||||||
func (tm *TestManager) EventMux() *event.TypeMux {
|
func (tm *TestManager) EventMux() *event.TypeMux {
|
||||||
return tm.eventMux
|
return tm.eventMux
|
||||||
}
|
}
|
||||||
func (tm *TestManager) Broadcast(msgType p2p.Msg, data []interface{}) {
|
|
||||||
fmt.Println("Broadcast not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *TestManager) ClientIdentity() p2p.ClientIdentity {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (tm *TestManager) KeyManager() *crypto.KeyManager {
|
func (tm *TestManager) KeyManager() *crypto.KeyManager {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@ type EthManager interface {
|
||||||
IsListening() bool
|
IsListening() bool
|
||||||
Peers() []*p2p.Peer
|
Peers() []*p2p.Peer
|
||||||
KeyManager() *crypto.KeyManager
|
KeyManager() *crypto.KeyManager
|
||||||
ClientIdentity() p2p.ClientIdentity
|
|
||||||
Db() ethutil.Database
|
Db() ethutil.Database
|
||||||
EventMux() *event.TypeMux
|
EventMux() *event.TypeMux
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
@ -27,10 +29,11 @@ func init() {
|
||||||
ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256)
|
ecies.AddParamsForCurve(S256(), ecies.ECIES_AES128_SHA256)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Sha3(data []byte) []byte {
|
func Sha3(data ...[]byte) []byte {
|
||||||
d := sha3.NewKeccak256()
|
d := sha3.NewKeccak256()
|
||||||
d.Write(data)
|
for _, b := range data {
|
||||||
|
d.Write(b)
|
||||||
|
}
|
||||||
return d.Sum(nil)
|
return d.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,6 +101,32 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte {
|
||||||
return elliptic.Marshal(S256(), pub.X, pub.Y)
|
return elliptic.Marshal(S256(), pub.X, pub.Y)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HexToECDSA parses a secp256k1 private key.
|
||||||
|
func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) {
|
||||||
|
b, err := hex.DecodeString(hexkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("invalid hex string")
|
||||||
|
}
|
||||||
|
if len(b) != 32 {
|
||||||
|
return nil, errors.New("invalid length, need 256 bits")
|
||||||
|
}
|
||||||
|
return ToECDSA(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadECDSA loads a secp256k1 private key from the given file.
|
||||||
|
func LoadECDSA(file string) (*ecdsa.PrivateKey, error) {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
fd, err := os.Open(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer fd.Close()
|
||||||
|
if _, err := io.ReadFull(fd, buf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ToECDSA(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateKey() (*ecdsa.PrivateKey, error) {
|
func GenerateKey() (*ecdsa.PrivateKey, error) {
|
||||||
return ecdsa.GenerateKey(S256(), rand.Reader)
|
return ecdsa.GenerateKey(S256(), rand.Reader)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
func TestSha3(t *testing.T) {
|
func TestSha3(t *testing.T) {
|
||||||
msg := []byte("abc")
|
msg := []byte("abc")
|
||||||
exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45")
|
exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45")
|
||||||
checkhash(t, "Sha3-256", Sha3, msg, exp)
|
checkhash(t, "Sha3-256", func(in []byte) []byte { return Sha3(in) }, msg, exp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSha256(t *testing.T) {
|
func TestSha256(t *testing.T) {
|
||||||
|
|
|
@ -25,11 +25,12 @@ package crypto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"code.google.com/p/go-uuid/uuid"
|
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"code.google.com/p/go-uuid/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Key struct {
|
type Key struct {
|
||||||
|
|
134
eth/backend.go
134
eth/backend.go
|
@ -1,9 +1,9 @@
|
||||||
package eth
|
package eth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/core"
|
"github.com/ethereum/go-ethereum/core"
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
@ -12,35 +12,58 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/event"
|
"github.com/ethereum/go-ethereum/event"
|
||||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
"github.com/ethereum/go-ethereum/p2p"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
"github.com/ethereum/go-ethereum/pow/ezp"
|
"github.com/ethereum/go-ethereum/pow/ezp"
|
||||||
"github.com/ethereum/go-ethereum/rpc"
|
"github.com/ethereum/go-ethereum/rpc"
|
||||||
"github.com/ethereum/go-ethereum/whisper"
|
"github.com/ethereum/go-ethereum/whisper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = ethlogger.NewLogger("SERV")
|
||||||
|
var jsonlogger = ethlogger.NewJsonLogger()
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Name string
|
Name string
|
||||||
Version string
|
KeyStore string
|
||||||
Identifier string
|
DataDir string
|
||||||
KeyStore string
|
LogFile string
|
||||||
DataDir string
|
LogLevel int
|
||||||
LogFile string
|
KeyRing string
|
||||||
LogLevel int
|
LogFormat string
|
||||||
LogFormat string
|
|
||||||
KeyRing string
|
|
||||||
|
|
||||||
MaxPeers int
|
MaxPeers int
|
||||||
Port string
|
Port string
|
||||||
NATType string
|
|
||||||
PMPGateway string
|
|
||||||
|
|
||||||
|
// This should be a space-separated list of
|
||||||
|
// discovery node URLs.
|
||||||
|
BootNodes string
|
||||||
|
|
||||||
|
// This key is used to identify the node on the network.
|
||||||
|
// If nil, an ephemeral key is used.
|
||||||
|
NodeKey *ecdsa.PrivateKey
|
||||||
|
|
||||||
|
NAT nat.Interface
|
||||||
Shh bool
|
Shh bool
|
||||||
Dial bool
|
Dial bool
|
||||||
|
|
||||||
KeyManager *crypto.KeyManager
|
KeyManager *crypto.KeyManager
|
||||||
}
|
}
|
||||||
|
|
||||||
var logger = ethlogger.NewLogger("SERV")
|
func (cfg *Config) parseBootNodes() []*discover.Node {
|
||||||
var jsonlogger = ethlogger.NewJsonLogger()
|
var ns []*discover.Node
|
||||||
|
for _, url := range strings.Split(cfg.BootNodes, " ") {
|
||||||
|
if url == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n, err := discover.ParseNode(url)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Bootstrap URL %s: %v\n", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ns = append(ns, n)
|
||||||
|
}
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
type Ethereum struct {
|
type Ethereum struct {
|
||||||
// Channel for shutting down the ethereum
|
// Channel for shutting down the ethereum
|
||||||
|
@ -68,11 +91,7 @@ type Ethereum struct {
|
||||||
WsServer rpc.RpcServer
|
WsServer rpc.RpcServer
|
||||||
keyManager *crypto.KeyManager
|
keyManager *crypto.KeyManager
|
||||||
|
|
||||||
clientIdentity p2p.ClientIdentity
|
logger ethlogger.LogSystem
|
||||||
logger ethlogger.LogSystem
|
|
||||||
|
|
||||||
synclock sync.Mutex
|
|
||||||
syncGroup sync.WaitGroup
|
|
||||||
|
|
||||||
Mining bool
|
Mining bool
|
||||||
}
|
}
|
||||||
|
@ -105,21 +124,17 @@ func New(config *Config) (*Ethereum, error) {
|
||||||
// Initialise the keyring
|
// Initialise the keyring
|
||||||
keyManager.Init(config.KeyRing, 0, false)
|
keyManager.Init(config.KeyRing, 0, false)
|
||||||
|
|
||||||
// Create a new client id for this instance. This will help identifying the node on the network
|
|
||||||
clientId := p2p.NewSimpleClientIdentity(config.Name, config.Version, config.Identifier, keyManager.PublicKey())
|
|
||||||
|
|
||||||
saveProtocolVersion(db)
|
saveProtocolVersion(db)
|
||||||
//ethutil.Config.Db = db
|
//ethutil.Config.Db = db
|
||||||
|
|
||||||
eth := &Ethereum{
|
eth := &Ethereum{
|
||||||
shutdownChan: make(chan bool),
|
shutdownChan: make(chan bool),
|
||||||
quit: make(chan bool),
|
quit: make(chan bool),
|
||||||
db: db,
|
db: db,
|
||||||
keyManager: keyManager,
|
keyManager: keyManager,
|
||||||
clientIdentity: clientId,
|
blacklist: p2p.NewBlacklist(),
|
||||||
blacklist: p2p.NewBlacklist(),
|
eventMux: &event.TypeMux{},
|
||||||
eventMux: &event.TypeMux{},
|
logger: logger,
|
||||||
logger: logger,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
eth.chainManager = core.NewChainManager(db, eth.EventMux())
|
eth.chainManager = core.NewChainManager(db, eth.EventMux())
|
||||||
|
@ -134,21 +149,22 @@ func New(config *Config) (*Ethereum, error) {
|
||||||
|
|
||||||
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
|
ethProto := EthProtocol(eth.txPool, eth.chainManager, eth.blockPool)
|
||||||
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()}
|
protocols := []p2p.Protocol{ethProto, eth.whisper.Protocol()}
|
||||||
|
netprv := config.NodeKey
|
||||||
nat, err := p2p.ParseNAT(config.NATType, config.PMPGateway)
|
if netprv == nil {
|
||||||
if err != nil {
|
if netprv, err = crypto.GenerateKey(); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("could not generate server key: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eth.net = &p2p.Server{
|
eth.net = &p2p.Server{
|
||||||
Identity: clientId,
|
PrivateKey: netprv,
|
||||||
MaxPeers: config.MaxPeers,
|
Name: config.Name,
|
||||||
Protocols: protocols,
|
MaxPeers: config.MaxPeers,
|
||||||
Blacklist: eth.blacklist,
|
Protocols: protocols,
|
||||||
NAT: nat,
|
Blacklist: eth.blacklist,
|
||||||
NoDial: !config.Dial,
|
NAT: config.NAT,
|
||||||
|
NoDial: !config.Dial,
|
||||||
|
BootstrapNodes: config.parseBootNodes(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Port) > 0 {
|
if len(config.Port) > 0 {
|
||||||
eth.net.ListenAddr = ":" + config.Port
|
eth.net.ListenAddr = ":" + config.Port
|
||||||
}
|
}
|
||||||
|
@ -164,8 +180,8 @@ func (s *Ethereum) Logger() ethlogger.LogSystem {
|
||||||
return s.logger
|
return s.logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Ethereum) ClientIdentity() p2p.ClientIdentity {
|
func (s *Ethereum) Name() string {
|
||||||
return s.clientIdentity
|
return s.net.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Ethereum) ChainManager() *core.ChainManager {
|
func (s *Ethereum) ChainManager() *core.ChainManager {
|
||||||
|
@ -221,12 +237,12 @@ func (s *Ethereum) Coinbase() []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the ethereum
|
// Start the ethereum
|
||||||
func (s *Ethereum) Start(seedNode string) error {
|
func (s *Ethereum) Start() error {
|
||||||
jsonlogger.LogJson(ðlogger.LogStarting{
|
jsonlogger.LogJson(ðlogger.LogStarting{
|
||||||
ClientString: s.ClientIdentity().String(),
|
ClientString: s.net.Name,
|
||||||
Coinbase: ethutil.Bytes2Hex(s.KeyManager().Address()),
|
Coinbase: ethutil.Bytes2Hex(s.KeyManager().Address()),
|
||||||
ProtocolVersion: ProtocolVersion,
|
ProtocolVersion: ProtocolVersion,
|
||||||
LogEvent: ethlogger.LogEvent{Guid: ethutil.Bytes2Hex(s.ClientIdentity().Pubkey())},
|
LogEvent: ethlogger.LogEvent{Guid: ethutil.Bytes2Hex(crypto.FromECDSAPub(&s.net.PrivateKey.PublicKey))},
|
||||||
})
|
})
|
||||||
|
|
||||||
err := s.net.Start()
|
err := s.net.Start()
|
||||||
|
@ -250,26 +266,16 @@ func (s *Ethereum) Start(seedNode string) error {
|
||||||
s.blockSub = s.eventMux.Subscribe(core.NewMinedBlockEvent{})
|
s.blockSub = s.eventMux.Subscribe(core.NewMinedBlockEvent{})
|
||||||
go s.blockBroadcastLoop()
|
go s.blockBroadcastLoop()
|
||||||
|
|
||||||
// TODO: read peers here
|
|
||||||
if len(seedNode) > 0 {
|
|
||||||
logger.Infof("Connect to seed node %v", seedNode)
|
|
||||||
if err := s.SuggestPeer(seedNode); err != nil {
|
|
||||||
logger.Infoln(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infoln("Server started")
|
logger.Infoln("Server started")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *Ethereum) SuggestPeer(addr string) error {
|
func (self *Ethereum) SuggestPeer(nodeURL string) error {
|
||||||
netaddr, err := net.ResolveTCPAddr("tcp", addr)
|
n, err := discover.ParseNode(nodeURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("couldn't resolve %s:", addr, err)
|
return fmt.Errorf("invalid node URL: %v", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
self.net.SuggestPeer(n)
|
||||||
self.net.SuggestPeer(netaddr.IP, netaddr.Port, nil)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,13 +92,14 @@ func EthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool)
|
||||||
// the main loop that handles incoming messages
|
// the main loop that handles incoming messages
|
||||||
// note RemovePeer in the post-disconnect hook
|
// note RemovePeer in the post-disconnect hook
|
||||||
func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool, peer *p2p.Peer, rw p2p.MsgReadWriter) (err error) {
|
func runEthProtocol(txPool txPool, chainManager chainManager, blockPool blockPool, peer *p2p.Peer, rw p2p.MsgReadWriter) (err error) {
|
||||||
|
id := peer.ID()
|
||||||
self := ðProtocol{
|
self := ðProtocol{
|
||||||
txPool: txPool,
|
txPool: txPool,
|
||||||
chainManager: chainManager,
|
chainManager: chainManager,
|
||||||
blockPool: blockPool,
|
blockPool: blockPool,
|
||||||
rw: rw,
|
rw: rw,
|
||||||
peer: peer,
|
peer: peer,
|
||||||
id: fmt.Sprintf("%x", peer.Identity().Pubkey()[:8]),
|
id: fmt.Sprintf("%x", id[:8]),
|
||||||
}
|
}
|
||||||
err = self.handleStatus()
|
err = self.handleStatus()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
ethlogger "github.com/ethereum/go-ethereum/logger"
|
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
"github.com/ethereum/go-ethereum/p2p"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
)
|
)
|
||||||
|
|
||||||
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
|
var sys = ethlogger.NewStdLogSystem(os.Stdout, log.LstdFlags, ethlogger.LogLevel(ethlogger.DebugDetailLevel))
|
||||||
|
@ -128,26 +129,11 @@ func (self *testBlockPool) RemovePeer(peerId string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: refactor this into p2p/client_identity
|
|
||||||
type peerId struct {
|
|
||||||
pubkey []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *peerId) String() string {
|
|
||||||
return "test peer"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *peerId) Pubkey() (pubkey []byte) {
|
|
||||||
pubkey = self.pubkey
|
|
||||||
if len(pubkey) == 0 {
|
|
||||||
pubkey = crypto.GenerateNewKeyPair().PublicKey
|
|
||||||
self.pubkey = pubkey
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func testPeer() *p2p.Peer {
|
func testPeer() *p2p.Peer {
|
||||||
return p2p.NewPeer(&peerId{}, []p2p.Cap{})
|
var id discover.NodeID
|
||||||
|
pk := crypto.GenerateNewKeyPair().PublicKey
|
||||||
|
copy(id[:], pk)
|
||||||
|
return p2p.NewPeer(id, "test peer", []p2p.Cap{})
|
||||||
}
|
}
|
||||||
|
|
||||||
type ethProtocolTester struct {
|
type ethProtocolTester struct {
|
||||||
|
|
|
@ -197,12 +197,13 @@ func (self *JSRE) watch(call otto.FunctionCall) otto.Value {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *JSRE) addPeer(call otto.FunctionCall) otto.Value {
|
func (self *JSRE) addPeer(call otto.FunctionCall) otto.Value {
|
||||||
host, err := call.Argument(0).ToString()
|
nodeURL, err := call.Argument(0).ToString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return otto.FalseValue()
|
return otto.FalseValue()
|
||||||
}
|
}
|
||||||
self.ethereum.SuggestPeer(host)
|
if err := self.ethereum.SuggestPeer(nodeURL); err != nil {
|
||||||
|
return otto.FalseValue()
|
||||||
|
}
|
||||||
return otto.TrueValue()
|
return otto.TrueValue()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ClientIdentity represents the identity of a peer.
|
|
||||||
type ClientIdentity interface {
|
|
||||||
String() string // human readable identity
|
|
||||||
Pubkey() []byte // 512-bit public key
|
|
||||||
}
|
|
||||||
|
|
||||||
type SimpleClientIdentity struct {
|
|
||||||
clientIdentifier string
|
|
||||||
version string
|
|
||||||
customIdentifier string
|
|
||||||
os string
|
|
||||||
implementation string
|
|
||||||
pubkey []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
|
|
||||||
clientIdentity := &SimpleClientIdentity{
|
|
||||||
clientIdentifier: clientIdentifier,
|
|
||||||
version: version,
|
|
||||||
customIdentifier: customIdentifier,
|
|
||||||
os: runtime.GOOS,
|
|
||||||
implementation: runtime.Version(),
|
|
||||||
pubkey: pubkey,
|
|
||||||
}
|
|
||||||
|
|
||||||
return clientIdentity
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SimpleClientIdentity) init() {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SimpleClientIdentity) String() string {
|
|
||||||
var id string
|
|
||||||
if len(c.customIdentifier) > 0 {
|
|
||||||
id = "/" + c.customIdentifier
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%s/v%s%s/%s/%s",
|
|
||||||
c.clientIdentifier,
|
|
||||||
c.version,
|
|
||||||
id,
|
|
||||||
c.os,
|
|
||||||
c.implementation)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SimpleClientIdentity) Pubkey() []byte {
|
|
||||||
return []byte(c.pubkey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
|
|
||||||
c.customIdentifier = customIdentifier
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
|
|
||||||
return c.customIdentifier
|
|
||||||
}
|
|
|
@ -1,30 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestClientIdentity(t *testing.T) {
|
|
||||||
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
|
|
||||||
clientString := clientIdentity.String()
|
|
||||||
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
|
|
||||||
if clientString != expected {
|
|
||||||
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
|
|
||||||
}
|
|
||||||
customIdentifier := clientIdentity.GetCustomIdentifier()
|
|
||||||
if customIdentifier != "test" {
|
|
||||||
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
|
|
||||||
}
|
|
||||||
clientIdentity.SetCustomIdentifier("test2")
|
|
||||||
customIdentifier = clientIdentity.GetCustomIdentifier()
|
|
||||||
if customIdentifier != "test2" {
|
|
||||||
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
|
|
||||||
}
|
|
||||||
clientString = clientIdentity.String()
|
|
||||||
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
|
|
||||||
if clientString != expected {
|
|
||||||
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,363 @@
|
||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "binary"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||||
|
ethlogger "github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/obscuren/ecies"
|
||||||
|
)
|
||||||
|
|
||||||
|
var clogger = ethlogger.NewLogger("CRYPTOID")
|
||||||
|
|
||||||
|
const (
|
||||||
|
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||||
|
sigLen = 65 // elliptic S256
|
||||||
|
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
|
||||||
|
shaLen = 32 // hash length (for nonce etc)
|
||||||
|
|
||||||
|
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
|
||||||
|
authRespLen = pubLen + shaLen + 1
|
||||||
|
|
||||||
|
eciesBytes = 65 + 16 + 32
|
||||||
|
iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
|
||||||
|
rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||||
|
)
|
||||||
|
|
||||||
|
type hexkey []byte
|
||||||
|
|
||||||
|
func (self hexkey) String() string {
|
||||||
|
return fmt.Sprintf("(%d) %x", len(self), []byte(self))
|
||||||
|
}
|
||||||
|
|
||||||
|
func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
|
||||||
|
remoteID discover.NodeID,
|
||||||
|
sessionToken []byte,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
if dial == nil {
|
||||||
|
var remotePubkey []byte
|
||||||
|
sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
|
||||||
|
copy(remoteID[:], remotePubkey)
|
||||||
|
} else {
|
||||||
|
remoteID = dial.ID
|
||||||
|
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
|
||||||
|
}
|
||||||
|
return remoteID, sessionToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// outboundEncHandshake negotiates a session token on conn.
|
||||||
|
// it should be called on the dialing side of the connection.
|
||||||
|
//
|
||||||
|
// privateKey is the local client's private key
|
||||||
|
// remotePublicKey is the remote peer's node ID
|
||||||
|
// sessionToken is the token from a previous session with this node.
|
||||||
|
func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (
|
||||||
|
newSessionToken []byte,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if sessionToken != nil {
|
||||||
|
clogger.Debugf("session-token: %v", hexkey(sessionToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
|
||||||
|
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
||||||
|
randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
|
||||||
|
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
|
||||||
|
if _, err = conn.Write(auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clogger.Debugf("initiator handshake: %v", hexkey(auth))
|
||||||
|
|
||||||
|
response := make([]byte, rHSLen)
|
||||||
|
if _, err = io.ReadFull(conn, response); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
||||||
|
remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
|
||||||
|
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
|
||||||
|
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authMsg creates the initiator handshake.
|
||||||
|
func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (
|
||||||
|
auth, initNonce []byte,
|
||||||
|
randomPrvKey *ecdsa.PrivateKey,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
// session init, common to both parties
|
||||||
|
remotePubKey, err := importPublicKey(remotePubKeyS)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenFlag byte // = 0x00
|
||||||
|
if sessionToken == nil {
|
||||||
|
// no session token found means we need to generate shared secret.
|
||||||
|
// ecies shared secret is used as initial session token for new peers
|
||||||
|
// generate shared key from prv and remote pubkey
|
||||||
|
if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// tokenFlag = 0x00 // redundant
|
||||||
|
} else {
|
||||||
|
// for known peers, we use stored token from the previous session
|
||||||
|
tokenFlag = 0x01
|
||||||
|
}
|
||||||
|
|
||||||
|
//E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0)
|
||||||
|
// E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1)
|
||||||
|
// allocate msgLen long message,
|
||||||
|
var msg []byte = make([]byte, authMsgLen)
|
||||||
|
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||||
|
if _, err = rand.Read(initNonce); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// create known message
|
||||||
|
// ecdh-shared-secret^nonce for new peers
|
||||||
|
// token^nonce for old peers
|
||||||
|
var sharedSecret = xor(sessionToken, initNonce)
|
||||||
|
|
||||||
|
// generate random keypair to use for signing
|
||||||
|
if randomPrvKey, err = crypto.GenerateKey(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// sign shared secret (message known to both parties): shared-secret
|
||||||
|
var signature []byte
|
||||||
|
// signature = sign(ecdhe-random, shared-secret)
|
||||||
|
// uses secp256k1.Sign
|
||||||
|
if signature, err = crypto.Sign(sharedSecret, randomPrvKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// message
|
||||||
|
// signed-shared-secret || H(ecdhe-random-pubk) || pubk || nonce || 0x0
|
||||||
|
copy(msg, signature) // copy signed-shared-secret
|
||||||
|
// H(ecdhe-random-pubk)
|
||||||
|
var randomPubKey64 []byte
|
||||||
|
if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var pubKey64 []byte
|
||||||
|
if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64))
|
||||||
|
// pubkey copied to the correct segment.
|
||||||
|
copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64)
|
||||||
|
// nonce is already in the slice
|
||||||
|
// stick tokenFlag byte to the end
|
||||||
|
msg[authMsgLen-1] = tokenFlag
|
||||||
|
|
||||||
|
// encrypt using remote-pubk
|
||||||
|
// auth = eciesEncrypt(remote-pubk, msg)
|
||||||
|
if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// completeHandshake is called when the initiator receives an
|
||||||
|
// authentication response (aka receiver handshake). It completes the
|
||||||
|
// handshake by reading off parameters the remote peer provides needed
|
||||||
|
// to set up the secure session.
|
||||||
|
func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (
|
||||||
|
respNonce []byte,
|
||||||
|
remoteRandomPubKey *ecdsa.PublicKey,
|
||||||
|
tokenFlag bool,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
var msg []byte
|
||||||
|
// they prove that msg is meant for me,
|
||||||
|
// I prove I possess private key if i can read it
|
||||||
|
if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respNonce = msg[pubLen : pubLen+shaLen]
|
||||||
|
var remoteRandomPubKeyS = msg[:pubLen]
|
||||||
|
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg[authRespLen-1] == 0x01 {
|
||||||
|
tokenFlag = true
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// inboundEncHandshake negotiates a session token on conn.
|
||||||
|
// it should be called on the listening side of the connection.
|
||||||
|
//
|
||||||
|
// privateKey is the local client's private key
|
||||||
|
// sessionToken is the token from a previous session with this node.
|
||||||
|
func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (
|
||||||
|
token, remotePubKey []byte,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
// we are listening connection. we are responders in the
|
||||||
|
// handshake. Extract info from the authentication. The initiator
|
||||||
|
// starts by sending us a handshake that we need to respond to. so
|
||||||
|
// we read auth message first, then respond.
|
||||||
|
auth := make([]byte, iHSLen)
|
||||||
|
if _, err := io.ReadFull(conn, auth); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
|
||||||
|
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
|
||||||
|
if _, err = conn.Write(response); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
|
||||||
|
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||||
|
return token, remotePubKey, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// authResp is called by peer if it accepted (but not
|
||||||
|
// initiated) the connection from the remote. It is passed the initiator
|
||||||
|
// handshake received and the session token belonging to the
|
||||||
|
// remote initiator.
|
||||||
|
//
|
||||||
|
// The first return value is the authentication response (aka receiver
|
||||||
|
// handshake) that is to be sent to the remote initiator.
|
||||||
|
func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) (
|
||||||
|
authResp, respNonce, initNonce, remotePubKeyS []byte,
|
||||||
|
randomPrivKey *ecdsa.PrivateKey,
|
||||||
|
remoteRandomPubKey *ecdsa.PublicKey,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
// they prove that msg is meant for me,
|
||||||
|
// I prove I possess private key if i can read it
|
||||||
|
msg, err := crypto.Decrypt(prvKey, auth)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen]
|
||||||
|
remotePubKey, _ := importPublicKey(remotePubKeyS)
|
||||||
|
|
||||||
|
var tokenFlag byte
|
||||||
|
if sessionToken == nil {
|
||||||
|
// no session token found means we need to generate shared secret.
|
||||||
|
// ecies shared secret is used as initial session token for new peers
|
||||||
|
// generate shared key from prv and remote pubkey
|
||||||
|
if sessionToken, err = ecies.ImportECDSA(prvKey).GenerateShared(ecies.ImportECDSAPublic(remotePubKey), sskLen, sskLen); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// tokenFlag = 0x00 // redundant
|
||||||
|
} else {
|
||||||
|
// for known peers, we use stored token from the previous session
|
||||||
|
tokenFlag = 0x01
|
||||||
|
}
|
||||||
|
|
||||||
|
// the initiator nonce is read off the end of the message
|
||||||
|
initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||||
|
// I prove that i own prv key (to derive shared secret, and read
|
||||||
|
// nonce off encrypted msg) and that I own shared secret they
|
||||||
|
// prove they own the private key belonging to ecdhe-random-pubk
|
||||||
|
// we can now reconstruct the signed message and recover the peers
|
||||||
|
// pubkey
|
||||||
|
var signedMsg = xor(sessionToken, initNonce)
|
||||||
|
var remoteRandomPubKeyS []byte
|
||||||
|
if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// convert to ECDSA standard
|
||||||
|
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// now we find ourselves a long task too, fill it random
|
||||||
|
var resp = make([]byte, authRespLen)
|
||||||
|
// generate shaLen long nonce
|
||||||
|
respNonce = resp[pubLen : pubLen+shaLen]
|
||||||
|
if _, err = rand.Read(respNonce); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// generate random keypair for session
|
||||||
|
if randomPrivKey, err = crypto.GenerateKey(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// responder auth message
|
||||||
|
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
|
||||||
|
var randomPubKeyS []byte
|
||||||
|
if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
copy(resp[:pubLen], randomPubKeyS)
|
||||||
|
// nonce is already in the slice
|
||||||
|
resp[authRespLen-1] = tokenFlag
|
||||||
|
|
||||||
|
// encrypt using remote-pubk
|
||||||
|
// auth = eciesEncrypt(remote-pubk, msg)
|
||||||
|
// why not encrypt with ecdhe-random-remote
|
||||||
|
if authResp, err = crypto.Encrypt(remotePubKey, resp); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSession is called after the handshake is completed. The
|
||||||
|
// arguments are values negotiated in the handshake. The return value
|
||||||
|
// is a new session Token to be remembered for the next time we
|
||||||
|
// connect with this peer.
|
||||||
|
func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) {
|
||||||
|
// 3) Now we can trust ecdhe-random-pubk to derive new keys
|
||||||
|
//ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk)
|
||||||
|
pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey)
|
||||||
|
dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce))
|
||||||
|
sessionToken := crypto.Sha3(sharedSecret)
|
||||||
|
return sessionToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// importPublicKey unmarshals 512 bit public keys.
|
||||||
|
func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
|
||||||
|
var pubKey65 []byte
|
||||||
|
switch len(pubKey) {
|
||||||
|
case 64:
|
||||||
|
// add 'uncompressed key' flag
|
||||||
|
pubKey65 = append([]byte{0x04}, pubKey...)
|
||||||
|
case 65:
|
||||||
|
pubKey65 = pubKey
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
|
||||||
|
}
|
||||||
|
return crypto.ToECDSAPub(pubKey65), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
|
||||||
|
if pubKeyEC == nil {
|
||||||
|
return nil, fmt.Errorf("no ECDSA public key given")
|
||||||
|
}
|
||||||
|
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func xor(one, other []byte) (xor []byte) {
|
||||||
|
xor = make([]byte, len(one))
|
||||||
|
for i := 0; i < len(one); i++ {
|
||||||
|
xor[i] = one[i] ^ other[i]
|
||||||
|
}
|
||||||
|
return xor
|
||||||
|
}
|
|
@ -0,0 +1,167 @@
|
||||||
|
package p2p
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/rand"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/obscuren/ecies"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPublicKeyEncoding(t *testing.T) {
|
||||||
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||||
|
pub0 := &prv0.PublicKey
|
||||||
|
pub0s := crypto.FromECDSAPub(pub0)
|
||||||
|
pub1, err := importPublicKey(pub0s)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
eciesPub1 := ecies.ImportECDSAPublic(pub1)
|
||||||
|
if eciesPub1 == nil {
|
||||||
|
t.Errorf("invalid ecdsa public key")
|
||||||
|
}
|
||||||
|
pub1s, err := exportPublicKey(pub1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if len(pub1s) != 64 {
|
||||||
|
t.Errorf("wrong length expect 64, got", len(pub1s))
|
||||||
|
}
|
||||||
|
pub2, err := importPublicKey(pub1s)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
pub2s, err := exportPublicKey(pub2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(pub1s, pub2s) {
|
||||||
|
t.Errorf("exports dont match")
|
||||||
|
}
|
||||||
|
pub2sEC := crypto.FromECDSAPub(pub2)
|
||||||
|
if !bytes.Equal(pub0s, pub2sEC) {
|
||||||
|
t.Errorf("exports dont match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharedSecret(t *testing.T) {
|
||||||
|
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||||
|
pub0 := &prv0.PublicKey
|
||||||
|
prv1, _ := crypto.GenerateKey()
|
||||||
|
pub1 := &prv1.PublicKey
|
||||||
|
|
||||||
|
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
||||||
|
if !bytes.Equal(ss0, ss1) {
|
||||||
|
t.Errorf("dont match :(")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoHandshake(t *testing.T) {
|
||||||
|
testCryptoHandshake(newkey(), newkey(), nil, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoHandshakeWithToken(t *testing.T) {
|
||||||
|
sessionToken := make([]byte, shaLen)
|
||||||
|
rand.Read(sessionToken)
|
||||||
|
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
|
||||||
|
var err error
|
||||||
|
// pub0 := &prv0.PublicKey
|
||||||
|
pub1 := &prv1.PublicKey
|
||||||
|
|
||||||
|
// pub0s := crypto.FromECDSAPub(pub0)
|
||||||
|
pub1s := crypto.FromECDSAPub(pub1)
|
||||||
|
|
||||||
|
// simulate handshake by feeding output to input
|
||||||
|
// initiator sends handshake 'auth'
|
||||||
|
auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
t.Logf("-> %v", hexkey(auth))
|
||||||
|
|
||||||
|
// receiver reads auth and responds with response
|
||||||
|
response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
t.Logf("<- %v\n", hexkey(response))
|
||||||
|
|
||||||
|
// initiator reads receiver's response and the key exchange completes
|
||||||
|
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("completeHandshake error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// now both parties should have the same session parameters
|
||||||
|
initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("newSession error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("newSession error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
|
||||||
|
|
||||||
|
// fmt.Printf("\nauth %x\ninitNonce %x\nresponse%x\nremoteRecNonce %x\nremoteInitNonce %x\nremoteRandomPubKey %x\nrecNonce %x\nremoteInitRandomPubKey %x\ninitSessionToken %x\n\n", auth, initNonce, response, remoteRecNonce, remoteInitNonce, remoteRandomPubKey, recNonce, remoteInitRandomPubKey, initSessionToken)
|
||||||
|
|
||||||
|
if !bytes.Equal(initNonce, remoteInitNonce) {
|
||||||
|
t.Errorf("nonces do not match")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(recNonce, remoteRecNonce) {
|
||||||
|
t.Errorf("receiver nonces do not match")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(initSessionToken, recSessionToken) {
|
||||||
|
t.Errorf("session tokens do not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshake(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
prv0, _ := crypto.GenerateKey()
|
||||||
|
prv1, _ := crypto.GenerateKey()
|
||||||
|
pub0s, _ := exportPublicKey(&prv0.PublicKey)
|
||||||
|
pub1s, _ := exportPublicKey(&prv1.PublicKey)
|
||||||
|
rw0, rw1 := net.Pipe()
|
||||||
|
tokens := make(chan []byte)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("outbound side error: %v", err)
|
||||||
|
}
|
||||||
|
tokens <- token
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("inbound side error: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(remotePubkey, pub0s) {
|
||||||
|
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
|
||||||
|
}
|
||||||
|
tokens <- token
|
||||||
|
}()
|
||||||
|
|
||||||
|
t1, t2 := <-tokens, <-tokens
|
||||||
|
if !bytes.Equal(t1, t2) {
|
||||||
|
t.Error("session token mismatch")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,291 @@
|
||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const nodeIDBits = 512
|
||||||
|
|
||||||
|
// Node represents a host on the network.
|
||||||
|
type Node struct {
|
||||||
|
ID NodeID
|
||||||
|
IP net.IP
|
||||||
|
|
||||||
|
DiscPort int // UDP listening port for discovery protocol
|
||||||
|
TCPPort int // TCP listening port for RLPx
|
||||||
|
|
||||||
|
active time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNode(id NodeID, addr *net.UDPAddr) *Node {
|
||||||
|
return &Node{
|
||||||
|
ID: id,
|
||||||
|
IP: addr.IP,
|
||||||
|
DiscPort: addr.Port,
|
||||||
|
TCPPort: addr.Port,
|
||||||
|
active: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Node) isValid() bool {
|
||||||
|
// TODO: don't accept localhost, LAN addresses from internet hosts
|
||||||
|
return !n.IP.IsMulticast() && !n.IP.IsUnspecified() && n.TCPPort != 0 && n.DiscPort != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// The string representation of a Node is a URL.
|
||||||
|
// Please see ParseNode for a description of the format.
|
||||||
|
func (n *Node) String() string {
|
||||||
|
addr := net.TCPAddr{IP: n.IP, Port: n.TCPPort}
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "enode",
|
||||||
|
User: url.User(fmt.Sprintf("%x", n.ID[:])),
|
||||||
|
Host: addr.String(),
|
||||||
|
}
|
||||||
|
if n.DiscPort != n.TCPPort {
|
||||||
|
u.RawQuery = "discport=" + strconv.Itoa(n.DiscPort)
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseNode parses a node URL.
|
||||||
|
//
|
||||||
|
// A node URL has scheme "enode".
|
||||||
|
//
|
||||||
|
// The hexadecimal node ID is encoded in the username portion of the
|
||||||
|
// URL, separated from the host by an @ sign. The hostname can only be
|
||||||
|
// given as an IP address, DNS domain names are not allowed. The port
|
||||||
|
// in the host name section is the TCP listening port. If the TCP and
|
||||||
|
// UDP (discovery) ports differ, the UDP port is specified as query
|
||||||
|
// parameter "discport".
|
||||||
|
//
|
||||||
|
// In the following example, the node URL describes
|
||||||
|
// a node with IP address 10.3.58.6, TCP listening port 30303
|
||||||
|
// and UDP discovery port 30301.
|
||||||
|
//
|
||||||
|
// enode://<hex node id>@10.3.58.6:30303?discport=30301
|
||||||
|
func ParseNode(rawurl string) (*Node, error) {
|
||||||
|
var n Node
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if u.Scheme != "enode" {
|
||||||
|
return nil, errors.New("invalid URL scheme, want \"enode\"")
|
||||||
|
}
|
||||||
|
if u.User == nil {
|
||||||
|
return nil, errors.New("does not contain node ID")
|
||||||
|
}
|
||||||
|
if n.ID, err = HexID(u.User.String()); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid node ID (%v)", err)
|
||||||
|
}
|
||||||
|
ip, port, err := net.SplitHostPort(u.Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid host: %v", err)
|
||||||
|
}
|
||||||
|
if n.IP = net.ParseIP(ip); n.IP == nil {
|
||||||
|
return nil, errors.New("invalid IP address")
|
||||||
|
}
|
||||||
|
if n.TCPPort, err = strconv.Atoi(port); err != nil {
|
||||||
|
return nil, errors.New("invalid port")
|
||||||
|
}
|
||||||
|
qv := u.Query()
|
||||||
|
if qv.Get("discport") == "" {
|
||||||
|
n.DiscPort = n.TCPPort
|
||||||
|
} else {
|
||||||
|
if n.DiscPort, err = strconv.Atoi(qv.Get("discport")); err != nil {
|
||||||
|
return nil, errors.New("invalid discport in query")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustParseNode parses a node URL. It panics if the URL is not valid.
|
||||||
|
func MustParseNode(rawurl string) *Node {
|
||||||
|
n, err := ParseNode(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
panic("invalid node URL: " + err.Error())
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Node) EncodeRLP(w io.Writer) error {
|
||||||
|
return rlp.Encode(w, rpcNode{IP: n.IP.String(), Port: uint16(n.TCPPort), ID: n.ID})
|
||||||
|
}
|
||||||
|
func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
|
||||||
|
var ext rpcNode
|
||||||
|
if err = s.Decode(&ext); err == nil {
|
||||||
|
n.TCPPort = int(ext.Port)
|
||||||
|
n.DiscPort = int(ext.Port)
|
||||||
|
n.ID = ext.ID
|
||||||
|
if n.IP = net.ParseIP(ext.IP); n.IP == nil {
|
||||||
|
return errors.New("invalid IP string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeID is a unique identifier for each node.
|
||||||
|
// The node identifier is a marshaled elliptic curve public key.
|
||||||
|
type NodeID [nodeIDBits / 8]byte
|
||||||
|
|
||||||
|
// NodeID prints as a long hexadecimal number.
|
||||||
|
func (n NodeID) String() string {
|
||||||
|
return fmt.Sprintf("%#x", n[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// The Go syntax representation of a NodeID is a call to HexID.
|
||||||
|
func (n NodeID) GoString() string {
|
||||||
|
return fmt.Sprintf("discover.HexID(\"%#x\")", n[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// HexID converts a hex string to a NodeID.
|
||||||
|
// The string may be prefixed with 0x.
|
||||||
|
func HexID(in string) (NodeID, error) {
|
||||||
|
if strings.HasPrefix(in, "0x") {
|
||||||
|
in = in[2:]
|
||||||
|
}
|
||||||
|
var id NodeID
|
||||||
|
b, err := hex.DecodeString(in)
|
||||||
|
if err != nil {
|
||||||
|
return id, err
|
||||||
|
} else if len(b) != len(id) {
|
||||||
|
return id, fmt.Errorf("wrong length, need %d hex bytes", len(id))
|
||||||
|
}
|
||||||
|
copy(id[:], b)
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustHexID converts a hex string to a NodeID.
|
||||||
|
// It panics if the string is not a valid NodeID.
|
||||||
|
func MustHexID(in string) NodeID {
|
||||||
|
id, err := HexID(in)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// PubkeyID returns a marshaled representation of the given public key.
|
||||||
|
func PubkeyID(pub *ecdsa.PublicKey) NodeID {
|
||||||
|
var id NodeID
|
||||||
|
pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
|
||||||
|
if len(pbytes)-1 != len(id) {
|
||||||
|
panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
|
||||||
|
}
|
||||||
|
copy(id[:], pbytes[1:])
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// recoverNodeID computes the public key used to sign the
|
||||||
|
// given hash from the signature.
|
||||||
|
func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
|
||||||
|
pubkey, err := secp256k1.RecoverPubkey(hash, sig)
|
||||||
|
if err != nil {
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
if len(pubkey)-1 != len(id) {
|
||||||
|
return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
|
||||||
|
}
|
||||||
|
for i := range id {
|
||||||
|
id[i] = pubkey[i+1]
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// distcmp compares the distances a->target and b->target.
|
||||||
|
// Returns -1 if a is closer to target, 1 if b is closer to target
|
||||||
|
// and 0 if they are equal.
|
||||||
|
func distcmp(target, a, b NodeID) int {
|
||||||
|
for i := range target {
|
||||||
|
da := a[i] ^ target[i]
|
||||||
|
db := b[i] ^ target[i]
|
||||||
|
if da > db {
|
||||||
|
return 1
|
||||||
|
} else if da < db {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// table of leading zero counts for bytes [0..255]
|
||||||
|
var lzcount = [256]int{
|
||||||
|
8, 7, 6, 6, 5, 5, 5, 5,
|
||||||
|
4, 4, 4, 4, 4, 4, 4, 4,
|
||||||
|
3, 3, 3, 3, 3, 3, 3, 3,
|
||||||
|
3, 3, 3, 3, 3, 3, 3, 3,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
2, 2, 2, 2, 2, 2, 2, 2,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// logdist returns the logarithmic distance between a and b, log2(a ^ b).
|
||||||
|
func logdist(a, b NodeID) int {
|
||||||
|
lz := 0
|
||||||
|
for i := range a {
|
||||||
|
x := a[i] ^ b[i]
|
||||||
|
if x == 0 {
|
||||||
|
lz += 8
|
||||||
|
} else {
|
||||||
|
lz += lzcount[x]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(a)*8 - lz
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomID returns a random NodeID such that logdist(a, b) == n
|
||||||
|
func randomID(a NodeID, n int) (b NodeID) {
|
||||||
|
if n == 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
// flip bit at position n, fill the rest with random bits
|
||||||
|
b = a
|
||||||
|
pos := len(a) - n/8 - 1
|
||||||
|
bit := byte(0x01) << (byte(n%8) - 1)
|
||||||
|
if bit == 0 {
|
||||||
|
pos++
|
||||||
|
bit = 0x80
|
||||||
|
}
|
||||||
|
b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
|
||||||
|
for i := pos + 1; i < len(a); i++ {
|
||||||
|
b[i] = byte(rand.Intn(255))
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
|
@ -0,0 +1,201 @@
|
||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/big"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
quickrand = rand.New(rand.NewSource(time.Now().Unix()))
|
||||||
|
quickcfg = &quick.Config{MaxCount: 5000, Rand: quickrand}
|
||||||
|
)
|
||||||
|
|
||||||
|
var parseNodeTests = []struct {
|
||||||
|
rawurl string
|
||||||
|
wantError string
|
||||||
|
wantResult *Node
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
rawurl: "http://foobar",
|
||||||
|
wantError: `invalid URL scheme, want "enode"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://foobar",
|
||||||
|
wantError: `does not contain node ID`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://01010101@123.124.125.126:3",
|
||||||
|
wantError: `invalid node ID (wrong length, need 64 hex bytes)`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@hostname:3",
|
||||||
|
wantError: `invalid IP address`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:foo",
|
||||||
|
wantError: `invalid port`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:3?discport=foo",
|
||||||
|
wantError: `invalid discport in query`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150",
|
||||||
|
wantResult: &Node{
|
||||||
|
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
DiscPort: 52150,
|
||||||
|
TCPPort: 52150,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@[::]:52150",
|
||||||
|
wantResult: &Node{
|
||||||
|
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
|
||||||
|
IP: net.ParseIP("::"),
|
||||||
|
DiscPort: 52150,
|
||||||
|
TCPPort: 52150,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
rawurl: "enode://1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439@127.0.0.1:52150?discport=223344",
|
||||||
|
wantResult: &Node{
|
||||||
|
ID: MustHexID("0x1dd9d65c4552b5eb43d5ad55a2ee3f56c6cbc1c64a5c8d659f51fcd51bace24351232b8d7821617d2b29b54b81cdefb9b3e9c37d7fd5f63270bcc9e1a6f6a439"),
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
DiscPort: 223344,
|
||||||
|
TCPPort: 52150,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseNode(t *testing.T) {
|
||||||
|
for i, test := range parseNodeTests {
|
||||||
|
n, err := ParseNode(test.rawurl)
|
||||||
|
if err == nil && test.wantError != "" {
|
||||||
|
t.Errorf("test %d: got nil error, expected %#q", i, test.wantError)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil && err.Error() != test.wantError {
|
||||||
|
t.Errorf("test %d: got error %#q, expected %#q", i, err.Error(), test.wantError)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(n, test.wantResult) {
|
||||||
|
t.Errorf("test %d: result mismatch:\ngot: %#v, want: %#v", i, n, test.wantResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeString(t *testing.T) {
|
||||||
|
for i, test := range parseNodeTests {
|
||||||
|
if test.wantError != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
str := test.wantResult.String()
|
||||||
|
if str != test.rawurl {
|
||||||
|
t.Errorf("test %d: Node.String() mismatch:\ngot: %s\nwant: %s", i, str, test.rawurl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHexID(t *testing.T) {
|
||||||
|
ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
|
||||||
|
id1 := MustHexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
|
||||||
|
id2 := MustHexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
|
||||||
|
|
||||||
|
if id1 != ref {
|
||||||
|
t.Errorf("wrong id1\ngot %v\nwant %v", id1[:], ref[:])
|
||||||
|
}
|
||||||
|
if id2 != ref {
|
||||||
|
t.Errorf("wrong id2\ngot %v\nwant %v", id2[:], ref[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeID_recover(t *testing.T) {
|
||||||
|
prv := newkey()
|
||||||
|
hash := make([]byte, 32)
|
||||||
|
sig, err := crypto.Sign(hash, prv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("signing error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub := PubkeyID(&prv.PublicKey)
|
||||||
|
recpub, err := recoverNodeID(hash, sig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("recovery error: %v", err)
|
||||||
|
}
|
||||||
|
if pub != recpub {
|
||||||
|
t.Errorf("recovered wrong pubkey:\ngot: %v\nwant: %v", recpub, pub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeID_distcmp(t *testing.T) {
|
||||||
|
distcmpBig := func(target, a, b NodeID) int {
|
||||||
|
tbig := new(big.Int).SetBytes(target[:])
|
||||||
|
abig := new(big.Int).SetBytes(a[:])
|
||||||
|
bbig := new(big.Int).SetBytes(b[:])
|
||||||
|
return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
|
||||||
|
}
|
||||||
|
if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the random tests is likely to miss the case where they're equal.
|
||||||
|
func TestNodeID_distcmpEqual(t *testing.T) {
|
||||||
|
base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||||
|
x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
|
||||||
|
if distcmp(base, x, x) != 0 {
|
||||||
|
t.Errorf("distcmp(base, x, x) != 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeID_logdist(t *testing.T) {
|
||||||
|
logdistBig := func(a, b NodeID) int {
|
||||||
|
abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
|
||||||
|
return new(big.Int).Xor(abig, bbig).BitLen()
|
||||||
|
}
|
||||||
|
if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the random tests is likely to miss the case where they're equal.
|
||||||
|
func TestNodeID_logdistEqual(t *testing.T) {
|
||||||
|
x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||||
|
if logdist(x, x) != 0 {
|
||||||
|
t.Errorf("logdist(x, x) != 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeID_randomID(t *testing.T) {
|
||||||
|
// we don't use quick.Check here because its output isn't
|
||||||
|
// very helpful when the test fails.
|
||||||
|
for i := 0; i < quickcfg.MaxCount; i++ {
|
||||||
|
a := gen(NodeID{}, quickrand).(NodeID)
|
||||||
|
dist := quickrand.Intn(len(NodeID{}) * 8)
|
||||||
|
result := randomID(a, dist)
|
||||||
|
actualdist := logdist(result, a)
|
||||||
|
|
||||||
|
if dist != actualdist {
|
||||||
|
t.Log("a: ", a)
|
||||||
|
t.Log("result:", result)
|
||||||
|
t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
|
var id NodeID
|
||||||
|
m := rand.Intn(len(id))
|
||||||
|
for i := len(id) - 1; i > m; i-- {
|
||||||
|
id[i] = byte(rand.Uint32())
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(id)
|
||||||
|
}
|
|
@ -0,0 +1,280 @@
|
||||||
|
// Package discover implements the Node Discovery Protocol.
|
||||||
|
//
|
||||||
|
// The Node Discovery protocol provides a way to find RLPx nodes that
|
||||||
|
// can be connected to. It uses a Kademlia-like protocol to maintain a
|
||||||
|
// distributed database of the IDs and endpoints of all listening
|
||||||
|
// nodes.
|
||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
alpha = 3 // Kademlia concurrency factor
|
||||||
|
bucketSize = 16 // Kademlia bucket size
|
||||||
|
nBuckets = nodeIDBits + 1 // Number of buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
type Table struct {
|
||||||
|
mutex sync.Mutex // protects buckets, their content, and nursery
|
||||||
|
buckets [nBuckets]*bucket // index of known nodes by distance
|
||||||
|
nursery []*Node // bootstrap nodes
|
||||||
|
|
||||||
|
net transport
|
||||||
|
self *Node // metadata of the local node
|
||||||
|
}
|
||||||
|
|
||||||
|
// transport is implemented by the UDP transport.
|
||||||
|
// it is an interface so we can test without opening lots of UDP
|
||||||
|
// sockets and without generating a private key.
|
||||||
|
type transport interface {
|
||||||
|
ping(*Node) error
|
||||||
|
findnode(e *Node, target NodeID) ([]*Node, error)
|
||||||
|
close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// bucket contains nodes, ordered by their last activity.
|
||||||
|
type bucket struct {
|
||||||
|
lastLookup time.Time
|
||||||
|
entries []*Node
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
|
||||||
|
tab := &Table{net: t, self: newNode(ourID, ourAddr)}
|
||||||
|
for i := range tab.buckets {
|
||||||
|
tab.buckets[i] = new(bucket)
|
||||||
|
}
|
||||||
|
return tab
|
||||||
|
}
|
||||||
|
|
||||||
|
// Self returns the local node ID.
|
||||||
|
func (tab *Table) Self() NodeID {
|
||||||
|
return tab.self.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close terminates the network listener.
|
||||||
|
func (tab *Table) Close() {
|
||||||
|
tab.net.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bootstrap sets the bootstrap nodes. These nodes are used to connect
|
||||||
|
// to the network if the table is empty. Bootstrap will also attempt to
|
||||||
|
// fill the table by performing random lookup operations on the
|
||||||
|
// network.
|
||||||
|
func (tab *Table) Bootstrap(nodes []*Node) {
|
||||||
|
tab.mutex.Lock()
|
||||||
|
// TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
|
||||||
|
tab.nursery = make([]*Node, 0, len(nodes))
|
||||||
|
for _, n := range nodes {
|
||||||
|
cpy := *n
|
||||||
|
tab.nursery = append(tab.nursery, &cpy)
|
||||||
|
}
|
||||||
|
tab.mutex.Unlock()
|
||||||
|
tab.refresh()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup performs a network search for nodes close
|
||||||
|
// to the given target. It approaches the target by querying
|
||||||
|
// nodes that are closer to it on each iteration.
|
||||||
|
func (tab *Table) Lookup(target NodeID) []*Node {
|
||||||
|
var (
|
||||||
|
asked = make(map[NodeID]bool)
|
||||||
|
seen = make(map[NodeID]bool)
|
||||||
|
reply = make(chan []*Node, alpha)
|
||||||
|
pendingQueries = 0
|
||||||
|
)
|
||||||
|
// don't query further if we hit the target or ourself.
|
||||||
|
// unlikely to happen often in practice.
|
||||||
|
asked[target] = true
|
||||||
|
asked[tab.self.ID] = true
|
||||||
|
|
||||||
|
tab.mutex.Lock()
|
||||||
|
// update last lookup stamp (for refresh logic)
|
||||||
|
tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
|
||||||
|
// generate initial result set
|
||||||
|
result := tab.closest(target, bucketSize)
|
||||||
|
tab.mutex.Unlock()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// ask the alpha closest nodes that we haven't asked yet
|
||||||
|
for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
|
||||||
|
n := result.entries[i]
|
||||||
|
if !asked[n.ID] {
|
||||||
|
asked[n.ID] = true
|
||||||
|
pendingQueries++
|
||||||
|
go func() {
|
||||||
|
result, _ := tab.net.findnode(n, target)
|
||||||
|
reply <- result
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if pendingQueries == 0 {
|
||||||
|
// we have asked all closest nodes, stop the search
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the next reply
|
||||||
|
for _, n := range <-reply {
|
||||||
|
cn := n
|
||||||
|
if !seen[n.ID] {
|
||||||
|
seen[n.ID] = true
|
||||||
|
result.push(cn, bucketSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pendingQueries--
|
||||||
|
}
|
||||||
|
return result.entries
|
||||||
|
}
|
||||||
|
|
||||||
|
// refresh performs a lookup for a random target to keep buckets full.
|
||||||
|
func (tab *Table) refresh() {
|
||||||
|
ld := -1 // logdist of chosen bucket
|
||||||
|
tab.mutex.Lock()
|
||||||
|
for i, b := range tab.buckets {
|
||||||
|
if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
|
||||||
|
ld = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tab.mutex.Unlock()
|
||||||
|
|
||||||
|
result := tab.Lookup(randomID(tab.self.ID, ld))
|
||||||
|
if len(result) == 0 {
|
||||||
|
// bootstrap the table with a self lookup
|
||||||
|
tab.mutex.Lock()
|
||||||
|
tab.add(tab.nursery)
|
||||||
|
tab.mutex.Unlock()
|
||||||
|
tab.Lookup(tab.self.ID)
|
||||||
|
// TODO: the Kademlia paper says that we're supposed to perform
|
||||||
|
// random lookups in all buckets further away than our closest neighbor.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// closest returns the n nodes in the table that are closest to the
|
||||||
|
// given id. The caller must hold tab.mutex.
|
||||||
|
func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
|
||||||
|
// This is a very wasteful way to find the closest nodes but
|
||||||
|
// obviously correct. I believe that tree-based buckets would make
|
||||||
|
// this easier to implement efficiently.
|
||||||
|
close := &nodesByDistance{target: target}
|
||||||
|
for _, b := range tab.buckets {
|
||||||
|
for _, n := range b.entries {
|
||||||
|
close.push(n, nresults)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return close
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tab *Table) len() (n int) {
|
||||||
|
for _, b := range tab.buckets {
|
||||||
|
n += len(b.entries)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// bumpOrAdd updates the activity timestamp for the given node and
|
||||||
|
// attempts to insert the node into a bucket. The returned Node might
|
||||||
|
// not be part of the table. The caller must hold tab.mutex.
|
||||||
|
func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
|
||||||
|
b := tab.buckets[logdist(tab.self.ID, node)]
|
||||||
|
if n = b.bump(node); n == nil {
|
||||||
|
n = newNode(node, from)
|
||||||
|
if len(b.entries) == bucketSize {
|
||||||
|
tab.pingReplace(n, b)
|
||||||
|
} else {
|
||||||
|
b.entries = append(b.entries, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tab *Table) pingReplace(n *Node, b *bucket) {
|
||||||
|
old := b.entries[bucketSize-1]
|
||||||
|
go func() {
|
||||||
|
if err := tab.net.ping(old); err == nil {
|
||||||
|
// it responded, we don't need to replace it.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// it didn't respond, replace the node if it is still the oldest node.
|
||||||
|
tab.mutex.Lock()
|
||||||
|
if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
|
||||||
|
// slide down other entries and put the new one in front.
|
||||||
|
// TODO: insert in correct position to keep the order
|
||||||
|
copy(b.entries[1:], b.entries)
|
||||||
|
b.entries[0] = n
|
||||||
|
}
|
||||||
|
tab.mutex.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// bump updates the activity timestamp for the given node.
|
||||||
|
// The caller must hold tab.mutex.
|
||||||
|
func (tab *Table) bump(node NodeID) {
|
||||||
|
tab.buckets[logdist(tab.self.ID, node)].bump(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add puts the entries into the table if their corresponding
|
||||||
|
// bucket is not full. The caller must hold tab.mutex.
|
||||||
|
func (tab *Table) add(entries []*Node) {
|
||||||
|
outer:
|
||||||
|
for _, n := range entries {
|
||||||
|
if n == nil || n.ID == tab.self.ID {
|
||||||
|
// skip bad entries. The RLP decoder returns nil for empty
|
||||||
|
// input lists.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
|
||||||
|
for i := range bucket.entries {
|
||||||
|
if bucket.entries[i].ID == n.ID {
|
||||||
|
// already in bucket
|
||||||
|
continue outer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(bucket.entries) < bucketSize {
|
||||||
|
bucket.entries = append(bucket.entries, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *bucket) bump(id NodeID) *Node {
|
||||||
|
for i, n := range b.entries {
|
||||||
|
if n.ID == id {
|
||||||
|
n.active = time.Now()
|
||||||
|
// move it to the front
|
||||||
|
copy(b.entries[1:], b.entries[:i+1])
|
||||||
|
b.entries[0] = n
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nodesByDistance is a list of nodes, ordered by
|
||||||
|
// distance to target.
|
||||||
|
type nodesByDistance struct {
|
||||||
|
entries []*Node
|
||||||
|
target NodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// push adds the given node to the list, keeping the total size below maxElems.
|
||||||
|
func (h *nodesByDistance) push(n *Node, maxElems int) {
|
||||||
|
ix := sort.Search(len(h.entries), func(i int) bool {
|
||||||
|
return distcmp(h.target, h.entries[i].ID, n.ID) > 0
|
||||||
|
})
|
||||||
|
if len(h.entries) < maxElems {
|
||||||
|
h.entries = append(h.entries, n)
|
||||||
|
}
|
||||||
|
if ix == len(h.entries) {
|
||||||
|
// farther away than all nodes we already have.
|
||||||
|
// if there was room for it, the node is now the last element.
|
||||||
|
} else {
|
||||||
|
// slide existing entries down to make room
|
||||||
|
// this will overwrite the entry we just appended.
|
||||||
|
copy(h.entries[ix+1:], h.entries[ix:])
|
||||||
|
h.entries[ix] = n
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,311 @@
|
||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTable_bumpOrAddBucketAssign(t *testing.T) {
|
||||||
|
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||||
|
for i := 1; i < len(tab.buckets); i++ {
|
||||||
|
tab.bumpOrAdd(randomID(tab.self.ID, i), &net.UDPAddr{})
|
||||||
|
}
|
||||||
|
for i, b := range tab.buckets {
|
||||||
|
if i > 0 && len(b.entries) != 1 {
|
||||||
|
t.Errorf("bucket %d has %d entries, want 1", i, len(b.entries))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_bumpOrAddPingReplace(t *testing.T) {
|
||||||
|
pingC := make(pingC)
|
||||||
|
tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
|
||||||
|
last := fillBucket(tab, 200)
|
||||||
|
|
||||||
|
// this bumpOrAdd should not replace the last node
|
||||||
|
// because the node replies to ping.
|
||||||
|
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||||
|
|
||||||
|
pinged := <-pingC
|
||||||
|
if pinged != last.ID {
|
||||||
|
t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
tab.mutex.Lock()
|
||||||
|
defer tab.mutex.Unlock()
|
||||||
|
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||||
|
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||||
|
}
|
||||||
|
if !contains(tab.buckets[200].entries, last.ID) {
|
||||||
|
t.Error("last entry was removed")
|
||||||
|
}
|
||||||
|
if contains(tab.buckets[200].entries, new.ID) {
|
||||||
|
t.Error("new entry was added")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_bumpOrAddPingTimeout(t *testing.T) {
|
||||||
|
tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
|
||||||
|
last := fillBucket(tab, 200)
|
||||||
|
|
||||||
|
// this bumpOrAdd should replace the last node
|
||||||
|
// because the node does not reply to ping.
|
||||||
|
new := tab.bumpOrAdd(randomID(tab.self.ID, 200), &net.UDPAddr{})
|
||||||
|
|
||||||
|
// wait for async bucket update. damn. this needs to go away.
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
|
||||||
|
tab.mutex.Lock()
|
||||||
|
defer tab.mutex.Unlock()
|
||||||
|
if l := len(tab.buckets[200].entries); l != bucketSize {
|
||||||
|
t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
|
||||||
|
}
|
||||||
|
if contains(tab.buckets[200].entries, last.ID) {
|
||||||
|
t.Error("last entry was not removed")
|
||||||
|
}
|
||||||
|
if !contains(tab.buckets[200].entries, new.ID) {
|
||||||
|
t.Error("new entry was not added")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillBucket(tab *Table, ld int) (last *Node) {
|
||||||
|
b := tab.buckets[ld]
|
||||||
|
for len(b.entries) < bucketSize {
|
||||||
|
b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
|
||||||
|
}
|
||||||
|
return b.entries[bucketSize-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
type pingC chan NodeID
|
||||||
|
|
||||||
|
func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||||
|
panic("findnode called on pingRecorder")
|
||||||
|
}
|
||||||
|
func (t pingC) close() {
|
||||||
|
panic("close called on pingRecorder")
|
||||||
|
}
|
||||||
|
func (t pingC) ping(n *Node) error {
|
||||||
|
if t == nil {
|
||||||
|
return errTimeout
|
||||||
|
}
|
||||||
|
t <- n.ID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_bump(t *testing.T) {
|
||||||
|
tab := newTable(nil, NodeID{}, &net.UDPAddr{})
|
||||||
|
|
||||||
|
// add an old entry and two recent ones
|
||||||
|
oldactive := time.Now().Add(-2 * time.Minute)
|
||||||
|
old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
|
||||||
|
others := []*Node{
|
||||||
|
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||||
|
&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
|
||||||
|
}
|
||||||
|
tab.add(append(others, old))
|
||||||
|
if tab.buckets[200].entries[0] == old {
|
||||||
|
t.Fatal("old entry is at front of bucket")
|
||||||
|
}
|
||||||
|
|
||||||
|
// bumping the old entry should move it to the front
|
||||||
|
tab.bump(old.ID)
|
||||||
|
if old.active == oldactive {
|
||||||
|
t.Error("activity timestamp not updated")
|
||||||
|
}
|
||||||
|
if tab.buckets[200].entries[0] != old {
|
||||||
|
t.Errorf("bumped entry did not move to the front of bucket")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_closest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
test := func(test *closeTest) bool {
|
||||||
|
// for any node table, Target and N
|
||||||
|
tab := newTable(nil, test.Self, &net.UDPAddr{})
|
||||||
|
tab.add(test.All)
|
||||||
|
|
||||||
|
// check that doClosest(Target, N) returns nodes
|
||||||
|
result := tab.closest(test.Target, test.N).entries
|
||||||
|
if hasDuplicates(result) {
|
||||||
|
t.Errorf("result contains duplicates")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !sortedByDistanceTo(test.Target, result) {
|
||||||
|
t.Errorf("result is not sorted by distance to target")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that the number of results is min(N, tablen)
|
||||||
|
wantN := test.N
|
||||||
|
if tlen := tab.len(); tlen < test.N {
|
||||||
|
wantN = tlen
|
||||||
|
}
|
||||||
|
if len(result) != wantN {
|
||||||
|
t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
|
||||||
|
return false
|
||||||
|
} else if len(result) == 0 {
|
||||||
|
return true // no need to check distance
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that the result nodes have minimum distance to target.
|
||||||
|
for _, b := range tab.buckets {
|
||||||
|
for _, n := range b.entries {
|
||||||
|
if contains(result, n.ID) {
|
||||||
|
continue // don't run the check below for nodes in result
|
||||||
|
}
|
||||||
|
farthestResult := result[len(result)-1].ID
|
||||||
|
if distcmp(test.Target, n.ID, farthestResult) < 0 {
|
||||||
|
t.Errorf("table contains node that is closer to target but it's not in result")
|
||||||
|
t.Logf(" Target: %v", test.Target)
|
||||||
|
t.Logf(" Farthest Result: %v", farthestResult)
|
||||||
|
t.Logf(" ID: %v", n.ID)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err := quick.Check(test, quickcfg); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type closeTest struct {
|
||||||
|
Self NodeID
|
||||||
|
Target NodeID
|
||||||
|
All []*Node
|
||||||
|
N int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||||
|
t := &closeTest{
|
||||||
|
Self: gen(NodeID{}, rand).(NodeID),
|
||||||
|
Target: gen(NodeID{}, rand).(NodeID),
|
||||||
|
N: rand.Intn(bucketSize),
|
||||||
|
}
|
||||||
|
for _, id := range gen([]NodeID{}, rand).([]NodeID) {
|
||||||
|
t.All = append(t.All, &Node{ID: id})
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTable_Lookup(t *testing.T) {
|
||||||
|
self := gen(NodeID{}, quickrand).(NodeID)
|
||||||
|
target := randomID(self, 200)
|
||||||
|
transport := findnodeOracle{t, target}
|
||||||
|
tab := newTable(transport, self, &net.UDPAddr{})
|
||||||
|
|
||||||
|
// lookup on empty table returns no nodes
|
||||||
|
if results := tab.Lookup(target); len(results) > 0 {
|
||||||
|
t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
|
||||||
|
}
|
||||||
|
// seed table with initial node (otherwise lookup will terminate immediately)
|
||||||
|
tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
|
||||||
|
|
||||||
|
results := tab.Lookup(target)
|
||||||
|
t.Logf("results:")
|
||||||
|
for _, e := range results {
|
||||||
|
t.Logf(" ld=%d, %v", logdist(target, e.ID), e.ID)
|
||||||
|
}
|
||||||
|
if len(results) != bucketSize {
|
||||||
|
t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
|
||||||
|
}
|
||||||
|
if hasDuplicates(results) {
|
||||||
|
t.Errorf("result set contains duplicate entries")
|
||||||
|
}
|
||||||
|
if !sortedByDistanceTo(target, results) {
|
||||||
|
t.Errorf("result set not sorted by distance to target")
|
||||||
|
}
|
||||||
|
if !contains(results, target) {
|
||||||
|
t.Errorf("result set does not contain target")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findnode on this transport always returns at least one node
|
||||||
|
// that is one bucket closer to the target.
|
||||||
|
type findnodeOracle struct {
|
||||||
|
t *testing.T
|
||||||
|
target NodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
|
||||||
|
t.t.Logf("findnode query at dist %d", n.DiscPort)
|
||||||
|
// current log distance is encoded in port number
|
||||||
|
var result []*Node
|
||||||
|
switch n.DiscPort {
|
||||||
|
case 0:
|
||||||
|
panic("query to node at distance 0")
|
||||||
|
default:
|
||||||
|
// TODO: add more randomness to distances
|
||||||
|
next := n.DiscPort - 1
|
||||||
|
for i := 0; i < bucketSize; i++ {
|
||||||
|
result = append(result, &Node{ID: randomID(t.target, next), DiscPort: next})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t findnodeOracle) close() {}
|
||||||
|
|
||||||
|
func (t findnodeOracle) ping(n *Node) error {
|
||||||
|
return errors.New("ping is not supported by this transport")
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasDuplicates(slice []*Node) bool {
|
||||||
|
seen := make(map[NodeID]bool)
|
||||||
|
for _, e := range slice {
|
||||||
|
if seen[e.ID] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
seen[e.ID] = true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
|
||||||
|
var last NodeID
|
||||||
|
for i, e := range slice {
|
||||||
|
if i > 0 && distcmp(distbase, e.ID, last) < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
last = e.ID
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(ns []*Node, id NodeID) bool {
|
||||||
|
for _, n := range ns {
|
||||||
|
if n.ID == id {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// gen wraps quick.Value so it's easier to use.
|
||||||
|
// it generates a random value of the given value's type.
|
||||||
|
func gen(typ interface{}, rand *rand.Rand) interface{} {
|
||||||
|
v, ok := quick.Value(reflect.TypeOf(typ), rand)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
|
||||||
|
}
|
||||||
|
return v.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newkey() *ecdsa.PrivateKey {
|
||||||
|
key, err := crypto.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
panic("couldn't generate key: " + err.Error())
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
|
@ -0,0 +1,431 @@
|
||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var log = logger.NewLogger("P2P Discovery")
|
||||||
|
|
||||||
|
// Errors
|
||||||
|
var (
|
||||||
|
errPacketTooSmall = errors.New("too small")
|
||||||
|
errBadHash = errors.New("bad hash")
|
||||||
|
errExpired = errors.New("expired")
|
||||||
|
errTimeout = errors.New("RPC timeout")
|
||||||
|
errClosed = errors.New("socket closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Timeouts
|
||||||
|
const (
|
||||||
|
respTimeout = 300 * time.Millisecond
|
||||||
|
sendTimeout = 300 * time.Millisecond
|
||||||
|
expiration = 20 * time.Second
|
||||||
|
|
||||||
|
refreshInterval = 1 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// RPC packet types
|
||||||
|
const (
|
||||||
|
pingPacket = iota + 1 // zero is 'reserved'
|
||||||
|
pongPacket
|
||||||
|
findnodePacket
|
||||||
|
neighborsPacket
|
||||||
|
)
|
||||||
|
|
||||||
|
// RPC request structures
|
||||||
|
type (
|
||||||
|
ping struct {
|
||||||
|
IP string // our IP
|
||||||
|
Port uint16 // our port
|
||||||
|
Expiration uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// reply to Ping
|
||||||
|
pong struct {
|
||||||
|
ReplyTok []byte
|
||||||
|
Expiration uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
findnode struct {
|
||||||
|
// Id to look up. The responding node will send back nodes
|
||||||
|
// closest to the target.
|
||||||
|
Target NodeID
|
||||||
|
Expiration uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// reply to findnode
|
||||||
|
neighbors struct {
|
||||||
|
Nodes []*Node
|
||||||
|
Expiration uint64
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type rpcNode struct {
|
||||||
|
IP string
|
||||||
|
Port uint16
|
||||||
|
ID NodeID
|
||||||
|
}
|
||||||
|
|
||||||
|
// udp implements the RPC protocol.
|
||||||
|
type udp struct {
|
||||||
|
conn *net.UDPConn
|
||||||
|
priv *ecdsa.PrivateKey
|
||||||
|
addpending chan *pending
|
||||||
|
replies chan reply
|
||||||
|
closing chan struct{}
|
||||||
|
nat nat.Interface
|
||||||
|
|
||||||
|
*Table
|
||||||
|
}
|
||||||
|
|
||||||
|
// pending represents a pending reply.
|
||||||
|
//
|
||||||
|
// some implementations of the protocol wish to send more than one
|
||||||
|
// reply packet to findnode. in general, any neighbors packet cannot
|
||||||
|
// be matched up with a specific findnode packet.
|
||||||
|
//
|
||||||
|
// our implementation handles this by storing a callback function for
|
||||||
|
// each pending reply. incoming packets from a node are dispatched
|
||||||
|
// to all the callback functions for that node.
|
||||||
|
type pending struct {
|
||||||
|
// these fields must match in the reply.
|
||||||
|
from NodeID
|
||||||
|
ptype byte
|
||||||
|
|
||||||
|
// time when the request must complete
|
||||||
|
deadline time.Time
|
||||||
|
|
||||||
|
// callback is called when a matching reply arrives. if it returns
|
||||||
|
// true, the callback is removed from the pending reply queue.
|
||||||
|
// if it returns false, the reply is considered incomplete and
|
||||||
|
// the callback will be invoked again for the next matching reply.
|
||||||
|
callback func(resp interface{}) (done bool)
|
||||||
|
|
||||||
|
// errc receives nil when the callback indicates completion or an
|
||||||
|
// error if no further reply is received within the timeout.
|
||||||
|
errc chan<- error
|
||||||
|
}
|
||||||
|
|
||||||
|
type reply struct {
|
||||||
|
from NodeID
|
||||||
|
ptype byte
|
||||||
|
data interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenUDP returns a new table that listens for UDP packets on laddr.
|
||||||
|
func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
udp := &udp{
|
||||||
|
conn: conn,
|
||||||
|
priv: priv,
|
||||||
|
closing: make(chan struct{}),
|
||||||
|
addpending: make(chan *pending),
|
||||||
|
replies: make(chan reply),
|
||||||
|
}
|
||||||
|
|
||||||
|
realaddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
if natm != nil {
|
||||||
|
if !realaddr.IP.IsLoopback() {
|
||||||
|
go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
|
||||||
|
}
|
||||||
|
// TODO: react to external IP changes over time.
|
||||||
|
if ext, err := natm.ExternalIP(); err == nil {
|
||||||
|
realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
|
||||||
|
|
||||||
|
go udp.loop()
|
||||||
|
go udp.readLoop()
|
||||||
|
log.Infoln("Listening, ", udp.self)
|
||||||
|
return udp.Table, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *udp) close() {
|
||||||
|
close(t.closing)
|
||||||
|
t.conn.Close()
|
||||||
|
// TODO: wait for the loops to end.
|
||||||
|
}
|
||||||
|
|
||||||
|
// ping sends a ping message to the given node and waits for a reply.
|
||||||
|
func (t *udp) ping(e *Node) error {
|
||||||
|
// TODO: maybe check for ReplyTo field in callback to measure RTT
|
||||||
|
errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
|
||||||
|
t.send(e, pingPacket, ping{
|
||||||
|
IP: t.self.IP.String(),
|
||||||
|
Port: uint16(t.self.TCPPort),
|
||||||
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
|
})
|
||||||
|
return <-errc
|
||||||
|
}
|
||||||
|
|
||||||
|
// findnode sends a findnode request to the given node and waits until
|
||||||
|
// the node has sent up to k neighbors.
|
||||||
|
func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
|
||||||
|
nodes := make([]*Node, 0, bucketSize)
|
||||||
|
nreceived := 0
|
||||||
|
errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
|
||||||
|
reply := r.(*neighbors)
|
||||||
|
for _, n := range reply.Nodes {
|
||||||
|
nreceived++
|
||||||
|
if n.isValid() {
|
||||||
|
nodes = append(nodes, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nreceived >= bucketSize
|
||||||
|
})
|
||||||
|
|
||||||
|
t.send(to, findnodePacket, findnode{
|
||||||
|
Target: target,
|
||||||
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
|
})
|
||||||
|
err := <-errc
|
||||||
|
return nodes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// pending adds a reply callback to the pending reply queue.
|
||||||
|
// see the documentation of type pending for a detailed explanation.
|
||||||
|
func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
|
||||||
|
ch := make(chan error, 1)
|
||||||
|
p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
|
||||||
|
select {
|
||||||
|
case t.addpending <- p:
|
||||||
|
// loop will handle it
|
||||||
|
case <-t.closing:
|
||||||
|
ch <- errClosed
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop runs in its own goroutin. it keeps track of
|
||||||
|
// the refresh timer and the pending reply queue.
|
||||||
|
func (t *udp) loop() {
|
||||||
|
var (
|
||||||
|
pending []*pending
|
||||||
|
nextDeadline time.Time
|
||||||
|
timeout = time.NewTimer(0)
|
||||||
|
refresh = time.NewTicker(refreshInterval)
|
||||||
|
)
|
||||||
|
<-timeout.C // ignore first timeout
|
||||||
|
defer refresh.Stop()
|
||||||
|
defer timeout.Stop()
|
||||||
|
|
||||||
|
rearmTimeout := func() {
|
||||||
|
if len(pending) == 0 || nextDeadline == pending[0].deadline {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextDeadline = pending[0].deadline
|
||||||
|
timeout.Reset(nextDeadline.Sub(time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-refresh.C:
|
||||||
|
go t.refresh()
|
||||||
|
|
||||||
|
case <-t.closing:
|
||||||
|
for _, p := range pending {
|
||||||
|
p.errc <- errClosed
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
case p := <-t.addpending:
|
||||||
|
p.deadline = time.Now().Add(respTimeout)
|
||||||
|
pending = append(pending, p)
|
||||||
|
rearmTimeout()
|
||||||
|
|
||||||
|
case reply := <-t.replies:
|
||||||
|
// run matching callbacks, remove if they return false.
|
||||||
|
for i, p := range pending {
|
||||||
|
if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
|
||||||
|
p.errc <- nil
|
||||||
|
copy(pending[i:], pending[i+1:])
|
||||||
|
pending = pending[:len(pending)-1]
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rearmTimeout()
|
||||||
|
|
||||||
|
case now := <-timeout.C:
|
||||||
|
// notify and remove callbacks whose deadline is in the past.
|
||||||
|
i := 0
|
||||||
|
for ; i < len(pending) && now.After(pending[i].deadline); i++ {
|
||||||
|
pending[i].errc <- errTimeout
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
copy(pending, pending[i:])
|
||||||
|
pending = pending[:len(pending)-i]
|
||||||
|
}
|
||||||
|
rearmTimeout()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
macSize = 256 / 8
|
||||||
|
sigSize = 520 / 8
|
||||||
|
headSize = macSize + sigSize // space of packet frame data
|
||||||
|
)
|
||||||
|
|
||||||
|
var headSpace = make([]byte, headSize)
|
||||||
|
|
||||||
|
func (t *udp) send(to *Node, ptype byte, req interface{}) error {
|
||||||
|
b := new(bytes.Buffer)
|
||||||
|
b.Write(headSpace)
|
||||||
|
b.WriteByte(ptype)
|
||||||
|
if err := rlp.Encode(b, req); err != nil {
|
||||||
|
log.Errorln("error encoding packet:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := b.Bytes()
|
||||||
|
sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorln("could not sign packet:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
copy(packet[macSize:], sig)
|
||||||
|
// add the hash to the front. Note: this doesn't protect the
|
||||||
|
// packet in any way. Our public key will be part of this hash in
|
||||||
|
// the future.
|
||||||
|
copy(packet, crypto.Sha3(packet[macSize:]))
|
||||||
|
|
||||||
|
toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
|
||||||
|
log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
|
||||||
|
if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
|
||||||
|
log.DebugDetailln("UDP send failed:", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLoop runs in its own goroutine. it handles incoming UDP packets.
|
||||||
|
func (t *udp) readLoop() {
|
||||||
|
defer t.conn.Close()
|
||||||
|
buf := make([]byte, 4096) // TODO: good buffer size
|
||||||
|
for {
|
||||||
|
nbytes, from, err := t.conn.ReadFromUDP(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := t.packetIn(from, buf[:nbytes]); err != nil {
|
||||||
|
log.Debugf("Bad packet from %v: %v\n", from, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
|
||||||
|
if len(buf) < headSize+1 {
|
||||||
|
return errPacketTooSmall
|
||||||
|
}
|
||||||
|
hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
|
||||||
|
shouldhash := crypto.Sha3(buf[macSize:])
|
||||||
|
if !bytes.Equal(hash, shouldhash) {
|
||||||
|
return errBadHash
|
||||||
|
}
|
||||||
|
fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var req interface {
|
||||||
|
handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
|
||||||
|
}
|
||||||
|
switch ptype := sigdata[0]; ptype {
|
||||||
|
case pingPacket:
|
||||||
|
req = new(ping)
|
||||||
|
case pongPacket:
|
||||||
|
req = new(pong)
|
||||||
|
case findnodePacket:
|
||||||
|
req = new(findnode)
|
||||||
|
case neighborsPacket:
|
||||||
|
req = new(neighbors)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown type: %d", ptype)
|
||||||
|
}
|
||||||
|
if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.DebugDetailf("<<< %v %T %v\n", from, req, req)
|
||||||
|
return req.handle(t, from, fromID, hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||||
|
if expired(req.Expiration) {
|
||||||
|
return errExpired
|
||||||
|
}
|
||||||
|
t.mutex.Lock()
|
||||||
|
// Note: we're ignoring the provided IP address right now
|
||||||
|
n := t.bumpOrAdd(fromID, from)
|
||||||
|
if req.Port != 0 {
|
||||||
|
n.TCPPort = int(req.Port)
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.send(n, pongPacket, pong{
|
||||||
|
ReplyTok: mac,
|
||||||
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||||
|
if expired(req.Expiration) {
|
||||||
|
return errExpired
|
||||||
|
}
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.bump(fromID)
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.replies <- reply{fromID, pongPacket, req}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||||
|
if expired(req.Expiration) {
|
||||||
|
return errExpired
|
||||||
|
}
|
||||||
|
t.mutex.Lock()
|
||||||
|
e := t.bumpOrAdd(fromID, from)
|
||||||
|
closest := t.closest(req.Target, bucketSize).entries
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.send(e, neighborsPacket, neighbors{
|
||||||
|
Nodes: closest,
|
||||||
|
Expiration: uint64(time.Now().Add(expiration).Unix()),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
|
||||||
|
if expired(req.Expiration) {
|
||||||
|
return errExpired
|
||||||
|
}
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.bump(fromID)
|
||||||
|
t.add(req.Nodes)
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.replies <- reply{fromID, neighborsPacket, req}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func expired(ts uint64) bool {
|
||||||
|
return time.Unix(int64(ts), 0).Before(time.Now())
|
||||||
|
}
|
|
@ -0,0 +1,211 @@
|
||||||
|
package discover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
logpkg "log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.ErrorLevel))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_ping(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
defer n1.Close()
|
||||||
|
defer n2.Close()
|
||||||
|
|
||||||
|
if err := n1.net.ping(n2.self); err != nil {
|
||||||
|
t.Fatalf("ping error: %v", err)
|
||||||
|
}
|
||||||
|
if find(n2, n1.self.ID) == nil {
|
||||||
|
t.Errorf("node 2 does not contain id of node 1")
|
||||||
|
}
|
||||||
|
if e := find(n1, n2.self.ID); e != nil {
|
||||||
|
t.Errorf("node 1 does contains id of node 2: %v", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func find(tab *Table, id NodeID) *Node {
|
||||||
|
for _, b := range tab.buckets {
|
||||||
|
for _, e := range b.entries {
|
||||||
|
if e.ID == id {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_findnode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
defer n1.Close()
|
||||||
|
defer n2.Close()
|
||||||
|
|
||||||
|
// put a few nodes into n2. the exact distribution shouldn't
|
||||||
|
// matter much, altough we need to take care not to overflow
|
||||||
|
// any bucket.
|
||||||
|
target := randomID(n1.self.ID, 100)
|
||||||
|
nodes := &nodesByDistance{target: target}
|
||||||
|
for i := 0; i < bucketSize; i++ {
|
||||||
|
n2.add([]*Node{&Node{
|
||||||
|
IP: net.IP{1, 2, 3, byte(i)},
|
||||||
|
DiscPort: i + 2,
|
||||||
|
TCPPort: i + 2,
|
||||||
|
ID: randomID(n2.self.ID, i+2),
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
n2.add(nodes.entries)
|
||||||
|
n2.bumpOrAdd(n1.self.ID, &net.UDPAddr{IP: n1.self.IP, Port: n1.self.DiscPort})
|
||||||
|
expected := n2.closest(target, bucketSize)
|
||||||
|
|
||||||
|
err := runUDP(10, func() error {
|
||||||
|
result, _ := n1.net.findnode(n2.self, target)
|
||||||
|
if len(result) != bucketSize {
|
||||||
|
return fmt.Errorf("wrong number of results: got %d, want %d", len(result), bucketSize)
|
||||||
|
}
|
||||||
|
for i := range result {
|
||||||
|
if result[i].ID != expected.entries[i].ID {
|
||||||
|
return fmt.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, result[i], expected.entries[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_replytimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// reserve a port so we don't talk to an existing service by accident
|
||||||
|
addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
fd, err := net.ListenUDP("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer fd.Close()
|
||||||
|
|
||||||
|
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
defer n1.Close()
|
||||||
|
n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
|
||||||
|
|
||||||
|
if err := n1.net.ping(n2); err != errTimeout {
|
||||||
|
t.Error("expected timeout error, got", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
|
||||||
|
t.Error("expected timeout error, got", err)
|
||||||
|
} else if len(result) > 0 {
|
||||||
|
t.Error("expected empty result, got", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDP_findnodeMultiReply(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
n1, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
n2, _ := ListenUDP(newkey(), "127.0.0.1:0", nil)
|
||||||
|
udp2 := n2.net.(*udp)
|
||||||
|
defer n1.Close()
|
||||||
|
defer n2.Close()
|
||||||
|
|
||||||
|
err := runUDP(10, func() error {
|
||||||
|
nodes := make([]*Node, bucketSize)
|
||||||
|
for i := range nodes {
|
||||||
|
nodes[i] = &Node{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
DiscPort: i + 1,
|
||||||
|
TCPPort: i + 1,
|
||||||
|
ID: randomID(n2.self.ID, i+1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ask N2 for neighbors. it will send an empty reply back.
|
||||||
|
// the request will wait for up to bucketSize replies.
|
||||||
|
resultc := make(chan []*Node)
|
||||||
|
errc := make(chan error)
|
||||||
|
go func() {
|
||||||
|
ns, err := n1.net.findnode(n2.self, n1.self.ID)
|
||||||
|
if err != nil {
|
||||||
|
errc <- err
|
||||||
|
} else {
|
||||||
|
resultc <- ns
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// send a few more neighbors packets to N1.
|
||||||
|
// it should collect those.
|
||||||
|
for end := 0; end < len(nodes); {
|
||||||
|
off := end
|
||||||
|
if end = end + 5; end > len(nodes) {
|
||||||
|
end = len(nodes)
|
||||||
|
}
|
||||||
|
udp2.send(n1.self, neighborsPacket, neighbors{
|
||||||
|
Nodes: nodes[off:end],
|
||||||
|
Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that they are all returned. we cannot just check for
|
||||||
|
// equality because they might not be returned in the order they
|
||||||
|
// were sent.
|
||||||
|
var result []*Node
|
||||||
|
select {
|
||||||
|
case result = <-resultc:
|
||||||
|
case err := <-errc:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hasDuplicates(result) {
|
||||||
|
return fmt.Errorf("result slice contains duplicates")
|
||||||
|
}
|
||||||
|
if len(result) != len(nodes) {
|
||||||
|
return fmt.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
|
||||||
|
}
|
||||||
|
matched := make(map[NodeID]bool)
|
||||||
|
for _, n := range result {
|
||||||
|
for _, expn := range nodes {
|
||||||
|
if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
|
||||||
|
matched[n.ID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(matched) != len(nodes) {
|
||||||
|
return fmt.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runUDP runs a test n times and returns an error if the test failed
|
||||||
|
// in all n runs. This is necessary because UDP is unreliable even for
|
||||||
|
// connections on the local machine, causing test failures.
|
||||||
|
func runUDP(n int, test func() error) error {
|
||||||
|
errcount := 0
|
||||||
|
errors := ""
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if err := test(); err != nil {
|
||||||
|
errors += fmt.Sprintf("\n#%d: %v", i, err)
|
||||||
|
errcount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errcount == n {
|
||||||
|
return fmt.Errorf("failed on all %d iterations:%s", n, errors)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
143
p2p/message.go
143
p2p/message.go
|
@ -1,6 +1,7 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -8,12 +9,37 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
"github.com/ethereum/go-ethereum/ethutil"
|
||||||
"github.com/ethereum/go-ethereum/rlp"
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// parameters for frameRW
|
||||||
|
const (
|
||||||
|
// maximum time allowed for reading a message header.
|
||||||
|
// this is effectively the amount of time a connection can be idle.
|
||||||
|
frameReadTimeout = 1 * time.Minute
|
||||||
|
|
||||||
|
// maximum time allowed for reading the payload data of a message.
|
||||||
|
// this is shorter than (and distinct from) frameReadTimeout because
|
||||||
|
// the connection is not considered idle while a message is transferred.
|
||||||
|
// this also limits the payload size of messages to how much the connection
|
||||||
|
// can transfer within the timeout.
|
||||||
|
payloadReadTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// maximum amount of time allowed for writing a complete message.
|
||||||
|
msgWriteTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// messages smaller than this many bytes will be read at
|
||||||
|
// once before passing them to a protocol. this increases
|
||||||
|
// concurrency in the processing.
|
||||||
|
wholePayloadSize = 64 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
// Msg defines the structure of a p2p message.
|
// Msg defines the structure of a p2p message.
|
||||||
//
|
//
|
||||||
// Note that a Msg can only be sent once since the Payload reader is
|
// Note that a Msg can only be sent once since the Payload reader is
|
||||||
|
@ -74,11 +100,14 @@ type MsgWriter interface {
|
||||||
// WriteMsg sends a message. It will block until the message's
|
// WriteMsg sends a message. It will block until the message's
|
||||||
// Payload has been consumed by the other end.
|
// Payload has been consumed by the other end.
|
||||||
//
|
//
|
||||||
// Note that messages can be sent only once.
|
// Note that messages can be sent only once because their
|
||||||
|
// payload reader is drained.
|
||||||
WriteMsg(Msg) error
|
WriteMsg(Msg) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// MsgReadWriter provides reading and writing of encoded messages.
|
// MsgReadWriter provides reading and writing of encoded messages.
|
||||||
|
// Implementations should ensure that ReadMsg and WriteMsg can be
|
||||||
|
// called simultaneously from multiple goroutines.
|
||||||
type MsgReadWriter interface {
|
type MsgReadWriter interface {
|
||||||
MsgReader
|
MsgReader
|
||||||
MsgWriter
|
MsgWriter
|
||||||
|
@ -90,8 +119,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
|
||||||
return w.WriteMsg(NewMsg(code, data...))
|
return w.WriteMsg(NewMsg(code, data...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
|
||||||
|
// As required by the interface, ReadMsg and WriteMsg can be called from
|
||||||
|
// multiple goroutines.
|
||||||
|
type frameRW struct {
|
||||||
|
net.Conn // make Conn methods available. be careful.
|
||||||
|
bufconn *bufio.ReadWriter
|
||||||
|
|
||||||
|
// this channel is used to 'lend' bufconn to a caller of ReadMsg
|
||||||
|
// until the message payload has been consumed. the channel
|
||||||
|
// receives a value when EOF is reached on the payload, unblocking
|
||||||
|
// a pending call to ReadMsg.
|
||||||
|
rsync chan struct{}
|
||||||
|
|
||||||
|
// this mutex guards writes to bufconn.
|
||||||
|
writeMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
|
||||||
|
rsync := make(chan struct{}, 1)
|
||||||
|
rsync <- struct{}{}
|
||||||
|
return &frameRW{
|
||||||
|
Conn: conn,
|
||||||
|
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
||||||
|
rsync: rsync,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var magicToken = []byte{34, 64, 8, 145}
|
var magicToken = []byte{34, 64, 8, 145}
|
||||||
|
|
||||||
|
func (rw *frameRW) WriteMsg(msg Msg) error {
|
||||||
|
rw.writeMu.Lock()
|
||||||
|
defer rw.writeMu.Unlock()
|
||||||
|
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
|
||||||
|
if err := writeMsg(rw.bufconn, msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return rw.bufconn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
func writeMsg(w io.Writer, msg Msg) error {
|
func writeMsg(w io.Writer, msg Msg) error {
|
||||||
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
|
||||||
code := ethutil.Encode(uint32(msg.Code))
|
code := ethutil.Encode(uint32(msg.Code))
|
||||||
|
@ -120,31 +186,51 @@ func makeListHeader(length uint32) []byte {
|
||||||
return append([]byte{lenb}, enc...)
|
return append([]byte{lenb}, enc...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// readMsg reads a message header from r.
|
func (rw *frameRW) ReadMsg() (msg Msg, err error) {
|
||||||
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer.
|
<-rw.rsync // wait until bufconn is ours
|
||||||
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
|
|
||||||
|
rw.SetReadDeadline(time.Now().Add(frameReadTimeout))
|
||||||
|
|
||||||
// read magic and payload size
|
// read magic and payload size
|
||||||
start := make([]byte, 8)
|
start := make([]byte, 8)
|
||||||
if _, err = io.ReadFull(r, start); err != nil {
|
if _, err = io.ReadFull(rw.bufconn, start); err != nil {
|
||||||
return msg, newPeerError(errRead, "%v", err)
|
return msg, err
|
||||||
}
|
}
|
||||||
if !bytes.HasPrefix(start, magicToken) {
|
if !bytes.HasPrefix(start, magicToken) {
|
||||||
return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken)
|
return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken)
|
||||||
}
|
}
|
||||||
size := binary.BigEndian.Uint32(start[4:])
|
size := binary.BigEndian.Uint32(start[4:])
|
||||||
|
|
||||||
// decode start of RLP message to get the message code
|
// decode start of RLP message to get the message code
|
||||||
posr := &postrack{r, 0}
|
posr := &postrack{rw.bufconn, 0}
|
||||||
s := rlp.NewStream(posr)
|
s := rlp.NewStream(posr)
|
||||||
if _, err := s.List(); err != nil {
|
if _, err := s.List(); err != nil {
|
||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
code, err := s.Uint()
|
msg.Code, err = s.Uint()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return msg, err
|
return msg, err
|
||||||
}
|
}
|
||||||
payloadsize := size - posr.p
|
msg.Size = size - posr.p
|
||||||
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
|
|
||||||
|
rw.SetReadDeadline(time.Now().Add(payloadReadTimeout))
|
||||||
|
|
||||||
|
if msg.Size <= wholePayloadSize {
|
||||||
|
// msg is small, read all of it and move on to the next message.
|
||||||
|
pbuf := make([]byte, msg.Size)
|
||||||
|
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
rw.rsync <- struct{}{} // bufconn is available again
|
||||||
|
msg.Payload = bytes.NewReader(pbuf)
|
||||||
|
} else {
|
||||||
|
// lend bufconn to the caller until it has
|
||||||
|
// consumed the payload. eofSignal will send a value
|
||||||
|
// on rw.rsync when EOF is reached.
|
||||||
|
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
|
||||||
|
msg.Payload = pr
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// postrack wraps an rlp.ByteReader with a position counter.
|
// postrack wraps an rlp.ByteReader with a position counter.
|
||||||
|
@ -167,6 +253,39 @@ func (r *postrack) ReadByte() (byte, error) {
|
||||||
return b, err
|
return b, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// eofSignal wraps a reader with eof signaling. the eof channel is
|
||||||
|
// closed when the wrapped reader returns an error or when count bytes
|
||||||
|
// have been read.
|
||||||
|
type eofSignal struct {
|
||||||
|
wrapped io.Reader
|
||||||
|
count uint32 // number of bytes left
|
||||||
|
eof chan<- struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: when using eofSignal to detect whether a message payload
|
||||||
|
// has been read, Read might not be called for zero sized messages.
|
||||||
|
func (r *eofSignal) Read(buf []byte) (int, error) {
|
||||||
|
if r.count == 0 {
|
||||||
|
if r.eof != nil {
|
||||||
|
r.eof <- struct{}{}
|
||||||
|
r.eof = nil
|
||||||
|
}
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
max := len(buf)
|
||||||
|
if int(r.count) < len(buf) {
|
||||||
|
max = int(r.count)
|
||||||
|
}
|
||||||
|
n, err := r.wrapped.Read(buf[:max])
|
||||||
|
r.count -= uint32(n)
|
||||||
|
if (err != nil || r.count == 0) && r.eof != nil {
|
||||||
|
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
||||||
|
r.eof = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
// MsgPipe creates a message pipe. Reads on one end are matched
|
// MsgPipe creates a message pipe. Reads on one end are matched
|
||||||
// with writes on the other. The pipe is full-duplex, both ends
|
// with writes on the other. The pipe is full-duplex, both ends
|
||||||
// implement MsgReadWriter.
|
// implement MsgReadWriter.
|
||||||
|
@ -198,7 +317,7 @@ type MsgPipeRW struct {
|
||||||
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
|
func (p *MsgPipeRW) WriteMsg(msg Msg) error {
|
||||||
if atomic.LoadInt32(p.closed) == 0 {
|
if atomic.LoadInt32(p.closed) == 0 {
|
||||||
consumed := make(chan struct{}, 1)
|
consumed := make(chan struct{}, 1)
|
||||||
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed}
|
msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
|
||||||
select {
|
select {
|
||||||
case p.w <- msg:
|
case p.w <- msg:
|
||||||
if msg.Size > 0 {
|
if msg.Size > 0 {
|
||||||
|
|
|
@ -3,12 +3,11 @@ package p2p
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/ethutil"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewMsg(t *testing.T) {
|
func TestNewMsg(t *testing.T) {
|
||||||
|
@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeDecodeMsg(t *testing.T) {
|
// func TestEncodeDecodeMsg(t *testing.T) {
|
||||||
msg := NewMsg(3, 1, "000")
|
// msg := NewMsg(3, 1, "000")
|
||||||
buf := new(bytes.Buffer)
|
// buf := new(bytes.Buffer)
|
||||||
if err := writeMsg(buf, msg); err != nil {
|
// if err := writeMsg(buf, msg); err != nil {
|
||||||
t.Fatalf("encodeMsg error: %v", err)
|
// t.Fatalf("encodeMsg error: %v", err)
|
||||||
}
|
// }
|
||||||
// t.Logf("encoded: %x", buf.Bytes())
|
// // t.Logf("encoded: %x", buf.Bytes())
|
||||||
|
|
||||||
decmsg, err := readMsg(buf)
|
// decmsg, err := readMsg(buf)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatalf("readMsg error: %v", err)
|
// t.Fatalf("readMsg error: %v", err)
|
||||||
}
|
// }
|
||||||
if decmsg.Code != 3 {
|
// if decmsg.Code != 3 {
|
||||||
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
|
// t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
|
||||||
}
|
// }
|
||||||
if decmsg.Size != 5 {
|
// if decmsg.Size != 5 {
|
||||||
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
// t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
|
||||||
}
|
// }
|
||||||
|
|
||||||
var data struct {
|
// var data struct {
|
||||||
I uint
|
// I uint
|
||||||
S string
|
// S string
|
||||||
}
|
// }
|
||||||
if err := decmsg.Decode(&data); err != nil {
|
// if err := decmsg.Decode(&data); err != nil {
|
||||||
t.Fatalf("Decode error: %v", err)
|
// t.Fatalf("Decode error: %v", err)
|
||||||
}
|
// }
|
||||||
if data.I != 1 {
|
// if data.I != 1 {
|
||||||
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
|
// t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
|
||||||
}
|
// }
|
||||||
if data.S != "000" {
|
// if data.S != "000" {
|
||||||
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
|
// t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func TestDecodeRealMsg(t *testing.T) {
|
// func TestDecodeRealMsg(t *testing.T) {
|
||||||
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
|
// data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
|
||||||
msg, err := readMsg(bytes.NewReader(data))
|
// msg, err := readMsg(bytes.NewReader(data))
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
// t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
if msg.Code != 0 {
|
// if msg.Code != 0 {
|
||||||
t.Errorf("incorrect code %d, want %d", msg.Code, 0)
|
// t.Errorf("incorrect code %d, want %d", msg.Code, 0)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
func ExampleMsgPipe() {
|
func ExampleMsgPipe() {
|
||||||
rw1, rw2 := MsgPipe()
|
rw1, rw2 := MsgPipe()
|
||||||
|
@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
|
||||||
go rw1.Close()
|
go rw1.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEOFSignal(t *testing.T) {
|
||||||
|
rb := make([]byte, 10)
|
||||||
|
|
||||||
|
// empty reader
|
||||||
|
eof := make(chan struct{}, 1)
|
||||||
|
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
default:
|
||||||
|
t.Error("EOF chan not signaled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// count before error
|
||||||
|
eof = make(chan struct{}, 1)
|
||||||
|
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
default:
|
||||||
|
t.Error("EOF chan not signaled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// error before count
|
||||||
|
eof = make(chan struct{}, 1)
|
||||||
|
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 4 || err != nil {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
default:
|
||||||
|
t.Error("EOF chan not signaled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// no signal if neither occurs
|
||||||
|
eof = make(chan struct{}, 1)
|
||||||
|
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
||||||
|
if n, err := sig.Read(rb); n != 10 || err != nil {
|
||||||
|
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-eof:
|
||||||
|
t.Error("unexpected EOF signal")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
23
p2p/nat.go
23
p2p/nat.go
|
@ -1,23 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ParseNAT(natType string, gateway string) (nat NAT, err error) {
|
|
||||||
switch natType {
|
|
||||||
case "UPNP":
|
|
||||||
nat = UPNP()
|
|
||||||
case "PMP":
|
|
||||||
ip := net.ParseIP(gateway)
|
|
||||||
if ip == nil {
|
|
||||||
return nil, fmt.Errorf("cannot resolve PMP gateway IP %s", gateway)
|
|
||||||
}
|
|
||||||
nat = PMP(ip)
|
|
||||||
case "":
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unrecognised NAT type '%s'", natType)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -0,0 +1,235 @@
|
||||||
|
// Package nat provides access to common port mapping protocols.
|
||||||
|
package nat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/jackpal/go-nat-pmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var log = logger.NewLogger("P2P NAT")
|
||||||
|
|
||||||
|
// An implementation of nat.Interface can map local ports to ports
|
||||||
|
// accessible from the Internet.
|
||||||
|
type Interface interface {
|
||||||
|
// These methods manage a mapping between a port on the local
|
||||||
|
// machine to a port that can be connected to from the internet.
|
||||||
|
//
|
||||||
|
// protocol is "UDP" or "TCP". Some implementations allow setting
|
||||||
|
// a display name for the mapping. The mapping may be removed by
|
||||||
|
// the gateway when its lifetime ends.
|
||||||
|
AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
|
||||||
|
DeleteMapping(protocol string, extport, intport int) error
|
||||||
|
|
||||||
|
// This method should return the external (Internet-facing)
|
||||||
|
// address of the gateway device.
|
||||||
|
ExternalIP() (net.IP, error)
|
||||||
|
|
||||||
|
// Should return name of the method. This is used for logging.
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse parses a NAT interface description.
|
||||||
|
// The following formats are currently accepted.
|
||||||
|
// Note that mechanism names are not case-sensitive.
|
||||||
|
//
|
||||||
|
// "" or "none" return nil
|
||||||
|
// "extip:77.12.33.4" will assume the local machine is reachable on the given IP
|
||||||
|
// "any" uses the first auto-detected mechanism
|
||||||
|
// "upnp" uses the Universal Plug and Play protocol
|
||||||
|
// "pmp" uses NAT-PMP with an auto-detected gateway address
|
||||||
|
// "pmp:192.168.0.1" uses NAT-PMP with the given gateway address
|
||||||
|
func Parse(spec string) (Interface, error) {
|
||||||
|
var (
|
||||||
|
parts = strings.SplitN(spec, ":", 2)
|
||||||
|
mech = strings.ToLower(parts[0])
|
||||||
|
ip net.IP
|
||||||
|
)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
ip = net.ParseIP(parts[1])
|
||||||
|
if ip == nil {
|
||||||
|
return nil, errors.New("invalid IP address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch mech {
|
||||||
|
case "", "none", "off":
|
||||||
|
return nil, nil
|
||||||
|
case "any", "auto", "on":
|
||||||
|
return Any(), nil
|
||||||
|
case "extip", "ip":
|
||||||
|
if ip == nil {
|
||||||
|
return nil, errors.New("missing IP address")
|
||||||
|
}
|
||||||
|
return ExtIP(ip), nil
|
||||||
|
case "upnp":
|
||||||
|
return UPnP(), nil
|
||||||
|
case "pmp", "natpmp", "nat-pmp":
|
||||||
|
return PMP(ip), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown mechanism %q", parts[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
mapTimeout = 20 * time.Minute
|
||||||
|
mapUpdateInterval = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// Map adds a port mapping on m and keeps it alive until c is closed.
|
||||||
|
// This function is typically invoked in its own goroutine.
|
||||||
|
func Map(m Interface, c chan struct{}, protocol string, extport, intport int, name string) {
|
||||||
|
refresh := time.NewTimer(mapUpdateInterval)
|
||||||
|
defer func() {
|
||||||
|
refresh.Stop()
|
||||||
|
log.Debugf("Deleting port mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
|
||||||
|
m.DeleteMapping(protocol, extport, intport)
|
||||||
|
}()
|
||||||
|
log.Debugf("add mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
|
||||||
|
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
|
||||||
|
log.Errorf("mapping error: %v\n", err)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _, ok := <-c:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-refresh.C:
|
||||||
|
log.DebugDetailf("refresh mapping: %s %d -> %d (%s) using %s\n", protocol, extport, intport, name, m)
|
||||||
|
if err := m.AddMapping(protocol, intport, extport, name, mapTimeout); err != nil {
|
||||||
|
log.Errorf("mapping error: %v\n", err)
|
||||||
|
}
|
||||||
|
refresh.Reset(mapUpdateInterval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtIP assumes that the local machine is reachable on the given
|
||||||
|
// external IP address, and that any required ports were mapped manually.
|
||||||
|
// Mapping operations will not return an error but won't actually do anything.
|
||||||
|
func ExtIP(ip net.IP) Interface {
|
||||||
|
if ip == nil {
|
||||||
|
panic("IP must not be nil")
|
||||||
|
}
|
||||||
|
return extIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
type extIP net.IP
|
||||||
|
|
||||||
|
func (n extIP) ExternalIP() (net.IP, error) { return net.IP(n), nil }
|
||||||
|
func (n extIP) String() string { return fmt.Sprintf("ExtIP(%v)", net.IP(n)) }
|
||||||
|
|
||||||
|
// These do nothing.
|
||||||
|
func (extIP) AddMapping(string, int, int, string, time.Duration) error { return nil }
|
||||||
|
func (extIP) DeleteMapping(string, int, int) error { return nil }
|
||||||
|
|
||||||
|
// Any returns a port mapper that tries to discover any supported
|
||||||
|
// mechanism on the local network.
|
||||||
|
func Any() Interface {
|
||||||
|
// TODO: attempt to discover whether the local machine has an
|
||||||
|
// Internet-class address. Return ExtIP in this case.
|
||||||
|
return startautodisc("UPnP or NAT-PMP", func() Interface {
|
||||||
|
found := make(chan Interface, 2)
|
||||||
|
go func() { found <- discoverUPnP() }()
|
||||||
|
go func() { found <- discoverPMP() }()
|
||||||
|
for i := 0; i < cap(found); i++ {
|
||||||
|
if c := <-found; c != nil {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UPnP returns a port mapper that uses UPnP. It will attempt to
|
||||||
|
// discover the address of your router using UDP broadcasts.
|
||||||
|
func UPnP() Interface {
|
||||||
|
return startautodisc("UPnP", discoverUPnP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PMP returns a port mapper that uses NAT-PMP. The provided gateway
|
||||||
|
// address should be the IP of your router. If the given gateway
|
||||||
|
// address is nil, PMP will attempt to auto-discover the router.
|
||||||
|
func PMP(gateway net.IP) Interface {
|
||||||
|
if gateway != nil {
|
||||||
|
return &pmp{gw: gateway, c: natpmp.NewClient(gateway)}
|
||||||
|
}
|
||||||
|
return startautodisc("NAT-PMP", discoverPMP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// autodisc represents a port mapping mechanism that is still being
|
||||||
|
// auto-discovered. Calls to the Interface methods on this type will
|
||||||
|
// wait until the discovery is done and then call the method on the
|
||||||
|
// discovered mechanism.
|
||||||
|
//
|
||||||
|
// This type is useful because discovery can take a while but we
|
||||||
|
// want return an Interface value from UPnP, PMP and Auto immediately.
|
||||||
|
type autodisc struct {
|
||||||
|
what string
|
||||||
|
done <-chan Interface
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
found Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
func startautodisc(what string, doit func() Interface) Interface {
|
||||||
|
// TODO: monitor network configuration and rerun doit when it changes.
|
||||||
|
done := make(chan Interface)
|
||||||
|
ad := &autodisc{what: what, done: done}
|
||||||
|
go func() { done <- doit(); close(done) }()
|
||||||
|
return ad
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *autodisc) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
||||||
|
if err := n.wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return n.found.AddMapping(protocol, extport, intport, name, lifetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *autodisc) DeleteMapping(protocol string, extport, intport int) error {
|
||||||
|
if err := n.wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return n.found.DeleteMapping(protocol, extport, intport)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *autodisc) ExternalIP() (net.IP, error) {
|
||||||
|
if err := n.wait(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return n.found.ExternalIP()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *autodisc) String() string {
|
||||||
|
n.mu.Lock()
|
||||||
|
defer n.mu.Unlock()
|
||||||
|
if n.found == nil {
|
||||||
|
return n.what
|
||||||
|
} else {
|
||||||
|
return n.found.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *autodisc) wait() error {
|
||||||
|
n.mu.Lock()
|
||||||
|
found := n.found
|
||||||
|
n.mu.Unlock()
|
||||||
|
if found != nil {
|
||||||
|
// already discovered
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if found = <-n.done; found == nil {
|
||||||
|
return errors.New("no devices discovered")
|
||||||
|
}
|
||||||
|
n.mu.Lock()
|
||||||
|
n.found = found
|
||||||
|
n.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
package nat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackpal/go-nat-pmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// natPMPClient adapts the NAT-PMP protocol implementation so it conforms to
|
||||||
|
// the common interface.
|
||||||
|
type pmp struct {
|
||||||
|
gw net.IP
|
||||||
|
c *natpmp.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *pmp) String() string {
|
||||||
|
return fmt.Sprintf("NAT-PMP(%v)", n.gw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *pmp) ExternalIP() (net.IP, error) {
|
||||||
|
response, err := n.c.GetExternalAddress()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return response.ExternalIPAddress[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
||||||
|
if lifetime <= 0 {
|
||||||
|
return fmt.Errorf("lifetime must not be <= 0")
|
||||||
|
}
|
||||||
|
// Note order of port arguments is switched between our
|
||||||
|
// AddMapping and the client's AddPortMapping.
|
||||||
|
_, err := n.c.AddPortMapping(strings.ToLower(protocol), intport, extport, int(lifetime/time.Second))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *pmp) DeleteMapping(protocol string, extport, intport int) (err error) {
|
||||||
|
// To destroy a mapping, send an add-port with an internalPort of
|
||||||
|
// the internal port to destroy, an external port of zero and a
|
||||||
|
// time of zero.
|
||||||
|
_, err = n.c.AddPortMapping(strings.ToLower(protocol), intport, 0, 0)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func discoverPMP() Interface {
|
||||||
|
// run external address lookups on all potential gateways
|
||||||
|
gws := potentialGateways()
|
||||||
|
found := make(chan *pmp, len(gws))
|
||||||
|
for i := range gws {
|
||||||
|
gw := gws[i]
|
||||||
|
go func() {
|
||||||
|
c := natpmp.NewClient(gw)
|
||||||
|
if _, err := c.GetExternalAddress(); err != nil {
|
||||||
|
found <- nil
|
||||||
|
} else {
|
||||||
|
found <- &pmp{gw, c}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
// return the one that responds first.
|
||||||
|
// discovery needs to be quick, so we stop caring about
|
||||||
|
// any responses after a very short timeout.
|
||||||
|
timeout := time.NewTimer(1 * time.Second)
|
||||||
|
defer timeout.Stop()
|
||||||
|
for _ = range gws {
|
||||||
|
select {
|
||||||
|
case c := <-found:
|
||||||
|
if c != nil {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
case <-timeout.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// LAN IP ranges
|
||||||
|
_, lan10, _ = net.ParseCIDR("10.0.0.0/8")
|
||||||
|
_, lan176, _ = net.ParseCIDR("172.16.0.0/12")
|
||||||
|
_, lan192, _ = net.ParseCIDR("192.168.0.0/16")
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: improve this. We currently assume that (on most networks)
|
||||||
|
// the router is X.X.X.1 in a local LAN range.
|
||||||
|
func potentialGateways() (gws []net.IP) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
ifaddrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
return gws
|
||||||
|
}
|
||||||
|
for _, addr := range ifaddrs {
|
||||||
|
switch x := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
if lan10.Contains(x.IP) || lan176.Contains(x.IP) || lan192.Contains(x.IP) {
|
||||||
|
ip := x.IP.Mask(x.Mask).To4()
|
||||||
|
if ip != nil {
|
||||||
|
ip[3] = ip[3] | 0x01
|
||||||
|
gws = append(gws, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gws
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
package nat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fjl/goupnp"
|
||||||
|
"github.com/fjl/goupnp/dcps/internetgateway1"
|
||||||
|
"github.com/fjl/goupnp/dcps/internetgateway2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type upnp struct {
|
||||||
|
dev *goupnp.RootDevice
|
||||||
|
service string
|
||||||
|
client upnpClient
|
||||||
|
}
|
||||||
|
|
||||||
|
type upnpClient interface {
|
||||||
|
GetExternalIPAddress() (string, error)
|
||||||
|
AddPortMapping(string, uint16, string, uint16, string, bool, string, uint32) error
|
||||||
|
DeletePortMapping(string, uint16, string) error
|
||||||
|
GetNATRSIPStatus() (sip bool, nat bool, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnp) ExternalIP() (addr net.IP, err error) {
|
||||||
|
ipString, err := n.client.GetExternalIPAddress()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(ipString)
|
||||||
|
if ip == nil {
|
||||||
|
return nil, errors.New("bad IP in response")
|
||||||
|
}
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnp) AddMapping(protocol string, extport, intport int, desc string, lifetime time.Duration) error {
|
||||||
|
ip, err := n.internalAddress()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
protocol = strings.ToUpper(protocol)
|
||||||
|
lifetimeS := uint32(lifetime / time.Second)
|
||||||
|
return n.client.AddPortMapping("", uint16(extport), protocol, uint16(intport), ip.String(), true, desc, lifetimeS)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnp) internalAddress() (net.IP, error) {
|
||||||
|
devaddr, err := net.ResolveUDPAddr("udp4", n.dev.URLBase.Host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
switch x := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
if x.Contains(devaddr.IP) {
|
||||||
|
return x.IP, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("could not find local address in same net as %v", devaddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnp) DeleteMapping(protocol string, extport, intport int) error {
|
||||||
|
return n.client.DeletePortMapping("", uint16(extport), strings.ToUpper(protocol))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *upnp) String() string {
|
||||||
|
return "UPNP " + n.service
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoverUPnP searches for Internet Gateway Devices
|
||||||
|
// and returns the first one it can find on the local network.
|
||||||
|
func discoverUPnP() Interface {
|
||||||
|
found := make(chan *upnp, 2)
|
||||||
|
// IGDv1
|
||||||
|
go discover(found, internetgateway1.URN_WANConnectionDevice_1, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
|
||||||
|
switch sc.Service.ServiceType {
|
||||||
|
case internetgateway1.URN_WANIPConnection_1:
|
||||||
|
return &upnp{dev, "IGDv1-IP1", &internetgateway1.WANIPConnection1{sc}}
|
||||||
|
case internetgateway1.URN_WANPPPConnection_1:
|
||||||
|
return &upnp{dev, "IGDv1-PPP1", &internetgateway1.WANPPPConnection1{sc}}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
// IGDv2
|
||||||
|
go discover(found, internetgateway2.URN_WANConnectionDevice_2, func(dev *goupnp.RootDevice, sc goupnp.ServiceClient) *upnp {
|
||||||
|
switch sc.Service.ServiceType {
|
||||||
|
case internetgateway2.URN_WANIPConnection_1:
|
||||||
|
return &upnp{dev, "IGDv2-IP1", &internetgateway2.WANIPConnection1{sc}}
|
||||||
|
case internetgateway2.URN_WANIPConnection_2:
|
||||||
|
return &upnp{dev, "IGDv2-IP2", &internetgateway2.WANIPConnection2{sc}}
|
||||||
|
case internetgateway2.URN_WANPPPConnection_1:
|
||||||
|
return &upnp{dev, "IGDv2-PPP1", &internetgateway2.WANPPPConnection1{sc}}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
for i := 0; i < cap(found); i++ {
|
||||||
|
if c := <-found; c != nil {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func discover(out chan<- *upnp, target string, matcher func(*goupnp.RootDevice, goupnp.ServiceClient) *upnp) {
|
||||||
|
devs, err := goupnp.DiscoverDevices(target)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for i := 0; i < len(devs) && !found; i++ {
|
||||||
|
if devs[i].Root == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
devs[i].Root.Device.VisitServices(func(service *goupnp.Service) {
|
||||||
|
if found {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check for a matching IGD service
|
||||||
|
sc := goupnp.ServiceClient{service.NewSOAPClient(), devs[i].Root, service}
|
||||||
|
upnp := matcher(devs[i].Root, sc)
|
||||||
|
if upnp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check whether port mapping is enabled
|
||||||
|
if _, nat, err := upnp.client.GetNATRSIPStatus(); err != nil || !nat {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out <- upnp
|
||||||
|
found = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
out <- nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,55 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
natpmp "github.com/jackpal/go-nat-pmp"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Adapt the NAT-PMP protocol to the NAT interface
|
|
||||||
|
|
||||||
// TODO:
|
|
||||||
// + Register for changes to the external address.
|
|
||||||
// + Re-register port mapping when router reboots.
|
|
||||||
// + A mechanism for keeping a port mapping registered.
|
|
||||||
// + Discover gateway address automatically.
|
|
||||||
|
|
||||||
type natPMPClient struct {
|
|
||||||
client *natpmp.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// PMP returns a NAT traverser that uses NAT-PMP. The provided gateway
|
|
||||||
// address should be the IP of your router.
|
|
||||||
func PMP(gateway net.IP) (nat NAT) {
|
|
||||||
return &natPMPClient{natpmp.NewClient(gateway)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*natPMPClient) String() string {
|
|
||||||
return "NAT-PMP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *natPMPClient) GetExternalAddress() (net.IP, error) {
|
|
||||||
response, err := n.client.GetExternalAddress()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return response.ExternalIPAddress[:], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error {
|
|
||||||
if lifetime <= 0 {
|
|
||||||
return fmt.Errorf("lifetime must not be <= 0")
|
|
||||||
}
|
|
||||||
// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
|
|
||||||
_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
|
|
||||||
// To destroy a mapping, send an add-port with
|
|
||||||
// an internalPort of the internal port to destroy, an external port of zero and a time of zero.
|
|
||||||
_, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
|
|
||||||
return
|
|
||||||
}
|
|
341
p2p/natupnp.go
341
p2p/natupnp.go
|
@ -1,341 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
// Just enough UPnP to be able to forward ports
|
|
||||||
//
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/xml"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
upnpDiscoverAttempts = 3
|
|
||||||
upnpDiscoverTimeout = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// UPNP returns a NAT port mapper that uses UPnP. It will attempt to
|
|
||||||
// discover the address of your router using UDP broadcasts.
|
|
||||||
func UPNP() NAT {
|
|
||||||
return &upnpNAT{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type upnpNAT struct {
|
|
||||||
serviceURL string
|
|
||||||
ourIP string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) String() string {
|
|
||||||
return "UPNP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) discover() error {
|
|
||||||
if n.serviceURL != "" {
|
|
||||||
// already discovered
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// TODO: try on all network interfaces simultaneously.
|
|
||||||
// Broadcasting on 0.0.0.0 could select a random interface
|
|
||||||
// to send on (platform specific).
|
|
||||||
conn, err := net.ListenPacket("udp4", ":0")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
|
|
||||||
buf := bytes.NewBufferString(
|
|
||||||
"M-SEARCH * HTTP/1.1\r\n" +
|
|
||||||
"HOST: 239.255.255.250:1900\r\n" +
|
|
||||||
st +
|
|
||||||
"MAN: \"ssdp:discover\"\r\n" +
|
|
||||||
"MX: 2\r\n\r\n")
|
|
||||||
message := buf.Bytes()
|
|
||||||
answerBytes := make([]byte, 1024)
|
|
||||||
for i := 0; i < upnpDiscoverAttempts; i++ {
|
|
||||||
_, err = conn.WriteTo(message, ssdp)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nn, _, err := conn.ReadFrom(answerBytes)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
answer := string(answerBytes[0:nn])
|
|
||||||
if strings.Index(answer, "\r\n"+st) < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// HTTP header field names are case-insensitive.
|
|
||||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
|
|
||||||
locString := "\r\nlocation: "
|
|
||||||
answer = strings.ToLower(answer)
|
|
||||||
locIndex := strings.Index(answer, locString)
|
|
||||||
if locIndex < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
loc := answer[locIndex+len(locString):]
|
|
||||||
endIndex := strings.Index(loc, "\r\n")
|
|
||||||
if endIndex < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
locURL := loc[0:endIndex]
|
|
||||||
var serviceURL string
|
|
||||||
serviceURL, err = getServiceURL(locURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var ourIP string
|
|
||||||
ourIP, err = getOurIP()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
n.serviceURL = serviceURL
|
|
||||||
n.ourIP = ourIP
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("UPnP port discovery failed.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
|
|
||||||
if err := n.discover(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
info, err := n.getStatusInfo()
|
|
||||||
return net.ParseIP(info.externalIpAddress), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error {
|
|
||||||
if err := n.discover(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// A single concatenation would break ARM compilation.
|
|
||||||
message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
|
||||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport)
|
|
||||||
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>"
|
|
||||||
message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" +
|
|
||||||
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" +
|
|
||||||
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>"
|
|
||||||
message += description +
|
|
||||||
"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) +
|
|
||||||
"</NewLeaseDuration></u:AddPortMapping>"
|
|
||||||
|
|
||||||
// TODO: check response to see if the port was forwarded
|
|
||||||
_, err := soapRequest(n.serviceURL, "AddPortMapping", message)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error {
|
|
||||||
if err := n.discover(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
|
||||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) +
|
|
||||||
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" +
|
|
||||||
"</u:DeletePortMapping>"
|
|
||||||
|
|
||||||
// TODO: check response to see if the port was deleted
|
|
||||||
_, err := soapRequest(n.serviceURL, "DeletePortMapping", message)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type statusInfo struct {
|
|
||||||
externalIpAddress string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
|
|
||||||
message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" +
|
|
||||||
"</u:GetStatusInfo>"
|
|
||||||
|
|
||||||
var response *http.Response
|
|
||||||
response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
|
|
||||||
|
|
||||||
response.Body.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// service represents the Service type in an UPnP xml description.
|
|
||||||
// Only the parts we care about are present and thus the xml may have more
|
|
||||||
// fields than present in the structure.
|
|
||||||
type service struct {
|
|
||||||
ServiceType string `xml:"serviceType"`
|
|
||||||
ControlURL string `xml:"controlURL"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// deviceList represents the deviceList type in an UPnP xml description.
|
|
||||||
// Only the parts we care about are present and thus the xml may have more
|
|
||||||
// fields than present in the structure.
|
|
||||||
type deviceList struct {
|
|
||||||
XMLName xml.Name `xml:"deviceList"`
|
|
||||||
Device []device `xml:"device"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// serviceList represents the serviceList type in an UPnP xml description.
|
|
||||||
// Only the parts we care about are present and thus the xml may have more
|
|
||||||
// fields than present in the structure.
|
|
||||||
type serviceList struct {
|
|
||||||
XMLName xml.Name `xml:"serviceList"`
|
|
||||||
Service []service `xml:"service"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// device represents the device type in an UPnP xml description.
|
|
||||||
// Only the parts we care about are present and thus the xml may have more
|
|
||||||
// fields than present in the structure.
|
|
||||||
type device struct {
|
|
||||||
XMLName xml.Name `xml:"device"`
|
|
||||||
DeviceType string `xml:"deviceType"`
|
|
||||||
DeviceList deviceList `xml:"deviceList"`
|
|
||||||
ServiceList serviceList `xml:"serviceList"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// specVersion represents the specVersion in a UPnP xml description.
|
|
||||||
// Only the parts we care about are present and thus the xml may have more
|
|
||||||
// fields than present in the structure.
|
|
||||||
type specVersion struct {
|
|
||||||
XMLName xml.Name `xml:"specVersion"`
|
|
||||||
Major int `xml:"major"`
|
|
||||||
Minor int `xml:"minor"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// root represents the Root document for a UPnP xml description.
|
|
||||||
// Only the parts we care about are present and thus the xml may have more
|
|
||||||
// fields than present in the structure.
|
|
||||||
type root struct {
|
|
||||||
XMLName xml.Name `xml:"root"`
|
|
||||||
SpecVersion specVersion
|
|
||||||
Device device
|
|
||||||
}
|
|
||||||
|
|
||||||
func getChildDevice(d *device, deviceType string) *device {
|
|
||||||
dl := d.DeviceList.Device
|
|
||||||
for i := 0; i < len(dl); i++ {
|
|
||||||
if dl[i].DeviceType == deviceType {
|
|
||||||
return &dl[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getChildService(d *device, serviceType string) *service {
|
|
||||||
sl := d.ServiceList.Service
|
|
||||||
for i := 0; i < len(sl); i++ {
|
|
||||||
if sl[i].ServiceType == serviceType {
|
|
||||||
return &sl[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getOurIP() (ip string, err error) {
|
|
||||||
hostname, err := os.Hostname()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p, err := net.LookupIP(hostname)
|
|
||||||
if err != nil && len(p) > 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return p[0].String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getServiceURL(rootURL string) (url string, err error) {
|
|
||||||
r, err := http.Get(rootURL)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer r.Body.Close()
|
|
||||||
if r.StatusCode >= 400 {
|
|
||||||
err = errors.New(string(r.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var root root
|
|
||||||
err = xml.NewDecoder(r.Body).Decode(&root)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
a := &root.Device
|
|
||||||
if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
|
|
||||||
err = errors.New("No InternetGatewayDevice")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
|
|
||||||
if b == nil {
|
|
||||||
err = errors.New("No WANDevice")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
|
|
||||||
if c == nil {
|
|
||||||
err = errors.New("No WANConnectionDevice")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
|
|
||||||
if d == nil {
|
|
||||||
err = errors.New("No WANIPConnection")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
url = combineURL(rootURL, d.ControlURL)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func combineURL(rootURL, subURL string) string {
|
|
||||||
protocolEnd := "://"
|
|
||||||
protoEndIndex := strings.Index(rootURL, protocolEnd)
|
|
||||||
a := rootURL[protoEndIndex+len(protocolEnd):]
|
|
||||||
rootIndex := strings.Index(a, "/")
|
|
||||||
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func soapRequest(url, function, message string) (r *http.Response, err error) {
|
|
||||||
fullMessage := "<?xml version=\"1.0\" ?>" +
|
|
||||||
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" +
|
|
||||||
"<s:Body>" + message + "</s:Body></s:Envelope>"
|
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
|
|
||||||
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
|
|
||||||
//req.Header.Set("Transfer-Encoding", "chunked")
|
|
||||||
req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
|
|
||||||
req.Header.Set("Connection", "Close")
|
|
||||||
req.Header.Set("Cache-Control", "no-cache")
|
|
||||||
req.Header.Set("Pragma", "no-cache")
|
|
||||||
|
|
||||||
r, err = http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Body != nil {
|
|
||||||
defer r.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.StatusCode >= 400 {
|
|
||||||
// log.Stderr(function, r.StatusCode)
|
|
||||||
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
|
|
||||||
r = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
480
p2p/peer.go
480
p2p/peer.go
|
@ -1,8 +1,7 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"errors"
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -11,159 +10,109 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/event"
|
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// peerAddr is the structure of a peer list element.
|
const (
|
||||||
// It is also a valid net.Addr.
|
baseProtocolVersion = 3
|
||||||
type peerAddr struct {
|
baseProtocolLength = uint64(16)
|
||||||
IP net.IP
|
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
||||||
Port uint64
|
|
||||||
Pubkey []byte // optional
|
disconnectGracePeriod = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// devp2p message codes
|
||||||
|
handshakeMsg = 0x00
|
||||||
|
discMsg = 0x01
|
||||||
|
pingMsg = 0x02
|
||||||
|
pongMsg = 0x03
|
||||||
|
getPeersMsg = 0x04
|
||||||
|
peersMsg = 0x05
|
||||||
|
)
|
||||||
|
|
||||||
|
// handshake is the RLP structure of the protocol handshake.
|
||||||
|
type handshake struct {
|
||||||
|
Version uint64
|
||||||
|
Name string
|
||||||
|
Caps []Cap
|
||||||
|
ListenPort uint64
|
||||||
|
NodeID discover.NodeID
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr {
|
// Peer represents a connected remote node.
|
||||||
n := addr.Network()
|
|
||||||
if n != "tcp" && n != "tcp4" && n != "tcp6" {
|
|
||||||
// for testing with non-TCP
|
|
||||||
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
|
|
||||||
}
|
|
||||||
ta := addr.(*net.TCPAddr)
|
|
||||||
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d peerAddr) Network() string {
|
|
||||||
if d.IP.To4() != nil {
|
|
||||||
return "tcp4"
|
|
||||||
} else {
|
|
||||||
return "tcp6"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d peerAddr) String() string {
|
|
||||||
return fmt.Sprintf("%v:%d", d.IP, d.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *peerAddr) RlpData() interface{} {
|
|
||||||
return []interface{}{string(d.IP), d.Port, d.Pubkey}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peer represents a remote peer.
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
// Peers have all the log methods.
|
// Peers have all the log methods.
|
||||||
// Use them to display messages related to the peer.
|
// Use them to display messages related to the peer.
|
||||||
*logger.Logger
|
*logger.Logger
|
||||||
|
|
||||||
infolock sync.Mutex
|
infoMu sync.Mutex
|
||||||
identity ClientIdentity
|
name string
|
||||||
caps []Cap
|
caps []Cap
|
||||||
listenAddr *peerAddr // what remote peer is listening on
|
|
||||||
dialAddr *peerAddr // non-nil if dialing
|
|
||||||
|
|
||||||
// The mutex protects the connection
|
ourID, remoteID *discover.NodeID
|
||||||
// so only one protocol can write at a time.
|
ourName string
|
||||||
writeMu sync.Mutex
|
|
||||||
conn net.Conn
|
rw *frameRW
|
||||||
bufconn *bufio.ReadWriter
|
|
||||||
|
|
||||||
// These fields maintain the running protocols.
|
// These fields maintain the running protocols.
|
||||||
protocols []Protocol
|
protocols []Protocol
|
||||||
runBaseProtocol bool // for testing
|
runlock sync.RWMutex // protects running
|
||||||
|
running map[string]*proto
|
||||||
|
|
||||||
runlock sync.RWMutex // protects running
|
// disables protocol handshake, for testing
|
||||||
running map[string]*proto
|
noHandshake bool
|
||||||
|
|
||||||
protoWG sync.WaitGroup
|
protoWG sync.WaitGroup
|
||||||
protoErr chan error
|
protoErr chan error
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
disc chan DiscReason
|
disc chan DiscReason
|
||||||
|
|
||||||
activity event.TypeMux // for activity events
|
|
||||||
|
|
||||||
slot int // index into Server peer list
|
|
||||||
|
|
||||||
// These fields are kept so base protocol can access them.
|
|
||||||
// TODO: this should be one or more interfaces
|
|
||||||
ourID ClientIdentity // client id of the Server
|
|
||||||
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
|
|
||||||
newPeerAddr chan<- *peerAddr // tell server about received peers
|
|
||||||
otherPeers func() []*Peer // should return the list of all peers
|
|
||||||
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeer returns a peer for testing purposes.
|
// NewPeer returns a peer for testing purposes.
|
||||||
func NewPeer(id ClientIdentity, caps []Cap) *Peer {
|
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||||
conn, _ := net.Pipe()
|
conn, _ := net.Pipe()
|
||||||
peer := newPeer(conn, nil, nil)
|
peer := newPeer(conn, nil, "", nil, &id)
|
||||||
peer.setHandshakeInfo(id, nil, caps)
|
peer.setHandshakeInfo(name, caps)
|
||||||
close(peer.closed)
|
close(peer.closed) // ensures Disconnect doesn't block
|
||||||
return peer
|
return peer
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
// ID returns the node's public key.
|
||||||
p := newPeer(conn, server.Protocols, dialAddr)
|
func (p *Peer) ID() discover.NodeID {
|
||||||
p.ourID = server.Identity
|
return *p.remoteID
|
||||||
p.newPeerAddr = server.peerConnect
|
|
||||||
p.otherPeers = server.Peers
|
|
||||||
p.pubkeyHook = server.verifyPeer
|
|
||||||
p.runBaseProtocol = true
|
|
||||||
|
|
||||||
// laddr can be updated concurrently by NAT traversal.
|
|
||||||
// newServerPeer must be called with the server lock held.
|
|
||||||
if server.laddr != nil {
|
|
||||||
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer {
|
// Name returns the node name that the remote node advertised.
|
||||||
p := &Peer{
|
func (p *Peer) Name() string {
|
||||||
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()),
|
// this needs a lock because the information is part of the
|
||||||
conn: conn,
|
// protocol handshake.
|
||||||
dialAddr: dialAddr,
|
p.infoMu.Lock()
|
||||||
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
|
name := p.name
|
||||||
protocols: protocols,
|
p.infoMu.Unlock()
|
||||||
running: make(map[string]*proto),
|
return name
|
||||||
disc: make(chan DiscReason),
|
|
||||||
protoErr: make(chan error),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
// Identity returns the client identity of the remote peer. The
|
|
||||||
// identity can be nil if the peer has not yet completed the
|
|
||||||
// handshake.
|
|
||||||
func (p *Peer) Identity() ClientIdentity {
|
|
||||||
p.infolock.Lock()
|
|
||||||
defer p.infolock.Unlock()
|
|
||||||
return p.identity
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||||
func (p *Peer) Caps() []Cap {
|
func (p *Peer) Caps() []Cap {
|
||||||
p.infolock.Lock()
|
// this needs a lock because the information is part of the
|
||||||
defer p.infolock.Unlock()
|
// protocol handshake.
|
||||||
return p.caps
|
p.infoMu.Lock()
|
||||||
}
|
caps := p.caps
|
||||||
|
p.infoMu.Unlock()
|
||||||
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) {
|
return caps
|
||||||
p.infolock.Lock()
|
|
||||||
p.identity = id
|
|
||||||
p.listenAddr = laddr
|
|
||||||
p.caps = caps
|
|
||||||
p.infolock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoteAddr returns the remote address of the network connection.
|
// RemoteAddr returns the remote address of the network connection.
|
||||||
func (p *Peer) RemoteAddr() net.Addr {
|
func (p *Peer) RemoteAddr() net.Addr {
|
||||||
return p.conn.RemoteAddr()
|
return p.rw.RemoteAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the local address of the network connection.
|
// LocalAddr returns the local address of the network connection.
|
||||||
func (p *Peer) LocalAddr() net.Addr {
|
func (p *Peer) LocalAddr() net.Addr {
|
||||||
return p.conn.LocalAddr()
|
return p.rw.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect terminates the peer connection with the given reason.
|
// Disconnect terminates the peer connection with the given reason.
|
||||||
|
@ -177,149 +126,177 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
||||||
|
|
||||||
// String implements fmt.Stringer.
|
// String implements fmt.Stringer.
|
||||||
func (p *Peer) String() string {
|
func (p *Peer) String() string {
|
||||||
kind := "inbound"
|
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
||||||
p.infolock.Lock()
|
|
||||||
if p.dialAddr != nil {
|
|
||||||
kind = "outbound"
|
|
||||||
}
|
|
||||||
p.infolock.Unlock()
|
|
||||||
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
||||||
// maximum amount of time allowed for reading a message
|
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
||||||
msgReadTimeout = 5 * time.Second
|
return &Peer{
|
||||||
// maximum amount of time allowed for writing a message
|
Logger: logger.NewLogger(logtag),
|
||||||
msgWriteTimeout = 5 * time.Second
|
rw: newFrameRW(conn, msgWriteTimeout),
|
||||||
// messages smaller than this many bytes will be read at
|
ourID: ourID,
|
||||||
// once before passing them to a protocol.
|
ourName: ourName,
|
||||||
wholePayloadSize = 64 * 1024
|
remoteID: remoteID,
|
||||||
)
|
protocols: protocols,
|
||||||
|
running: make(map[string]*proto),
|
||||||
|
disc: make(chan DiscReason),
|
||||||
|
protoErr: make(chan error),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
|
||||||
inactivityTimeout = 2 * time.Second
|
p.infoMu.Lock()
|
||||||
disconnectGracePeriod = 2 * time.Second
|
p.name = name
|
||||||
)
|
p.caps = caps
|
||||||
|
p.infoMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Peer) loop() (reason DiscReason, err error) {
|
func (p *Peer) run() DiscReason {
|
||||||
defer p.activity.Stop()
|
var readErr = make(chan error, 1)
|
||||||
defer p.closeProtocols()
|
defer p.closeProtocols()
|
||||||
defer close(p.closed)
|
defer close(p.closed)
|
||||||
defer p.conn.Close()
|
|
||||||
|
|
||||||
// read loop
|
go func() { readErr <- p.readLoop() }()
|
||||||
readMsg := make(chan Msg)
|
|
||||||
readErr := make(chan error)
|
|
||||||
readNext := make(chan bool, 1)
|
|
||||||
protoDone := make(chan struct{}, 1)
|
|
||||||
go p.readLoop(readMsg, readErr, readNext)
|
|
||||||
readNext <- true
|
|
||||||
|
|
||||||
if p.runBaseProtocol {
|
if !p.noHandshake {
|
||||||
p.startBaseProtocol()
|
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
||||||
}
|
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
||||||
|
p.rw.Close()
|
||||||
loop:
|
return DiscProtocolError
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case msg := <-readMsg:
|
|
||||||
// a new message has arrived.
|
|
||||||
var wait bool
|
|
||||||
if wait, err = p.dispatch(msg, protoDone); err != nil {
|
|
||||||
p.Errorf("msg dispatch error: %v\n", err)
|
|
||||||
reason = discReasonForError(err)
|
|
||||||
break loop
|
|
||||||
}
|
|
||||||
if !wait {
|
|
||||||
// Msg has already been read completely, continue with next message.
|
|
||||||
readNext <- true
|
|
||||||
}
|
|
||||||
p.activity.Post(time.Now())
|
|
||||||
case <-protoDone:
|
|
||||||
// protocol has consumed the message payload,
|
|
||||||
// we can continue reading from the socket.
|
|
||||||
readNext <- true
|
|
||||||
|
|
||||||
case err := <-readErr:
|
|
||||||
// read failed. there is no need to run the
|
|
||||||
// polite disconnect sequence because the connection
|
|
||||||
// is probably dead anyway.
|
|
||||||
// TODO: handle write errors as well
|
|
||||||
return DiscNetworkError, err
|
|
||||||
case err = <-p.protoErr:
|
|
||||||
reason = discReasonForError(err)
|
|
||||||
break loop
|
|
||||||
case reason = <-p.disc:
|
|
||||||
break loop
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for read loop to return.
|
// Wait for an error or disconnect.
|
||||||
close(readNext)
|
var reason DiscReason
|
||||||
|
select {
|
||||||
|
case err := <-readErr:
|
||||||
|
// We rely on protocols to abort if there is a write error. It
|
||||||
|
// might be more robust to handle them here as well.
|
||||||
|
p.DebugDetailf("Read error: %v\n", err)
|
||||||
|
p.rw.Close()
|
||||||
|
return DiscNetworkError
|
||||||
|
|
||||||
|
case err := <-p.protoErr:
|
||||||
|
reason = discReasonForError(err)
|
||||||
|
case reason = <-p.disc:
|
||||||
|
}
|
||||||
|
p.politeDisconnect(reason)
|
||||||
|
|
||||||
|
// Wait for readLoop. It will end because conn is now closed.
|
||||||
<-readErr
|
<-readErr
|
||||||
// tell the remote end to disconnect
|
p.Debugf("Disconnected: %v\n", reason)
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod))
|
EncodeMsg(p.rw, discMsg, uint(reason))
|
||||||
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod)
|
// Wait for the other side to close the connection.
|
||||||
io.Copy(ioutil.Discard, p.conn)
|
// Discard any data that they send until then.
|
||||||
|
io.Copy(ioutil.Discard, p.rw)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(disconnectGracePeriod):
|
case <-time.After(disconnectGracePeriod):
|
||||||
}
|
}
|
||||||
return reason, err
|
p.rw.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) {
|
func (p *Peer) readLoop() error {
|
||||||
for _ = range unblock {
|
if !p.noHandshake {
|
||||||
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout))
|
if err := readProtocolHandshake(p, p.rw); err != nil {
|
||||||
if msg, err := readMsg(p.bufconn); err != nil {
|
return err
|
||||||
errc <- err
|
|
||||||
} else {
|
|
||||||
msgc <- msg
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
close(errc)
|
for {
|
||||||
}
|
msg, err := p.rw.ReadMsg()
|
||||||
|
|
||||||
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) {
|
|
||||||
proto, err := p.getProto(msg.Code)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if msg.Size <= wholePayloadSize {
|
|
||||||
// optimization: msg is small enough, read all
|
|
||||||
// of it and move on to the next message
|
|
||||||
buf, err := ioutil.ReadAll(msg.Payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return err
|
||||||
|
}
|
||||||
|
if err = p.handle(msg); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
msg.Payload = bytes.NewReader(buf)
|
|
||||||
proto.in <- msg
|
|
||||||
} else {
|
|
||||||
wait = true
|
|
||||||
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
|
|
||||||
msg.Payload = pr
|
|
||||||
proto.in <- msg
|
|
||||||
}
|
}
|
||||||
return wait, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) startBaseProtocol() {
|
func (p *Peer) handle(msg Msg) error {
|
||||||
p.runlock.Lock()
|
switch {
|
||||||
defer p.runlock.Unlock()
|
case msg.Code == pingMsg:
|
||||||
p.running[""] = p.startProto(0, Protocol{
|
msg.Discard()
|
||||||
Length: baseProtocolLength,
|
go EncodeMsg(p.rw, pongMsg)
|
||||||
Run: runBaseProtocol,
|
case msg.Code == discMsg:
|
||||||
})
|
var reason DiscReason
|
||||||
|
// no need to discard or for error checking, we'll close the
|
||||||
|
// connection after this.
|
||||||
|
rlp.Decode(msg.Payload, &reason)
|
||||||
|
p.Disconnect(DiscRequested)
|
||||||
|
return discRequestedError(reason)
|
||||||
|
case msg.Code < baseProtocolLength:
|
||||||
|
// ignore other base protocol messages
|
||||||
|
return msg.Discard()
|
||||||
|
default:
|
||||||
|
// it's a subprotocol message
|
||||||
|
proto, err := p.getProto(msg.Code)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("msg code out of range: %v", msg.Code)
|
||||||
|
}
|
||||||
|
proto.in <- msg
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
||||||
|
// read and handle remote handshake
|
||||||
|
msg, err := rw.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Code == discMsg {
|
||||||
|
// disconnect before protocol handshake is valid according to the
|
||||||
|
// spec and we send it ourself if Server.addPeer fails.
|
||||||
|
var reason DiscReason
|
||||||
|
rlp.Decode(msg.Payload, &reason)
|
||||||
|
return discRequestedError(reason)
|
||||||
|
}
|
||||||
|
if msg.Code != handshakeMsg {
|
||||||
|
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
||||||
|
}
|
||||||
|
if msg.Size > baseProtocolMaxMsgSize {
|
||||||
|
return newPeerError(errInvalidMsg, "message too big")
|
||||||
|
}
|
||||||
|
var hs handshake
|
||||||
|
if err := msg.Decode(&hs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// validate handshake info
|
||||||
|
if hs.Version != baseProtocolVersion {
|
||||||
|
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
|
||||||
|
baseProtocolVersion, hs.Version)
|
||||||
|
}
|
||||||
|
if hs.NodeID == *p.remoteID {
|
||||||
|
return newPeerError(errPubkeyForbidden, "node ID mismatch")
|
||||||
|
}
|
||||||
|
// TODO: remove Caps with empty name
|
||||||
|
p.setHandshakeInfo(hs.Name, hs.Caps)
|
||||||
|
p.startSubprotocols(hs.Caps)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
|
||||||
|
var caps []interface{}
|
||||||
|
for _, proto := range ps {
|
||||||
|
caps = append(caps, proto.cap())
|
||||||
|
}
|
||||||
|
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// startProtocols starts matching named subprotocols.
|
// startProtocols starts matching named subprotocols.
|
||||||
func (p *Peer) startSubprotocols(caps []Cap) {
|
func (p *Peer) startSubprotocols(caps []Cap) {
|
||||||
sort.Sort(capsByName(caps))
|
sort.Sort(capsByName(caps))
|
||||||
|
|
||||||
p.runlock.Lock()
|
p.runlock.Lock()
|
||||||
defer p.runlock.Unlock()
|
defer p.runlock.Unlock()
|
||||||
offset := baseProtocolLength
|
offset := baseProtocolLength
|
||||||
|
@ -338,20 +315,22 @@ outer:
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||||
|
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
|
||||||
rw := &proto{
|
rw := &proto{
|
||||||
|
name: impl.Name,
|
||||||
in: make(chan Msg),
|
in: make(chan Msg),
|
||||||
offset: offset,
|
offset: offset,
|
||||||
maxcode: impl.Length,
|
maxcode: impl.Length,
|
||||||
peer: p,
|
w: p.rw,
|
||||||
}
|
}
|
||||||
p.protoWG.Add(1)
|
p.protoWG.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
err := impl.Run(p, rw)
|
err := impl.Run(p, rw)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.Infof("protocol %q returned", impl.Name)
|
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
||||||
err = newPeerError(errMisc, "protocol returned")
|
err = errors.New("protocol returned")
|
||||||
} else {
|
} else {
|
||||||
p.Errorf("protocol %q error: %v\n", impl.Name, err)
|
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case p.protoErr <- err:
|
case p.protoErr <- err:
|
||||||
|
@ -385,6 +364,7 @@ func (p *Peer) closeProtocols() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||||
|
// this exists because of Server.Broadcast.
|
||||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||||
p.runlock.RLock()
|
p.runlock.RLock()
|
||||||
proto, ok := p.running[protoName]
|
proto, ok := p.running[protoName]
|
||||||
|
@ -396,25 +376,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||||
}
|
}
|
||||||
msg.Code += proto.offset
|
msg.Code += proto.offset
|
||||||
return p.writeMsg(msg, msgWriteTimeout)
|
return p.rw.WriteMsg(msg)
|
||||||
}
|
|
||||||
|
|
||||||
// writeMsg writes a message to the connection.
|
|
||||||
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
|
|
||||||
p.writeMu.Lock()
|
|
||||||
defer p.writeMu.Unlock()
|
|
||||||
p.conn.SetWriteDeadline(time.Now().Add(timeout))
|
|
||||||
if err := writeMsg(p.bufconn, msg); err != nil {
|
|
||||||
return newPeerError(errWrite, "%v", err)
|
|
||||||
}
|
|
||||||
return p.bufconn.Flush()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type proto struct {
|
type proto struct {
|
||||||
name string
|
name string
|
||||||
in chan Msg
|
in chan Msg
|
||||||
maxcode, offset uint64
|
maxcode, offset uint64
|
||||||
peer *Peer
|
w MsgWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *proto) WriteMsg(msg Msg) error {
|
func (rw *proto) WriteMsg(msg Msg) error {
|
||||||
|
@ -422,11 +391,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
|
||||||
return newPeerError(errInvalidMsgCode, "not handled")
|
return newPeerError(errInvalidMsgCode, "not handled")
|
||||||
}
|
}
|
||||||
msg.Code += rw.offset
|
msg.Code += rw.offset
|
||||||
return rw.peer.writeMsg(msg, msgWriteTimeout)
|
return rw.w.WriteMsg(msg)
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
|
|
||||||
return rw.WriteMsg(NewMsg(code, data...))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *proto) ReadMsg() (Msg, error) {
|
func (rw *proto) ReadMsg() (Msg, error) {
|
||||||
|
@ -437,26 +402,3 @@ func (rw *proto) ReadMsg() (Msg, error) {
|
||||||
msg.Code -= rw.offset
|
msg.Code -= rw.offset
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// eofSignal wraps a reader with eof signaling. the eof channel is
|
|
||||||
// closed when the wrapped reader returns an error or when count bytes
|
|
||||||
// have been read.
|
|
||||||
//
|
|
||||||
type eofSignal struct {
|
|
||||||
wrapped io.Reader
|
|
||||||
count int64
|
|
||||||
eof chan<- struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// note: when using eofSignal to detect whether a message payload
|
|
||||||
// has been read, Read might not be called for zero sized messages.
|
|
||||||
|
|
||||||
func (r *eofSignal) Read(buf []byte) (int, error) {
|
|
||||||
n, err := r.wrapped.Read(buf)
|
|
||||||
r.count -= int64(n)
|
|
||||||
if (err != nil || r.count <= 0) && r.eof != nil {
|
|
||||||
r.eof <- struct{}{} // tell Peer that msg has been consumed
|
|
||||||
r.eof = nil
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,7 +12,6 @@ const (
|
||||||
errInvalidMsgCode
|
errInvalidMsgCode
|
||||||
errInvalidMsg
|
errInvalidMsg
|
||||||
errP2PVersionMismatch
|
errP2PVersionMismatch
|
||||||
errPubkeyMissing
|
|
||||||
errPubkeyInvalid
|
errPubkeyInvalid
|
||||||
errPubkeyForbidden
|
errPubkeyForbidden
|
||||||
errProtocolBreach
|
errProtocolBreach
|
||||||
|
@ -22,20 +21,19 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var errorToString = map[int]string{
|
var errorToString = map[int]string{
|
||||||
errMagicTokenMismatch: "Magic token mismatch",
|
errMagicTokenMismatch: "magic token mismatch",
|
||||||
errRead: "Read error",
|
errRead: "read error",
|
||||||
errWrite: "Write error",
|
errWrite: "write error",
|
||||||
errMisc: "Misc error",
|
errMisc: "misc error",
|
||||||
errInvalidMsgCode: "Invalid message code",
|
errInvalidMsgCode: "invalid message code",
|
||||||
errInvalidMsg: "Invalid message",
|
errInvalidMsg: "invalid message",
|
||||||
errP2PVersionMismatch: "P2P Version Mismatch",
|
errP2PVersionMismatch: "P2P Version Mismatch",
|
||||||
errPubkeyMissing: "Public key missing",
|
errPubkeyInvalid: "public key invalid",
|
||||||
errPubkeyInvalid: "Public key invalid",
|
errPubkeyForbidden: "public key forbidden",
|
||||||
errPubkeyForbidden: "Public key forbidden",
|
errProtocolBreach: "protocol Breach",
|
||||||
errProtocolBreach: "Protocol Breach",
|
errPingTimeout: "ping timeout",
|
||||||
errPingTimeout: "Ping timeout",
|
errInvalidNetworkId: "invalid network id",
|
||||||
errInvalidNetworkId: "Invalid network id",
|
errInvalidProtocolVersion: "invalid protocol version",
|
||||||
errInvalidProtocolVersion: "Invalid protocol version",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type peerError struct {
|
type peerError struct {
|
||||||
|
@ -62,22 +60,22 @@ func (self *peerError) Error() string {
|
||||||
type DiscReason byte
|
type DiscReason byte
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DiscRequested DiscReason = 0x00
|
DiscRequested DiscReason = iota
|
||||||
DiscNetworkError = 0x01
|
DiscNetworkError
|
||||||
DiscProtocolError = 0x02
|
DiscProtocolError
|
||||||
DiscUselessPeer = 0x03
|
DiscUselessPeer
|
||||||
DiscTooManyPeers = 0x04
|
DiscTooManyPeers
|
||||||
DiscAlreadyConnected = 0x05
|
DiscAlreadyConnected
|
||||||
DiscIncompatibleVersion = 0x06
|
DiscIncompatibleVersion
|
||||||
DiscInvalidIdentity = 0x07
|
DiscInvalidIdentity
|
||||||
DiscQuitting = 0x08
|
DiscQuitting
|
||||||
DiscUnexpectedIdentity = 0x09
|
DiscUnexpectedIdentity
|
||||||
DiscSelf = 0x0a
|
DiscSelf
|
||||||
DiscReadTimeout = 0x0b
|
DiscReadTimeout
|
||||||
DiscSubprotocolError = 0x10
|
DiscSubprotocolError
|
||||||
)
|
)
|
||||||
|
|
||||||
var discReasonToString = [DiscSubprotocolError + 1]string{
|
var discReasonToString = [...]string{
|
||||||
DiscRequested: "Disconnect requested",
|
DiscRequested: "Disconnect requested",
|
||||||
DiscNetworkError: "Network error",
|
DiscNetworkError: "Network error",
|
||||||
DiscProtocolError: "Breach of protocol",
|
DiscProtocolError: "Breach of protocol",
|
||||||
|
@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
|
||||||
switch peerError.Code {
|
switch peerError.Code {
|
||||||
case errP2PVersionMismatch:
|
case errP2PVersionMismatch:
|
||||||
return DiscIncompatibleVersion
|
return DiscIncompatibleVersion
|
||||||
case errPubkeyMissing, errPubkeyInvalid:
|
case errPubkeyInvalid:
|
||||||
return DiscInvalidIdentity
|
return DiscInvalidIdentity
|
||||||
case errPubkeyForbidden:
|
case errPubkeyForbidden:
|
||||||
return DiscUselessPeer
|
return DiscUselessPeer
|
||||||
|
@ -125,7 +123,7 @@ func discReasonForError(err error) DiscReason {
|
||||||
return DiscProtocolError
|
return DiscProtocolError
|
||||||
case errPingTimeout:
|
case errPingTimeout:
|
||||||
return DiscReadTimeout
|
return DiscReadTimeout
|
||||||
case errRead, errWrite, errMisc:
|
case errRead, errWrite:
|
||||||
return DiscNetworkError
|
return DiscNetworkError
|
||||||
default:
|
default:
|
||||||
return DiscSubprotocolError
|
return DiscSubprotocolError
|
||||||
|
|
307
p2p/peer_test.go
307
p2p/peer_test.go
|
@ -1,15 +1,17 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/ethereum/go-ethereum/rlp"
|
||||||
)
|
)
|
||||||
|
|
||||||
var discard = Protocol{
|
var discard = Protocol{
|
||||||
|
@ -28,17 +30,13 @@ var discard = Protocol{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) {
|
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
|
||||||
conn1, conn2 := net.Pipe()
|
conn1, conn2 := net.Pipe()
|
||||||
peer := newPeer(conn1, protos, nil)
|
peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
|
||||||
peer.ourID = &peerId{}
|
peer.noHandshake = noHandshake
|
||||||
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
errc := make(chan DiscReason, 1)
|
||||||
errc := make(chan error, 1)
|
go func() { errc <- peer.run() }()
|
||||||
go func() {
|
return newFrameRW(conn2, msgWriteTimeout), peer, errc
|
||||||
_, err := peer.loop()
|
|
||||||
errc <- err
|
|
||||||
}()
|
|
||||||
return conn2, peer, errc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerProtoReadMsg(t *testing.T) {
|
func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
|
@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) {
|
||||||
Name: "a",
|
Name: "a",
|
||||||
Length: 5,
|
Length: 5,
|
||||||
Run: func(peer *Peer, rw MsgReadWriter) error {
|
Run: func(peer *Peer, rw MsgReadWriter) error {
|
||||||
msg, err := rw.ReadMsg()
|
if err := expectMsg(rw, 2, []uint{1}); err != nil {
|
||||||
if err != nil {
|
t.Error(err)
|
||||||
t.Errorf("read error: %v", err)
|
|
||||||
}
|
}
|
||||||
if msg.Code != 2 {
|
if err := expectMsg(rw, 3, []uint{2}); err != nil {
|
||||||
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
data, err := ioutil.ReadAll(msg.Payload)
|
if err := expectMsg(rw, 4, []uint{3}); err != nil {
|
||||||
if err != nil {
|
t.Error(err)
|
||||||
t.Errorf("payload read error: %v", err)
|
|
||||||
}
|
|
||||||
expdata, _ := hex.DecodeString("0183303030")
|
|
||||||
if !bytes.Equal(expdata, data) {
|
|
||||||
t.Errorf("incorrect msg data %x", data)
|
|
||||||
}
|
}
|
||||||
close(done)
|
close(done)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
net, peer, errc := testPeer([]Protocol{proto})
|
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||||
defer net.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
writeMsg(net, NewMsg(18, 1, "000"))
|
EncodeMsg(rw, baseProtocolLength+2, 1)
|
||||||
|
EncodeMsg(rw, baseProtocolLength+3, 2)
|
||||||
|
EncodeMsg(rw, baseProtocolLength+4, 3)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case err := <-errc:
|
case err := <-errc:
|
||||||
|
@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
net, peer, errc := testPeer([]Protocol{proto})
|
rw, peer, errc := testPeer(true, []Protocol{proto})
|
||||||
defer net.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
writeMsg(net, NewMsg(18, make([]byte, msgsize)))
|
EncodeMsg(rw, 18, make([]byte, msgsize))
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case err := <-errc:
|
case err := <-errc:
|
||||||
|
@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
net, peer, _ := testPeer([]Protocol{proto})
|
rw, peer, _ := testPeer(true, []Protocol{proto})
|
||||||
defer net.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{proto.cap()})
|
peer.startSubprotocols([]Cap{proto.cap()})
|
||||||
|
|
||||||
bufr := bufio.NewReader(net)
|
if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
|
||||||
msg, err := readMsg(bufr)
|
t.Error(err)
|
||||||
if err != nil {
|
|
||||||
t.Errorf("read error: %v", err)
|
|
||||||
}
|
|
||||||
if msg.Code != 17 {
|
|
||||||
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
|
|
||||||
}
|
|
||||||
var data []string
|
|
||||||
if err := msg.Decode(&data); err != nil {
|
|
||||||
t.Errorf("payload decode error: %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
|
|
||||||
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerWrite(t *testing.T) {
|
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
net, peer, peerErr := testPeer([]Protocol{discard})
|
rw, peer, peerErr := testPeer(true, []Protocol{discard})
|
||||||
defer net.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{discard.cap()})
|
peer.startSubprotocols([]Cap{discard.cap()})
|
||||||
|
|
||||||
// test write errors
|
// test write errors
|
||||||
|
@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) {
|
||||||
// setup for reading the message on the other end
|
// setup for reading the message on the other end
|
||||||
read := make(chan struct{})
|
read := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
bufr := bufio.NewReader(net)
|
if err := expectMsg(rw, 16, nil); err != nil {
|
||||||
msg, err := readMsg(bufr)
|
t.Error()
|
||||||
if err != nil {
|
|
||||||
t.Errorf("read error: %v", err)
|
|
||||||
} else if msg.Code != 16 {
|
|
||||||
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
|
|
||||||
}
|
}
|
||||||
msg.Discard()
|
|
||||||
close(read)
|
close(read)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// test succcessful write
|
// test successful write
|
||||||
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
|
||||||
t.Errorf("expect no error for known protocol: %v", err)
|
t.Errorf("expect no error for known protocol: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -198,104 +176,153 @@ func TestPeerWrite(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPeerActivity(t *testing.T) {
|
func TestPeerPing(t *testing.T) {
|
||||||
// shorten inactivityTimeout while this test is running
|
defer testlog(t).detach()
|
||||||
oldT := inactivityTimeout
|
|
||||||
defer func() { inactivityTimeout = oldT }()
|
|
||||||
inactivityTimeout = 20 * time.Millisecond
|
|
||||||
|
|
||||||
net, peer, peerErr := testPeer([]Protocol{discard})
|
rw, _, _ := testPeer(true, nil)
|
||||||
defer net.Close()
|
defer rw.Close()
|
||||||
peer.startSubprotocols([]Cap{discard.cap()})
|
if err := EncodeMsg(rw, pingMsg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := expectMsg(rw, pongMsg, nil); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sub := peer.activity.Subscribe(time.Time{})
|
func TestPeerDisconnect(t *testing.T) {
|
||||||
defer sub.Unsubscribe()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
for i := 0; i < 6; i++ {
|
rw, _, disc := testPeer(true, nil)
|
||||||
writeMsg(net, NewMsg(16))
|
defer rw.Close()
|
||||||
|
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rw.Close() // make test end faster
|
||||||
|
if reason := <-disc; reason != DiscRequested {
|
||||||
|
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPeerHandshake(t *testing.T) {
|
||||||
|
defer testlog(t).detach()
|
||||||
|
|
||||||
|
// remote has two matching protocols: a and c
|
||||||
|
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
|
||||||
|
remoteID := randomID()
|
||||||
|
remote.ourID = &remoteID
|
||||||
|
remote.ourName = "remote peer"
|
||||||
|
|
||||||
|
start := make(chan string)
|
||||||
|
stop := make(chan struct{})
|
||||||
|
run := func(p *Peer, rw MsgReadWriter) error {
|
||||||
|
name := rw.(*proto).name
|
||||||
|
if name != "a" && name != "c" {
|
||||||
|
t.Errorf("protocol %q should not be started", name)
|
||||||
|
} else {
|
||||||
|
start <- name
|
||||||
|
}
|
||||||
|
<-stop
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
protocols := []Protocol{
|
||||||
|
{Name: "a", Version: 1, Length: 1, Run: run},
|
||||||
|
{Name: "b", Version: 2, Length: 1, Run: run},
|
||||||
|
{Name: "c", Version: 3, Length: 1, Run: run},
|
||||||
|
{Name: "d", Version: 4, Length: 1, Run: run},
|
||||||
|
}
|
||||||
|
rw, p, disc := testPeer(false, protocols)
|
||||||
|
p.remoteID = remote.ourID
|
||||||
|
defer rw.Close()
|
||||||
|
|
||||||
|
// run the handshake
|
||||||
|
remoteProtocols := []Protocol{protocols[0], protocols[2]}
|
||||||
|
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
|
||||||
|
t.Fatalf("handshake write error: %v", err)
|
||||||
|
}
|
||||||
|
if err := readProtocolHandshake(remote, rw); err != nil {
|
||||||
|
t.Fatalf("handshake read error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that all protocols have been started
|
||||||
|
var started []string
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
select {
|
select {
|
||||||
case <-sub.Chan():
|
case name := <-start:
|
||||||
case <-time.After(inactivityTimeout / 2):
|
started = append(started, name)
|
||||||
t.Fatal("no event within ", inactivityTimeout/2)
|
case <-time.After(100 * time.Millisecond):
|
||||||
case err := <-peerErr:
|
|
||||||
t.Fatal("peer error", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
sort.Strings(started)
|
||||||
select {
|
if !reflect.DeepEqual(started, []string{"a", "c"}) {
|
||||||
case <-time.After(inactivityTimeout * 2):
|
t.Errorf("wrong protocols started: %v", started)
|
||||||
case <-sub.Chan():
|
|
||||||
t.Fatal("got activity event while connection was inactive")
|
|
||||||
case err := <-peerErr:
|
|
||||||
t.Fatal("peer error", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check that metadata has been set
|
||||||
|
if p.ID() != remoteID {
|
||||||
|
t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
|
||||||
|
}
|
||||||
|
if p.Name() != remote.ourName {
|
||||||
|
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
|
||||||
|
}
|
||||||
|
|
||||||
|
close(stop)
|
||||||
|
expectMsg(rw, discMsg, nil)
|
||||||
|
t.Logf("disc reason: %v", <-disc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewPeer(t *testing.T) {
|
func TestNewPeer(t *testing.T) {
|
||||||
|
name := "nodename"
|
||||||
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
caps := []Cap{{"foo", 2}, {"bar", 3}}
|
||||||
id := &peerId{}
|
id := randomID()
|
||||||
p := NewPeer(id, caps)
|
p := NewPeer(id, name, caps)
|
||||||
|
if p.ID() != id {
|
||||||
|
t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
|
||||||
|
}
|
||||||
|
if p.Name() != name {
|
||||||
|
t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
|
||||||
|
}
|
||||||
if !reflect.DeepEqual(p.Caps(), caps) {
|
if !reflect.DeepEqual(p.Caps(), caps) {
|
||||||
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
|
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
|
||||||
}
|
}
|
||||||
if p.Identity() != id {
|
|
||||||
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id)
|
p.Disconnect(DiscAlreadyConnected) // Should not hang
|
||||||
}
|
|
||||||
// Should not hang.
|
|
||||||
p.Disconnect(DiscAlreadyConnected)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEOFSignal(t *testing.T) {
|
// expectMsg reads a message from r and verifies that its
|
||||||
rb := make([]byte, 10)
|
// code and encoded RLP content match the provided values.
|
||||||
|
// If content is nil, the payload is discarded and not verified.
|
||||||
|
func expectMsg(r MsgReader, code uint64, content interface{}) error {
|
||||||
|
msg, err := r.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Code != code {
|
||||||
|
return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
|
||||||
|
}
|
||||||
|
if content == nil {
|
||||||
|
return msg.Discard()
|
||||||
|
} else {
|
||||||
|
contentEnc, err := rlp.EncodeToBytes(content)
|
||||||
|
if err != nil {
|
||||||
|
panic("content encode error: " + err.Error())
|
||||||
|
}
|
||||||
|
// skip over list header in encoded value. this is temporary.
|
||||||
|
contentEncR := bytes.NewReader(contentEnc)
|
||||||
|
if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
|
||||||
|
panic("content must encode as RLP list")
|
||||||
|
}
|
||||||
|
contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
|
||||||
|
|
||||||
// empty reader
|
actualContent, err := ioutil.ReadAll(msg.Payload)
|
||||||
eof := make(chan struct{}, 1)
|
if err != nil {
|
||||||
sig := &eofSignal{new(bytes.Buffer), 0, eof}
|
return err
|
||||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
}
|
||||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
if !bytes.Equal(actualContent, contentEnc) {
|
||||||
}
|
return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
|
||||||
select {
|
}
|
||||||
case <-eof:
|
|
||||||
default:
|
|
||||||
t.Error("EOF chan not signaled")
|
|
||||||
}
|
|
||||||
|
|
||||||
// count before error
|
|
||||||
eof = make(chan struct{}, 1)
|
|
||||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
|
|
||||||
if n, err := sig.Read(rb); n != 8 || err != nil {
|
|
||||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-eof:
|
|
||||||
default:
|
|
||||||
t.Error("EOF chan not signaled")
|
|
||||||
}
|
|
||||||
|
|
||||||
// error before count
|
|
||||||
eof = make(chan struct{}, 1)
|
|
||||||
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
|
|
||||||
if n, err := sig.Read(rb); n != 4 || err != nil {
|
|
||||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
||||||
}
|
|
||||||
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
|
|
||||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-eof:
|
|
||||||
default:
|
|
||||||
t.Error("EOF chan not signaled")
|
|
||||||
}
|
|
||||||
|
|
||||||
// no signal if neither occurs
|
|
||||||
eof = make(chan struct{}, 1)
|
|
||||||
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
|
|
||||||
if n, err := sig.Read(rb); n != 10 || err != nil {
|
|
||||||
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-eof:
|
|
||||||
t.Error("unexpected EOF signal")
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
244
p2p/protocol.go
244
p2p/protocol.go
|
@ -1,10 +1,5 @@
|
||||||
package p2p
|
package p2p
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Protocol represents a P2P subprotocol implementation.
|
// Protocol represents a P2P subprotocol implementation.
|
||||||
type Protocol struct {
|
type Protocol struct {
|
||||||
// Name should contain the official protocol name,
|
// Name should contain the official protocol name,
|
||||||
|
@ -32,38 +27,6 @@ func (p Protocol) cap() Cap {
|
||||||
return Cap{p.Name, p.Version}
|
return Cap{p.Name, p.Version}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
baseProtocolVersion = 2
|
|
||||||
baseProtocolLength = uint64(16)
|
|
||||||
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// devp2p message codes
|
|
||||||
handshakeMsg = 0x00
|
|
||||||
discMsg = 0x01
|
|
||||||
pingMsg = 0x02
|
|
||||||
pongMsg = 0x03
|
|
||||||
getPeersMsg = 0x04
|
|
||||||
peersMsg = 0x05
|
|
||||||
)
|
|
||||||
|
|
||||||
// handshake is the structure of a handshake list.
|
|
||||||
type handshake struct {
|
|
||||||
Version uint64
|
|
||||||
ID string
|
|
||||||
Caps []Cap
|
|
||||||
ListenPort uint64
|
|
||||||
NodeID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handshake) String() string {
|
|
||||||
return h.ID
|
|
||||||
}
|
|
||||||
func (h *handshake) Pubkey() []byte {
|
|
||||||
return h.NodeID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cap is the structure of a peer capability.
|
// Cap is the structure of a peer capability.
|
||||||
type Cap struct {
|
type Cap struct {
|
||||||
Name string
|
Name string
|
||||||
|
@ -79,210 +42,3 @@ type capsByName []Cap
|
||||||
func (cs capsByName) Len() int { return len(cs) }
|
func (cs capsByName) Len() int { return len(cs) }
|
||||||
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
|
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
|
||||||
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
|
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
|
||||||
|
|
||||||
type baseProtocol struct {
|
|
||||||
rw MsgReadWriter
|
|
||||||
peer *Peer
|
|
||||||
}
|
|
||||||
|
|
||||||
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
|
|
||||||
bp := &baseProtocol{rw, peer}
|
|
||||||
errc := make(chan error, 1)
|
|
||||||
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
|
|
||||||
if err := bp.readHandshake(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// handle write error
|
|
||||||
if err := <-errc; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// run main loop
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
if err := bp.handle(rw); err != nil {
|
|
||||||
errc <- err
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return bp.loop(errc)
|
|
||||||
}
|
|
||||||
|
|
||||||
var pingTimeout = 2 * time.Second
|
|
||||||
|
|
||||||
func (bp *baseProtocol) loop(quit <-chan error) error {
|
|
||||||
ping := time.NewTimer(pingTimeout)
|
|
||||||
activity := bp.peer.activity.Subscribe(time.Time{})
|
|
||||||
lastActive := time.Time{}
|
|
||||||
defer ping.Stop()
|
|
||||||
defer activity.Unsubscribe()
|
|
||||||
|
|
||||||
getPeersTick := time.NewTicker(10 * time.Second)
|
|
||||||
defer getPeersTick.Stop()
|
|
||||||
err := EncodeMsg(bp.rw, getPeersMsg)
|
|
||||||
|
|
||||||
for err == nil {
|
|
||||||
select {
|
|
||||||
case err = <-quit:
|
|
||||||
return err
|
|
||||||
case <-getPeersTick.C:
|
|
||||||
err = EncodeMsg(bp.rw, getPeersMsg)
|
|
||||||
case event := <-activity.Chan():
|
|
||||||
ping.Reset(pingTimeout)
|
|
||||||
lastActive = event.(time.Time)
|
|
||||||
case t := <-ping.C:
|
|
||||||
if lastActive.Add(pingTimeout * 2).Before(t) {
|
|
||||||
err = newPeerError(errPingTimeout, "")
|
|
||||||
} else if lastActive.Add(pingTimeout).Before(t) {
|
|
||||||
err = EncodeMsg(bp.rw, pingMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
|
|
||||||
msg, err := rw.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Size > baseProtocolMaxMsgSize {
|
|
||||||
return newPeerError(errMisc, "message too big")
|
|
||||||
}
|
|
||||||
// make sure that the payload has been fully consumed
|
|
||||||
defer msg.Discard()
|
|
||||||
|
|
||||||
switch msg.Code {
|
|
||||||
case handshakeMsg:
|
|
||||||
return newPeerError(errProtocolBreach, "extra handshake received")
|
|
||||||
|
|
||||||
case discMsg:
|
|
||||||
var reason [1]DiscReason
|
|
||||||
if err := msg.Decode(&reason); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return discRequestedError(reason[0])
|
|
||||||
|
|
||||||
case pingMsg:
|
|
||||||
return EncodeMsg(bp.rw, pongMsg)
|
|
||||||
|
|
||||||
case pongMsg:
|
|
||||||
|
|
||||||
case getPeersMsg:
|
|
||||||
peers := bp.peerList()
|
|
||||||
// this is dangerous. the spec says that we should _delay_
|
|
||||||
// sending the response if no new information is available.
|
|
||||||
// this means that would need to send a response later when
|
|
||||||
// new peers become available.
|
|
||||||
//
|
|
||||||
// TODO: add event mechanism to notify baseProtocol for new peers
|
|
||||||
if len(peers) > 0 {
|
|
||||||
return EncodeMsg(bp.rw, peersMsg, peers...)
|
|
||||||
}
|
|
||||||
|
|
||||||
case peersMsg:
|
|
||||||
var peers []*peerAddr
|
|
||||||
if err := msg.Decode(&peers); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, addr := range peers {
|
|
||||||
bp.peer.Debugf("received peer suggestion: %v", addr)
|
|
||||||
bp.peer.newPeerAddr <- addr
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *baseProtocol) readHandshake() error {
|
|
||||||
// read and handle remote handshake
|
|
||||||
msg, err := bp.rw.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Code != handshakeMsg {
|
|
||||||
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
|
|
||||||
}
|
|
||||||
if msg.Size > baseProtocolMaxMsgSize {
|
|
||||||
return newPeerError(errMisc, "message too big")
|
|
||||||
}
|
|
||||||
var hs handshake
|
|
||||||
if err := msg.Decode(&hs); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// validate handshake info
|
|
||||||
if hs.Version != baseProtocolVersion {
|
|
||||||
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
|
|
||||||
baseProtocolVersion, hs.Version)
|
|
||||||
}
|
|
||||||
if len(hs.NodeID) == 0 {
|
|
||||||
return newPeerError(errPubkeyMissing, "")
|
|
||||||
}
|
|
||||||
if len(hs.NodeID) != 64 {
|
|
||||||
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
|
|
||||||
}
|
|
||||||
if da := bp.peer.dialAddr; da != nil {
|
|
||||||
// verify that the peer we wanted to connect to
|
|
||||||
// actually holds the target public key.
|
|
||||||
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
|
|
||||||
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
|
|
||||||
if err := bp.peer.pubkeyHook(pa); err != nil {
|
|
||||||
return newPeerError(errPubkeyForbidden, "%v", err)
|
|
||||||
}
|
|
||||||
// TODO: remove Caps with empty name
|
|
||||||
var addr *peerAddr
|
|
||||||
if hs.ListenPort != 0 {
|
|
||||||
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
|
|
||||||
addr.Port = hs.ListenPort
|
|
||||||
}
|
|
||||||
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
|
|
||||||
bp.peer.startSubprotocols(hs.Caps)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *baseProtocol) handshakeMsg() Msg {
|
|
||||||
var (
|
|
||||||
port uint64
|
|
||||||
caps []interface{}
|
|
||||||
)
|
|
||||||
if bp.peer.ourListenAddr != nil {
|
|
||||||
port = bp.peer.ourListenAddr.Port
|
|
||||||
}
|
|
||||||
for _, proto := range bp.peer.protocols {
|
|
||||||
caps = append(caps, proto.cap())
|
|
||||||
}
|
|
||||||
return NewMsg(handshakeMsg,
|
|
||||||
baseProtocolVersion,
|
|
||||||
bp.peer.ourID.String(),
|
|
||||||
caps,
|
|
||||||
port,
|
|
||||||
bp.peer.ourID.Pubkey()[1:],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bp *baseProtocol) peerList() []interface{} {
|
|
||||||
peers := bp.peer.otherPeers()
|
|
||||||
ds := make([]interface{}, 0, len(peers))
|
|
||||||
for _, p := range peers {
|
|
||||||
p.infolock.Lock()
|
|
||||||
addr := p.listenAddr
|
|
||||||
p.infolock.Unlock()
|
|
||||||
// filter out this peer and peers that are not listening or
|
|
||||||
// have not completed the handshake.
|
|
||||||
// TODO: track previously sent peers and exclude them as well.
|
|
||||||
if p == bp.peer || addr == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ds = append(ds, addr)
|
|
||||||
}
|
|
||||||
ourAddr := bp.peer.ourListenAddr
|
|
||||||
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
|
|
||||||
ds = append(ds, ourAddr)
|
|
||||||
}
|
|
||||||
return ds
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,158 +0,0 @@
|
||||||
package p2p
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"reflect"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type peerId struct {
|
|
||||||
pubkey []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *peerId) String() string {
|
|
||||||
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (self *peerId) Pubkey() (pubkey []byte) {
|
|
||||||
pubkey = self.pubkey
|
|
||||||
if len(pubkey) == 0 {
|
|
||||||
pubkey = crypto.GenerateNewKeyPair().PublicKey
|
|
||||||
self.pubkey = pubkey
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestPeer() (peer *Peer) {
|
|
||||||
peer = NewPeer(&peerId{}, []Cap{})
|
|
||||||
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
|
||||||
peer.ourID = &peerId{}
|
|
||||||
peer.listenAddr = &peerAddr{}
|
|
||||||
peer.otherPeers = func() []*Peer { return nil }
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBaseProtocolPeers(t *testing.T) {
|
|
||||||
peerList := []*peerAddr{
|
|
||||||
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
|
|
||||||
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
|
|
||||||
}
|
|
||||||
listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
|
|
||||||
rw1, rw2 := MsgPipe()
|
|
||||||
defer rw1.Close()
|
|
||||||
wg := new(sync.WaitGroup)
|
|
||||||
|
|
||||||
// run matcher, close pipe when addresses have arrived
|
|
||||||
numPeers := len(peerList) + 1
|
|
||||||
addrChan := make(chan *peerAddr)
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
i := 0
|
|
||||||
for got := range addrChan {
|
|
||||||
var want *peerAddr
|
|
||||||
switch {
|
|
||||||
case i < len(peerList):
|
|
||||||
want = peerList[i]
|
|
||||||
case i == len(peerList):
|
|
||||||
want = listenAddr // listenAddr should be the last thing sent
|
|
||||||
}
|
|
||||||
t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
|
|
||||||
if !reflect.DeepEqual(want, got) {
|
|
||||||
t.Errorf("mismatch: got %+v, want %+v", got, want)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
if i == numPeers {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if i != numPeers {
|
|
||||||
t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
|
|
||||||
}
|
|
||||||
rw1.Close()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// run first peer (in background)
|
|
||||||
peer1 := newTestPeer()
|
|
||||||
peer1.ourListenAddr = listenAddr
|
|
||||||
peer1.otherPeers = func() []*Peer {
|
|
||||||
pl := make([]*Peer, len(peerList))
|
|
||||||
for i, addr := range peerList {
|
|
||||||
pl[i] = &Peer{listenAddr: addr}
|
|
||||||
}
|
|
||||||
return pl
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
runBaseProtocol(peer1, rw1)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// run second peer
|
|
||||||
peer2 := newTestPeer()
|
|
||||||
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
|
|
||||||
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
|
|
||||||
t.Errorf("peer2 terminated with unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// terminate matcher
|
|
||||||
close(addrChan)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBaseProtocolDisconnect(t *testing.T) {
|
|
||||||
peer := NewPeer(&peerId{}, nil)
|
|
||||||
peer.ourID = &peerId{}
|
|
||||||
peer.pubkeyHook = func(*peerAddr) error { return nil }
|
|
||||||
|
|
||||||
rw1, rw2 := MsgPipe()
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
if err := expectMsg(rw2, handshakeMsg); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
err := EncodeMsg(rw2, handshakeMsg,
|
|
||||||
baseProtocolVersion,
|
|
||||||
"",
|
|
||||||
[]interface{}{},
|
|
||||||
0,
|
|
||||||
make([]byte, 64),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if err := expectMsg(rw2, getPeersMsg); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := runBaseProtocol(peer, rw1); err == nil {
|
|
||||||
t.Errorf("base protocol returned without error")
|
|
||||||
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
|
|
||||||
t.Errorf("base protocol returned wrong error: %v", err)
|
|
||||||
}
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
|
|
||||||
func expectMsg(r MsgReader, code uint64) error {
|
|
||||||
msg, err := r.ReadMsg()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := msg.Discard(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Code != code {
|
|
||||||
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
420
p2p/server.go
420
p2p/server.go
|
@ -2,37 +2,56 @@ package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
"github.com/ethereum/go-ethereum/logger"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
outboundAddressPoolSize = 500
|
handshakeTimeout = 5 * time.Second
|
||||||
defaultDialTimeout = 10 * time.Second
|
defaultDialTimeout = 10 * time.Second
|
||||||
portMappingUpdateInterval = 15 * time.Minute
|
refreshPeersInterval = 30 * time.Second
|
||||||
portMappingTimeout = 20 * time.Minute
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var srvlog = logger.NewLogger("P2P Server")
|
var srvlog = logger.NewLogger("P2P Server")
|
||||||
|
|
||||||
|
// MakeName creates a node name that follows the ethereum convention
|
||||||
|
// for such names. It adds the operation system name and Go runtime version
|
||||||
|
// the name.
|
||||||
|
func MakeName(name, version string) string {
|
||||||
|
return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version())
|
||||||
|
}
|
||||||
|
|
||||||
// Server manages all peer connections.
|
// Server manages all peer connections.
|
||||||
//
|
//
|
||||||
// The fields of Server are used as configuration parameters.
|
// The fields of Server are used as configuration parameters.
|
||||||
// You should set them before starting the Server. Fields may not be
|
// You should set them before starting the Server. Fields may not be
|
||||||
// modified while the server is running.
|
// modified while the server is running.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// This field must be set to a valid client identity.
|
// This field must be set to a valid secp256k1 private key.
|
||||||
Identity ClientIdentity
|
PrivateKey *ecdsa.PrivateKey
|
||||||
|
|
||||||
// MaxPeers is the maximum number of peers that can be
|
// MaxPeers is the maximum number of peers that can be
|
||||||
// connected. It must be greater than zero.
|
// connected. It must be greater than zero.
|
||||||
MaxPeers int
|
MaxPeers int
|
||||||
|
|
||||||
|
// Name sets the node name of this server.
|
||||||
|
// Use MakeName to create a name that follows existing conventions.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Bootstrap nodes are used to establish connectivity
|
||||||
|
// with the rest of the network.
|
||||||
|
BootstrapNodes []*discover.Node
|
||||||
|
|
||||||
// Protocols should contain the protocols supported
|
// Protocols should contain the protocols supported
|
||||||
// by the server. Matching protocols are launched for
|
// by the server. Matching protocols are launched for
|
||||||
// each peer.
|
// each peer.
|
||||||
|
@ -53,7 +72,7 @@ type Server struct {
|
||||||
// If set to a non-nil value, the given NAT port mapper
|
// If set to a non-nil value, the given NAT port mapper
|
||||||
// is used to make the listening port available to the
|
// is used to make the listening port available to the
|
||||||
// Internet.
|
// Internet.
|
||||||
NAT NAT
|
NAT nat.Interface
|
||||||
|
|
||||||
// If Dialer is set to a non-nil value, the given Dialer
|
// If Dialer is set to a non-nil value, the given Dialer
|
||||||
// is used to dial outbound peer connections.
|
// is used to dial outbound peer connections.
|
||||||
|
@ -62,35 +81,26 @@ type Server struct {
|
||||||
// If NoDial is true, the server will not dial any peers.
|
// If NoDial is true, the server will not dial any peers.
|
||||||
NoDial bool
|
NoDial bool
|
||||||
|
|
||||||
// Hook for testing. This is useful because we can inhibit
|
// Hooks for testing. These are useful because we can inhibit
|
||||||
// the whole protocol stack.
|
// the whole protocol stack.
|
||||||
newPeerFunc peerFunc
|
handshakeFunc
|
||||||
|
newPeerHook
|
||||||
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
running bool
|
running bool
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
laddr *net.TCPAddr // real listen addr
|
peers map[discover.NodeID]*Peer
|
||||||
peers []*Peer
|
|
||||||
peerSlots chan int
|
|
||||||
peerCount int
|
|
||||||
|
|
||||||
quit chan struct{}
|
ntab *discover.Table
|
||||||
wg sync.WaitGroup
|
|
||||||
peerConnect chan *peerAddr
|
quit chan struct{}
|
||||||
peerDisconnect chan *Peer
|
loopWG sync.WaitGroup // {dial,listen,nat}Loop
|
||||||
|
peerWG sync.WaitGroup // active peer goroutines
|
||||||
|
peerConnect chan *discover.Node
|
||||||
}
|
}
|
||||||
|
|
||||||
// NAT is implemented by NAT traversal methods.
|
type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
|
||||||
type NAT interface {
|
type newPeerHook func(*Peer)
|
||||||
GetExternalAddress() (net.IP, error)
|
|
||||||
AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error
|
|
||||||
DeletePortMapping(protocol string, extport, intport int) error
|
|
||||||
|
|
||||||
// Should return name of the method.
|
|
||||||
String() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer
|
|
||||||
|
|
||||||
// Peers returns all connected peers.
|
// Peers returns all connected peers.
|
||||||
func (srv *Server) Peers() (peers []*Peer) {
|
func (srv *Server) Peers() (peers []*Peer) {
|
||||||
|
@ -107,18 +117,15 @@ func (srv *Server) Peers() (peers []*Peer) {
|
||||||
// PeerCount returns the number of connected peers.
|
// PeerCount returns the number of connected peers.
|
||||||
func (srv *Server) PeerCount() int {
|
func (srv *Server) PeerCount() int {
|
||||||
srv.lock.RLock()
|
srv.lock.RLock()
|
||||||
defer srv.lock.RUnlock()
|
n := len(srv.peers)
|
||||||
return srv.peerCount
|
srv.lock.RUnlock()
|
||||||
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
// SuggestPeer injects an address into the outbound address pool.
|
// SuggestPeer creates a connection to the given Node if it
|
||||||
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) {
|
// is not already connected.
|
||||||
addr := &peerAddr{ip, uint64(port), nodeID}
|
func (srv *Server) SuggestPeer(n *discover.Node) {
|
||||||
select {
|
srv.peerConnect <- n
|
||||||
case srv.peerConnect <- addr:
|
|
||||||
default: // don't block
|
|
||||||
srvlog.Warnf("peer suggestion %v ignored", addr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Broadcast sends an RLP-encoded message to all connected peers.
|
// Broadcast sends an RLP-encoded message to all connected peers.
|
||||||
|
@ -152,47 +159,46 @@ func (srv *Server) Start() (err error) {
|
||||||
}
|
}
|
||||||
srvlog.Infoln("Starting Server")
|
srvlog.Infoln("Starting Server")
|
||||||
|
|
||||||
// initialize fields
|
// initialize all the fields
|
||||||
if srv.Identity == nil {
|
if srv.PrivateKey == nil {
|
||||||
return fmt.Errorf("Server.Identity must be set to a non-nil identity")
|
return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
|
||||||
}
|
}
|
||||||
if srv.MaxPeers <= 0 {
|
if srv.MaxPeers <= 0 {
|
||||||
return fmt.Errorf("Server.MaxPeers must be > 0")
|
return fmt.Errorf("Server.MaxPeers must be > 0")
|
||||||
}
|
}
|
||||||
srv.quit = make(chan struct{})
|
srv.quit = make(chan struct{})
|
||||||
srv.peers = make([]*Peer, srv.MaxPeers)
|
srv.peers = make(map[discover.NodeID]*Peer)
|
||||||
srv.peerSlots = make(chan int, srv.MaxPeers)
|
srv.peerConnect = make(chan *discover.Node)
|
||||||
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
|
|
||||||
srv.peerDisconnect = make(chan *Peer)
|
if srv.handshakeFunc == nil {
|
||||||
if srv.newPeerFunc == nil {
|
srv.handshakeFunc = encHandshake
|
||||||
srv.newPeerFunc = newServerPeer
|
|
||||||
}
|
}
|
||||||
if srv.Blacklist == nil {
|
if srv.Blacklist == nil {
|
||||||
srv.Blacklist = NewBlacklist()
|
srv.Blacklist = NewBlacklist()
|
||||||
}
|
}
|
||||||
if srv.Dialer == nil {
|
|
||||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
|
||||||
}
|
|
||||||
|
|
||||||
if srv.ListenAddr != "" {
|
if srv.ListenAddr != "" {
|
||||||
if err := srv.startListening(); err != nil {
|
if err := srv.startListening(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dial stuff
|
||||||
|
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
srv.ntab = dt
|
||||||
|
if srv.Dialer == nil {
|
||||||
|
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||||
|
}
|
||||||
if !srv.NoDial {
|
if !srv.NoDial {
|
||||||
srv.wg.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go srv.dialLoop()
|
go srv.dialLoop()
|
||||||
}
|
}
|
||||||
if srv.NoDial && srv.ListenAddr == "" {
|
if srv.NoDial && srv.ListenAddr == "" {
|
||||||
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
|
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
|
||||||
}
|
}
|
||||||
|
|
||||||
// make all slots available
|
|
||||||
for i := range srv.peers {
|
|
||||||
srv.peerSlots <- i
|
|
||||||
}
|
|
||||||
// note: discLoop is not part of WaitGroup
|
|
||||||
go srv.discLoop()
|
|
||||||
srv.running = true
|
srv.running = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -202,14 +208,17 @@ func (srv *Server) startListening() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
srv.ListenAddr = listener.Addr().String()
|
laddr := listener.Addr().(*net.TCPAddr)
|
||||||
srv.laddr = listener.Addr().(*net.TCPAddr)
|
srv.ListenAddr = laddr.String()
|
||||||
srv.listener = listener
|
srv.listener = listener
|
||||||
srv.wg.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go srv.listenLoop()
|
go srv.listenLoop()
|
||||||
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
|
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||||
srv.wg.Add(1)
|
srv.loopWG.Add(1)
|
||||||
go srv.natLoop(srv.laddr.Port)
|
go func() {
|
||||||
|
nat.Map(srv.NAT, srv.quit, "tcp", laddr.Port, laddr.Port, "ethereum p2p")
|
||||||
|
srv.loopWG.Done()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -225,200 +234,171 @@ func (srv *Server) Stop() {
|
||||||
srv.running = false
|
srv.running = false
|
||||||
srv.lock.Unlock()
|
srv.lock.Unlock()
|
||||||
|
|
||||||
srvlog.Infoln("Stopping server")
|
srvlog.Infoln("Stopping Server")
|
||||||
|
srv.ntab.Close()
|
||||||
if srv.listener != nil {
|
if srv.listener != nil {
|
||||||
// this unblocks listener Accept
|
// this unblocks listener Accept
|
||||||
srv.listener.Close()
|
srv.listener.Close()
|
||||||
}
|
}
|
||||||
close(srv.quit)
|
close(srv.quit)
|
||||||
for _, peer := range srv.Peers() {
|
srv.loopWG.Wait()
|
||||||
|
|
||||||
|
// No new peers can be added at this point because dialLoop and
|
||||||
|
// listenLoop are down. It is safe to call peerWG.Wait because
|
||||||
|
// peerWG.Add is not called outside of those loops.
|
||||||
|
for _, peer := range srv.peers {
|
||||||
peer.Disconnect(DiscQuitting)
|
peer.Disconnect(DiscQuitting)
|
||||||
}
|
}
|
||||||
srv.wg.Wait()
|
srv.peerWG.Wait()
|
||||||
|
|
||||||
// wait till they actually disconnect
|
|
||||||
// this is checked by claiming all peerSlots.
|
|
||||||
// slots become available as the peers disconnect.
|
|
||||||
for i := 0; i < cap(srv.peerSlots); i++ {
|
|
||||||
<-srv.peerSlots
|
|
||||||
}
|
|
||||||
// terminate discLoop
|
|
||||||
close(srv.peerDisconnect)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) discLoop() {
|
|
||||||
for peer := range srv.peerDisconnect {
|
|
||||||
srv.removePeer(peer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// main loop for adding connections via listening
|
// main loop for adding connections via listening
|
||||||
func (srv *Server) listenLoop() {
|
func (srv *Server) listenLoop() {
|
||||||
defer srv.wg.Done()
|
defer srv.loopWG.Done()
|
||||||
|
|
||||||
srvlog.Infoln("Listening on", srv.listener.Addr())
|
srvlog.Infoln("Listening on", srv.listener.Addr())
|
||||||
for {
|
for {
|
||||||
select {
|
conn, err := srv.listener.Accept()
|
||||||
case slot := <-srv.peerSlots:
|
if err != nil {
|
||||||
srvlog.Debugf("grabbed slot %v for listening", slot)
|
|
||||||
conn, err := srv.listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
srv.peerSlots <- slot
|
|
||||||
return
|
|
||||||
}
|
|
||||||
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
|
|
||||||
srv.addPeer(conn, nil, slot)
|
|
||||||
case <-srv.quit:
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
|
||||||
|
srv.peerWG.Add(1)
|
||||||
|
go srv.startPeer(conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) natLoop(port int) {
|
|
||||||
defer srv.wg.Done()
|
|
||||||
for {
|
|
||||||
srv.updatePortMapping(port)
|
|
||||||
select {
|
|
||||||
case <-time.After(portMappingUpdateInterval):
|
|
||||||
// one more round
|
|
||||||
case <-srv.quit:
|
|
||||||
srv.removePortMapping(port)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) updatePortMapping(port int) {
|
|
||||||
srvlog.Infoln("Attempting to map port", port, "with", srv.NAT)
|
|
||||||
err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout)
|
|
||||||
if err != nil {
|
|
||||||
srvlog.Errorln("Port mapping error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
extip, err := srv.NAT.GetExternalAddress()
|
|
||||||
if err != nil {
|
|
||||||
srvlog.Errorln("Error getting external IP:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
srv.lock.Lock()
|
|
||||||
extaddr := *(srv.listener.Addr().(*net.TCPAddr))
|
|
||||||
extaddr.IP = extip
|
|
||||||
srvlog.Infoln("Mapped port, external addr is", &extaddr)
|
|
||||||
srv.laddr = &extaddr
|
|
||||||
srv.lock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) removePortMapping(port int) {
|
|
||||||
srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT)
|
|
||||||
srv.NAT.DeletePortMapping("tcp", port, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) dialLoop() {
|
func (srv *Server) dialLoop() {
|
||||||
defer srv.wg.Done()
|
defer srv.loopWG.Done()
|
||||||
var (
|
refresh := time.NewTicker(refreshPeersInterval)
|
||||||
suggest chan *peerAddr
|
defer refresh.Stop()
|
||||||
slot *int
|
|
||||||
slots = srv.peerSlots
|
srv.ntab.Bootstrap(srv.BootstrapNodes)
|
||||||
)
|
go srv.findPeers()
|
||||||
|
|
||||||
|
dialed := make(chan *discover.Node)
|
||||||
|
dialing := make(map[discover.NodeID]bool)
|
||||||
|
|
||||||
|
// TODO: limit number of active dials
|
||||||
|
// TODO: ensure only one findPeers goroutine is running
|
||||||
|
// TODO: pause findPeers when we're at capacity
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case i := <-slots:
|
case <-refresh.C:
|
||||||
// we need a peer in slot i, slot reserved
|
|
||||||
slot = &i
|
|
||||||
// now we can watch for candidate peers in the next loop
|
|
||||||
suggest = srv.peerConnect
|
|
||||||
// do not consume more until candidate peer is found
|
|
||||||
slots = nil
|
|
||||||
|
|
||||||
case desc := <-suggest:
|
go srv.findPeers()
|
||||||
// candidate peer found, will dial out asyncronously
|
|
||||||
// if connection fails slot will be released
|
case dest := <-srv.peerConnect:
|
||||||
srvlog.DebugDetailf("dial %v (%v)", desc, *slot)
|
// avoid dialing nodes that are already connected.
|
||||||
go srv.dialPeer(desc, *slot)
|
// there is another check for this in addPeer,
|
||||||
// we can watch if more peers needed in the next loop
|
// which runs after the handshake.
|
||||||
slots = srv.peerSlots
|
srv.lock.Lock()
|
||||||
// until then we dont care about candidate peers
|
_, isconnected := srv.peers[dest.ID]
|
||||||
suggest = nil
|
srv.lock.Unlock()
|
||||||
|
if isconnected || dialing[dest.ID] || dest.ID == srv.ntab.Self() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dialing[dest.ID] = true
|
||||||
|
srv.peerWG.Add(1)
|
||||||
|
go func() {
|
||||||
|
srv.dialNode(dest)
|
||||||
|
// at this point, the peer has been added
|
||||||
|
// or discarded. either way, we're not dialing it anymore.
|
||||||
|
dialed <- dest
|
||||||
|
}()
|
||||||
|
|
||||||
|
case dest := <-dialed:
|
||||||
|
delete(dialing, dest.ID)
|
||||||
|
|
||||||
case <-srv.quit:
|
case <-srv.quit:
|
||||||
// give back the currently reserved slot
|
// TODO: maybe wait for active dials
|
||||||
if slot != nil {
|
|
||||||
srv.peerSlots <- *slot
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect to peer via dial out
|
func (srv *Server) dialNode(dest *discover.Node) {
|
||||||
func (srv *Server) dialPeer(desc *peerAddr, slot int) {
|
addr := &net.TCPAddr{IP: dest.IP, Port: dest.TCPPort}
|
||||||
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot)
|
srvlog.Debugf("Dialing %v\n", dest)
|
||||||
conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
|
conn, err := srv.Dialer.Dial("tcp", addr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
srvlog.DebugDetailf("dial error: %v", err)
|
srvlog.DebugDetailf("dial error: %v", err)
|
||||||
srv.peerSlots <- slot
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go srv.addPeer(conn, desc, slot)
|
srv.startPeer(conn, dest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// creates the new peer object and inserts it into its slot
|
func (srv *Server) findPeers() {
|
||||||
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer {
|
far := srv.ntab.Self()
|
||||||
srv.lock.Lock()
|
for i := range far {
|
||||||
defer srv.lock.Unlock()
|
far[i] = ^far[i]
|
||||||
if !srv.running {
|
|
||||||
conn.Close()
|
|
||||||
srv.peerSlots <- slot // release slot
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
peer := srv.newPeerFunc(srv, conn, desc)
|
closeToSelf := srv.ntab.Lookup(srv.ntab.Self())
|
||||||
peer.slot = slot
|
farFromSelf := srv.ntab.Lookup(far)
|
||||||
srv.peers[slot] = peer
|
|
||||||
srv.peerCount++
|
|
||||||
go func() {
|
|
||||||
peer.loop()
|
|
||||||
srv.peerDisconnect <- peer
|
|
||||||
}()
|
|
||||||
return peer
|
|
||||||
}
|
|
||||||
|
|
||||||
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
|
for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
|
||||||
func (srv *Server) removePeer(peer *Peer) {
|
if i < len(closeToSelf) {
|
||||||
srv.lock.Lock()
|
srv.peerConnect <- closeToSelf[i]
|
||||||
defer srv.lock.Unlock()
|
}
|
||||||
srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot)
|
if i < len(farFromSelf) {
|
||||||
if srv.peers[peer.slot] != peer {
|
srv.peerConnect <- farFromSelf[i]
|
||||||
srvlog.Warnln("Invalid peer to remove:", peer)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// remove from list and index
|
|
||||||
srv.peerCount--
|
|
||||||
srv.peers[peer.slot] = nil
|
|
||||||
// release slot to signal need for a new peer, last!
|
|
||||||
srv.peerSlots <- peer.slot
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *Server) verifyPeer(addr *peerAddr) error {
|
|
||||||
if srv.Blacklist.Exists(addr.Pubkey) {
|
|
||||||
return errors.New("blacklisted")
|
|
||||||
}
|
|
||||||
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
|
|
||||||
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
|
|
||||||
}
|
|
||||||
srv.lock.RLock()
|
|
||||||
defer srv.lock.RUnlock()
|
|
||||||
for _, peer := range srv.peers {
|
|
||||||
if peer != nil {
|
|
||||||
id := peer.Identity()
|
|
||||||
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
|
|
||||||
return errors.New("already connected")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO replace with "Set"
|
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
|
||||||
|
// TODO: handle/store session token
|
||||||
|
conn.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||||
|
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ourID := srv.ntab.Self()
|
||||||
|
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
|
||||||
|
if ok, reason := srv.addPeer(remoteID, p); !ok {
|
||||||
|
srvlog.DebugDetailf("Not adding %v (%v)\n", p, reason)
|
||||||
|
p.politeDisconnect(reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srvlog.Debugf("Added %v\n", p)
|
||||||
|
|
||||||
|
if srv.newPeerHook != nil {
|
||||||
|
srv.newPeerHook(p)
|
||||||
|
}
|
||||||
|
discreason := p.run()
|
||||||
|
srv.removePeer(p)
|
||||||
|
srvlog.Debugf("Removed %v (%v)\n", p, discreason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
|
||||||
|
srv.lock.Lock()
|
||||||
|
defer srv.lock.Unlock()
|
||||||
|
switch {
|
||||||
|
case !srv.running:
|
||||||
|
return false, DiscQuitting
|
||||||
|
case len(srv.peers) >= srv.MaxPeers:
|
||||||
|
return false, DiscTooManyPeers
|
||||||
|
case srv.peers[id] != nil:
|
||||||
|
return false, DiscAlreadyConnected
|
||||||
|
case srv.Blacklist.Exists(id[:]):
|
||||||
|
return false, DiscUselessPeer
|
||||||
|
case id == srv.ntab.Self():
|
||||||
|
return false, DiscSelf
|
||||||
|
}
|
||||||
|
srv.peers[id] = p
|
||||||
|
return true, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) removePeer(p *Peer) {
|
||||||
|
srv.lock.Lock()
|
||||||
|
delete(srv.peers, *p.remoteID)
|
||||||
|
srv.lock.Unlock()
|
||||||
|
srv.peerWG.Done()
|
||||||
|
}
|
||||||
|
|
||||||
type Blacklist interface {
|
type Blacklist interface {
|
||||||
Get([]byte) (bool, error)
|
Get([]byte) (bool, error)
|
||||||
Put([]byte) error
|
Put([]byte) error
|
||||||
|
|
|
@ -2,19 +2,28 @@ package p2p
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/ecdsa"
|
||||||
"io"
|
"io"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ethereum/go-ethereum/crypto"
|
||||||
|
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startTestServer(t *testing.T, pf peerFunc) *Server {
|
func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
||||||
server := &Server{
|
server := &Server{
|
||||||
Identity: &peerId{},
|
Name: "test",
|
||||||
MaxPeers: 10,
|
MaxPeers: 10,
|
||||||
ListenAddr: "127.0.0.1:0",
|
ListenAddr: "127.0.0.1:0",
|
||||||
newPeerFunc: pf,
|
PrivateKey: newkey(),
|
||||||
|
newPeerHook: pf,
|
||||||
|
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
|
||||||
|
return randomID(), nil, err
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
t.Fatalf("Could not start server: %v", err)
|
t.Fatalf("Could not start server: %v", err)
|
||||||
|
@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) {
|
||||||
|
|
||||||
// start the test server
|
// start the test server
|
||||||
connected := make(chan *Peer)
|
connected := make(chan *Peer)
|
||||||
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
srv := startTestServer(t, func(p *Peer) {
|
||||||
if conn == nil {
|
if p == nil {
|
||||||
t.Error("peer func called with nil conn")
|
t.Error("peer func called with nil conn")
|
||||||
}
|
}
|
||||||
if dialAddr != nil {
|
connected <- p
|
||||||
t.Error("peer func called with non-nil dialAddr")
|
|
||||||
}
|
|
||||||
peer := newPeer(conn, nil, dialAddr)
|
|
||||||
connected <- peer
|
|
||||||
return peer
|
|
||||||
})
|
})
|
||||||
defer close(connected)
|
defer close(connected)
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case peer := <-connected:
|
case peer := <-connected:
|
||||||
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() {
|
if peer.LocalAddr().String() != conn.RemoteAddr().String() {
|
||||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||||
peer.conn.LocalAddr(), conn.RemoteAddr())
|
peer.LocalAddr(), conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
t.Error("server did not accept within one second")
|
t.Error("server did not accept within one second")
|
||||||
|
@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) {
|
||||||
func TestServerDial(t *testing.T) {
|
func TestServerDial(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
// run a fake TCP server to handle the connection.
|
// run a one-shot TCP server to handle the connection.
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("could not setup listener: %v")
|
t.Fatalf("could not setup listener: %v")
|
||||||
|
@ -72,41 +76,32 @@ func TestServerDial(t *testing.T) {
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("acccept error:", err)
|
t.Error("accept error:", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
accepted <- conn
|
accepted <- conn
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// start the test server
|
// start the server
|
||||||
connected := make(chan *Peer)
|
connected := make(chan *Peer)
|
||||||
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer {
|
srv := startTestServer(t, func(p *Peer) { connected <- p })
|
||||||
if conn == nil {
|
|
||||||
t.Error("peer func called with nil conn")
|
|
||||||
}
|
|
||||||
peer := newPeer(conn, nil, dialAddr)
|
|
||||||
connected <- peer
|
|
||||||
return peer
|
|
||||||
})
|
|
||||||
defer close(connected)
|
defer close(connected)
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
||||||
// tell the server to connect.
|
// tell the server to connect
|
||||||
connAddr := newPeerAddr(listener.Addr(), nil)
|
tcpAddr := listener.Addr().(*net.TCPAddr)
|
||||||
srv.peerConnect <- connAddr
|
srv.SuggestPeer(&discover.Node{IP: tcpAddr.IP, TCPPort: tcpAddr.Port})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case conn := <-accepted:
|
case conn := <-accepted:
|
||||||
select {
|
select {
|
||||||
case peer := <-connected:
|
case peer := <-connected:
|
||||||
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() {
|
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
|
||||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||||
peer.conn.RemoteAddr(), conn.LocalAddr())
|
peer.RemoteAddr(), conn.LocalAddr())
|
||||||
}
|
|
||||||
if peer.dialAddr != connAddr {
|
|
||||||
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
|
|
||||||
peer.dialAddr, connAddr)
|
|
||||||
}
|
}
|
||||||
|
// TODO: validate more fields
|
||||||
case <-time.After(1 * time.Second):
|
case <-time.After(1 * time.Second):
|
||||||
t.Error("server did not launch peer within one second")
|
t.Error("server did not launch peer within one second")
|
||||||
}
|
}
|
||||||
|
@ -118,16 +113,17 @@ func TestServerDial(t *testing.T) {
|
||||||
|
|
||||||
func TestServerBroadcast(t *testing.T) {
|
func TestServerBroadcast(t *testing.T) {
|
||||||
defer testlog(t).detach()
|
defer testlog(t).detach()
|
||||||
|
|
||||||
var connected sync.WaitGroup
|
var connected sync.WaitGroup
|
||||||
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer {
|
srv := startTestServer(t, func(p *Peer) {
|
||||||
peer := newPeer(c, []Protocol{discard}, dialAddr)
|
p.protocols = []Protocol{discard}
|
||||||
peer.startSubprotocols([]Cap{discard.cap()})
|
p.startSubprotocols([]Cap{discard.cap()})
|
||||||
|
p.noHandshake = true
|
||||||
connected.Done()
|
connected.Done()
|
||||||
return peer
|
|
||||||
})
|
})
|
||||||
defer srv.Stop()
|
defer srv.Stop()
|
||||||
|
|
||||||
// dial a bunch of conns
|
// create a few peers
|
||||||
var conns = make([]net.Conn, 8)
|
var conns = make([]net.Conn, 8)
|
||||||
connected.Add(len(conns))
|
connected.Add(len(conns))
|
||||||
deadline := time.Now().Add(3 * time.Second)
|
deadline := time.Now().Add(3 * time.Second)
|
||||||
|
@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newkey() *ecdsa.PrivateKey {
|
||||||
|
key, err := crypto.GenerateKey()
|
||||||
|
if err != nil {
|
||||||
|
panic("couldn't generate key: " + err.Error())
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomID() (id discover.NodeID) {
|
||||||
|
for i := range id {
|
||||||
|
id[i] = byte(rand.Intn(255))
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel }
|
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
|
||||||
func (testLogger) SetLogLevel(logger.LogLevel) {}
|
func (testLogger) SetLogLevel(logger.LogLevel) {}
|
||||||
|
|
||||||
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
|
func (l testLogger) LogPrint(level logger.LogLevel, msg string) {
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
// +build none
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
|
||||||
"github.com/ethereum/go-ethereum/logger"
|
|
||||||
"github.com/ethereum/go-ethereum/p2p"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
|
|
||||||
|
|
||||||
pub, _ := secp256k1.GenerateKeyPair()
|
|
||||||
srv := p2p.Server{
|
|
||||||
MaxPeers: 10,
|
|
||||||
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
|
|
||||||
ListenAddr: ":30303",
|
|
||||||
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
|
|
||||||
}
|
|
||||||
if err := srv.Start(); err != nil {
|
|
||||||
fmt.Println("could not start server:", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// add seed peers
|
|
||||||
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("couldn't resolve:", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
srv.SuggestPeer(seed.IP, seed.Port, nil)
|
|
||||||
|
|
||||||
select {}
|
|
||||||
}
|
|
|
@ -350,8 +350,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
|
||||||
return writeUint, nil
|
return writeUint, nil
|
||||||
case kind == reflect.String:
|
case kind == reflect.String:
|
||||||
return writeString, nil
|
return writeString, nil
|
||||||
case kind == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 && !typ.Elem().Implements(encoderInterface):
|
case kind == reflect.Slice && isByte(typ.Elem()):
|
||||||
return writeBytes, nil
|
return writeBytes, nil
|
||||||
|
case kind == reflect.Array && isByte(typ.Elem()):
|
||||||
|
return writeByteArray, nil
|
||||||
case kind == reflect.Slice || kind == reflect.Array:
|
case kind == reflect.Slice || kind == reflect.Array:
|
||||||
return makeSliceWriter(typ)
|
return makeSliceWriter(typ)
|
||||||
case kind == reflect.Struct:
|
case kind == reflect.Struct:
|
||||||
|
@ -363,6 +365,10 @@ func makeWriter(typ reflect.Type) (writer, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isByte(typ reflect.Type) bool {
|
||||||
|
return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface)
|
||||||
|
}
|
||||||
|
|
||||||
func writeUint(val reflect.Value, w *encbuf) error {
|
func writeUint(val reflect.Value, w *encbuf) error {
|
||||||
i := val.Uint()
|
i := val.Uint()
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
|
@ -407,6 +413,20 @@ func writeBytes(val reflect.Value, w *encbuf) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeByteArray(val reflect.Value, w *encbuf) error {
|
||||||
|
if !val.CanAddr() {
|
||||||
|
// Slice requires the value to be addressable.
|
||||||
|
// Make it addressable by copying.
|
||||||
|
copy := reflect.New(val.Type()).Elem()
|
||||||
|
copy.Set(val)
|
||||||
|
val = copy
|
||||||
|
}
|
||||||
|
size := val.Len()
|
||||||
|
slice := val.Slice(0, size).Bytes()
|
||||||
|
w.encodeString(slice)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func writeString(val reflect.Value, w *encbuf) error {
|
func writeString(val reflect.Value, w *encbuf) error {
|
||||||
s := val.String()
|
s := val.String()
|
||||||
w.encodeStringHeader(len(s))
|
w.encodeStringHeader(len(s))
|
||||||
|
|
|
@ -40,6 +40,8 @@ func (e *encodableReader) Read(b []byte) (int, error) {
|
||||||
panic("called")
|
panic("called")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type namedByteType byte
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ = Encoder(&testEncoder{})
|
_ = Encoder(&testEncoder{})
|
||||||
_ = Encoder(byteEncoder(0))
|
_ = Encoder(byteEncoder(0))
|
||||||
|
@ -102,6 +104,10 @@ var encTests = []encTest{
|
||||||
// byte slices, strings
|
// byte slices, strings
|
||||||
{val: []byte{}, output: "80"},
|
{val: []byte{}, output: "80"},
|
||||||
{val: []byte{1, 2, 3}, output: "83010203"},
|
{val: []byte{1, 2, 3}, output: "83010203"},
|
||||||
|
|
||||||
|
{val: []namedByteType{1, 2, 3}, output: "83010203"},
|
||||||
|
{val: [...]namedByteType{1, 2, 3}, output: "83010203"},
|
||||||
|
|
||||||
{val: "", output: "80"},
|
{val: "", output: "80"},
|
||||||
{val: "dog", output: "83646F67"},
|
{val: "dog", output: "83646F67"},
|
||||||
{
|
{
|
||||||
|
|
|
@ -215,7 +215,7 @@ func NewPeer(peer *p2p.Peer) *Peer {
|
||||||
return &Peer{
|
return &Peer{
|
||||||
ref: peer,
|
ref: peer,
|
||||||
Ip: fmt.Sprintf("%v", peer.RemoteAddr()),
|
Ip: fmt.Sprintf("%v", peer.RemoteAddr()),
|
||||||
Version: fmt.Sprintf("%v", peer.Identity()),
|
Version: fmt.Sprintf("%v", peer.ID()),
|
||||||
Caps: fmt.Sprintf("%v", caps),
|
Caps: fmt.Sprintf("%v", caps),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,6 @@ type Backend interface {
|
||||||
IsListening() bool
|
IsListening() bool
|
||||||
Peers() []*p2p.Peer
|
Peers() []*p2p.Peer
|
||||||
KeyManager() *crypto.KeyManager
|
KeyManager() *crypto.KeyManager
|
||||||
ClientIdentity() p2p.ClientIdentity
|
|
||||||
Db() ethutil.Database
|
Db() ethutil.Database
|
||||||
EventMux() *event.TypeMux
|
EventMux() *event.TypeMux
|
||||||
Whisper() *whisper.Whisper
|
Whisper() *whisper.Whisper
|
||||||
|
|
Loading…
Reference in New Issue