diff --git a/p2p/client_identity.go b/p2p/client_identity.go
new file mode 100644
index 0000000000..236b23106f
--- /dev/null
+++ b/p2p/client_identity.go
@@ -0,0 +1,63 @@
+package p2p
+
+import (
+ "fmt"
+ "runtime"
+)
+
+// should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc.
+type ClientIdentity interface {
+ String() string
+ Pubkey() []byte
+}
+
+type SimpleClientIdentity struct {
+ clientIdentifier string
+ version string
+ customIdentifier string
+ os string
+ implementation string
+ pubkey string
+}
+
+func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey string) *SimpleClientIdentity {
+ clientIdentity := &SimpleClientIdentity{
+ clientIdentifier: clientIdentifier,
+ version: version,
+ customIdentifier: customIdentifier,
+ os: runtime.GOOS,
+ implementation: runtime.Version(),
+ pubkey: pubkey,
+ }
+
+ return clientIdentity
+}
+
+func (c *SimpleClientIdentity) init() {
+}
+
+func (c *SimpleClientIdentity) String() string {
+ var id string
+ if len(c.customIdentifier) > 0 {
+ id = "/" + c.customIdentifier
+ }
+
+ return fmt.Sprintf("%s/v%s%s/%s/%s",
+ c.clientIdentifier,
+ c.version,
+ id,
+ c.os,
+ c.implementation)
+}
+
+func (c *SimpleClientIdentity) Pubkey() []byte {
+ return []byte(c.pubkey)
+}
+
+func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
+ c.customIdentifier = customIdentifier
+}
+
+func (c *SimpleClientIdentity) GetCustomIdentifier() string {
+ return c.customIdentifier
+}
diff --git a/p2p/client_identity_test.go b/p2p/client_identity_test.go
new file mode 100644
index 0000000000..40b0e6f5e1
--- /dev/null
+++ b/p2p/client_identity_test.go
@@ -0,0 +1,30 @@
+package p2p
+
+import (
+ "fmt"
+ "runtime"
+ "testing"
+)
+
+func TestClientIdentity(t *testing.T) {
+ clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", "pubkey")
+ clientString := clientIdentity.String()
+ expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
+ if clientString != expected {
+ t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
+ }
+ customIdentifier := clientIdentity.GetCustomIdentifier()
+ if customIdentifier != "test" {
+ t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
+ }
+ clientIdentity.SetCustomIdentifier("test2")
+ customIdentifier = clientIdentity.GetCustomIdentifier()
+ if customIdentifier != "test2" {
+ t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
+ }
+ clientString = clientIdentity.String()
+ expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
+ if clientString != expected {
+ t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
+ }
+}
diff --git a/p2p/connection.go b/p2p/connection.go
new file mode 100644
index 0000000000..e999cbe55f
--- /dev/null
+++ b/p2p/connection.go
@@ -0,0 +1,275 @@
+package p2p
+
+import (
+ "bytes"
+ // "fmt"
+ "net"
+ "time"
+
+ "github.com/ethereum/eth-go/ethutil"
+)
+
+type Connection struct {
+ conn net.Conn
+ // conn NetworkConnection
+ timeout time.Duration
+ in chan []byte
+ out chan []byte
+ err chan *PeerError
+ closingIn chan chan bool
+ closingOut chan chan bool
+}
+
+// const readBufferLength = 2 //for testing
+
+const readBufferLength = 1440
+const partialsQueueSize = 10
+const maxPendingQueueSize = 1
+const defaultTimeout = 500
+
+var magicToken = []byte{34, 64, 8, 145}
+
+func (self *Connection) Open() {
+ go self.startRead()
+ go self.startWrite()
+}
+
+func (self *Connection) Close() {
+ self.closeIn()
+ self.closeOut()
+}
+
+func (self *Connection) closeIn() {
+ errc := make(chan bool)
+ self.closingIn <- errc
+ <-errc
+}
+
+func (self *Connection) closeOut() {
+ errc := make(chan bool)
+ self.closingOut <- errc
+ <-errc
+}
+
+func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection {
+ return &Connection{
+ conn: conn,
+ timeout: defaultTimeout,
+ in: make(chan []byte),
+ out: make(chan []byte),
+ err: errchan,
+ closingIn: make(chan chan bool, 1),
+ closingOut: make(chan chan bool, 1),
+ }
+}
+
+func (self *Connection) Read() <-chan []byte {
+ return self.in
+}
+
+func (self *Connection) Write() chan<- []byte {
+ return self.out
+}
+
+func (self *Connection) Error() <-chan *PeerError {
+ return self.err
+}
+
+func (self *Connection) startRead() {
+ payloads := make(chan []byte)
+ done := make(chan *PeerError)
+ pending := [][]byte{}
+ var head []byte
+ var wait time.Duration // initally 0 (no delay)
+ read := time.After(wait * time.Millisecond)
+
+ for {
+ // if pending empty, nil channel blocks
+ var in chan []byte
+ if len(pending) > 0 {
+ in = self.in // enable send case
+ head = pending[0]
+ } else {
+ in = nil
+ }
+
+ select {
+ case <-read:
+ go self.read(payloads, done)
+ case err := <-done:
+ if err == nil { // no error but nothing to read
+ if len(pending) < maxPendingQueueSize {
+ wait = 100
+ } else if wait == 0 {
+ wait = 100
+ } else {
+ wait = 2 * wait
+ }
+ } else {
+ self.err <- err // report error
+ wait = 100
+ }
+ read = time.After(wait * time.Millisecond)
+ case payload := <-payloads:
+ pending = append(pending, payload)
+ if len(pending) < maxPendingQueueSize {
+ wait = 0
+ } else {
+ wait = 100
+ }
+ read = time.After(wait * time.Millisecond)
+ case in <- head:
+ pending = pending[1:]
+ case errc := <-self.closingIn:
+ errc <- true
+ close(self.in)
+ return
+ }
+
+ }
+}
+
+func (self *Connection) startWrite() {
+ pending := [][]byte{}
+ done := make(chan *PeerError)
+ writing := false
+ for {
+ if len(pending) > 0 && !writing {
+ writing = true
+ go self.write(pending[0], done)
+ }
+ select {
+ case payload := <-self.out:
+ pending = append(pending, payload)
+ case err := <-done:
+ if err == nil {
+ pending = pending[1:]
+ writing = false
+ } else {
+ self.err <- err // report error
+ }
+ case errc := <-self.closingOut:
+ errc <- true
+ close(self.out)
+ return
+ }
+ }
+}
+
+func pack(payload []byte) (packet []byte) {
+ length := ethutil.NumberToBytes(uint32(len(payload)), 32)
+ // return error if too long?
+ // Write magic token and payload length (first 8 bytes)
+ packet = append(magicToken, length...)
+ packet = append(packet, payload...)
+ return
+}
+
+func avoidPanic(done chan *PeerError) {
+ if rec := recover(); rec != nil {
+ err := NewPeerError(MiscError, " %v", rec)
+ logger.Debugln(err)
+ done <- err
+ }
+}
+
+func (self *Connection) write(payload []byte, done chan *PeerError) {
+ defer avoidPanic(done)
+ var err *PeerError
+ _, ok := self.conn.Write(pack(payload))
+ if ok != nil {
+ err = NewPeerError(WriteError, " %v", ok)
+ logger.Debugln(err)
+ }
+ done <- err
+}
+
+func (self *Connection) read(payloads chan []byte, done chan *PeerError) {
+ //defer avoidPanic(done)
+
+ partials := make(chan []byte, partialsQueueSize)
+ errc := make(chan *PeerError)
+ go self.readPartials(partials, errc)
+
+ packet := []byte{}
+ length := 8
+ start := true
+ var err *PeerError
+out:
+ for {
+ // appends partials read via connection until packet is
+ // - either parseable (>=8bytes)
+ // - or complete (payload fully consumed)
+ for len(packet) < length {
+ partial, ok := <-partials
+ if !ok { // partials channel is closed
+ err = <-errc
+ if err == nil && len(packet) > 0 {
+ if start {
+ err = NewPeerError(PacketTooShort, "%v", packet)
+ } else {
+ err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length)
+ }
+ }
+ break out
+ }
+ packet = append(packet, partial...)
+ }
+ if start {
+ // at least 8 bytes read, can validate packet
+ if bytes.Compare(magicToken, packet[:4]) != 0 {
+ err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4])
+ break
+ }
+ length = int(ethutil.BytesToNumber(packet[4:8]))
+ packet = packet[8:]
+
+ if length > 0 {
+ start = false // now consuming payload
+ } else { //penalize peer but read on
+ self.err <- NewPeerError(EmptyPayload, "")
+ length = 8
+ }
+ } else {
+ // packet complete (payload fully consumed)
+ payloads <- packet[:length]
+ packet = packet[length:] // resclice packet
+ start = true
+ length = 8
+ }
+ }
+
+ // this stops partials read via the connection, should we?
+ //if err != nil {
+ // select {
+ // case errc <- err
+ // default:
+ //}
+ done <- err
+}
+
+func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) {
+ defer close(partials)
+ for {
+ // Give buffering some time
+ self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond))
+ buffer := make([]byte, readBufferLength)
+ // read partial from connection
+ bytesRead, err := self.conn.Read(buffer)
+ if err == nil || err.Error() == "EOF" {
+ if bytesRead > 0 {
+ partials <- buffer[:bytesRead]
+ }
+ if err != nil && err.Error() == "EOF" {
+ break
+ }
+ } else {
+ // unexpected error, report to errc
+ err := NewPeerError(ReadError, " %v", err)
+ logger.Debugln(err)
+ errc <- err
+ return // will close partials channel
+ }
+ }
+ close(errc)
+}
diff --git a/p2p/connection_test.go b/p2p/connection_test.go
new file mode 100644
index 0000000000..76ee8021c8
--- /dev/null
+++ b/p2p/connection_test.go
@@ -0,0 +1,222 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net"
+ "testing"
+ "time"
+)
+
+type TestNetworkConnection struct {
+ in chan []byte
+ current []byte
+ Out [][]byte
+ addr net.Addr
+}
+
+func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection {
+ return &TestNetworkConnection{
+ in: make(chan []byte),
+ current: []byte{},
+ Out: [][]byte{},
+ addr: addr,
+ }
+}
+
+func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) {
+ time.Sleep(latency)
+ for _, s := range packets {
+ self.in <- s
+ }
+}
+
+func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) {
+ if len(self.current) == 0 {
+ select {
+ case self.current = <-self.in:
+ default:
+ return 0, io.EOF
+ }
+ }
+ length := len(self.current)
+ if length > len(buff) {
+ copy(buff[:], self.current[:len(buff)])
+ self.current = self.current[len(buff):]
+ return len(buff), nil
+ } else {
+ copy(buff[:length], self.current[:])
+ self.current = []byte{}
+ return length, io.EOF
+ }
+}
+
+func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) {
+ self.Out = append(self.Out, buff)
+ fmt.Printf("net write %v\n%v\n", len(self.Out), buff)
+ return len(buff), nil
+}
+
+func (self *TestNetworkConnection) Close() (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) {
+ return
+}
+
+func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) {
+ return self.addr
+}
+
+func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) {
+ return
+}
+
+func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) {
+ return
+}
+
+func setupConnection() (*Connection, *TestNetworkConnection) {
+ addr := &TestAddr{"test:30303"}
+ net := NewTestNetworkConnection(addr)
+ conn := NewConnection(net, NewPeerErrorChannel())
+ conn.Open()
+ return conn, net
+}
+
+func TestReadingNilPacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{})
+ // time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ t.Errorf("incorrect error %v", err)
+ default:
+ }
+ conn.Close()
+}
+
+func TestReadingShortPacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{0})
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ if err.Code != PacketTooShort {
+ t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort)
+ }
+ }
+ conn.Close()
+}
+
+func TestReadingInvalidPacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ if err.Code != MagicTokenMismatch {
+ t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch)
+ }
+ }
+ conn.Close()
+}
+
+func TestReadingInvalidPayload(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0})
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ case err := <-conn.Error():
+ if err.Code != PayloadTooShort {
+ t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort)
+ }
+ }
+ conn.Close()
+}
+
+func TestReadingEmptyPayload(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0})
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ t.Errorf("read %v", packet)
+ default:
+ }
+ select {
+ case err := <-conn.Error():
+ code := err.Code
+ if code != EmptyPayload {
+ t.Errorf("incorrect error, expected EmptyPayload, got %v", code)
+ }
+ default:
+ t.Errorf("no error, expected EmptyPayload")
+ }
+ conn.Close()
+}
+
+func TestReadingCompletePacket(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1})
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ if bytes.Compare(packet, []byte{1}) != 0 {
+ t.Errorf("incorrect payload read")
+ }
+ case err := <-conn.Error():
+ t.Errorf("incorrect error %v", err)
+ default:
+ t.Errorf("nothing read")
+ }
+ conn.Close()
+}
+
+func TestReadingTwoCompletePackets(t *testing.T) {
+ conn, net := setupConnection()
+ go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1})
+
+ for i := 0; i < 2; i++ {
+ time.Sleep(10 * time.Millisecond)
+ select {
+ case packet := <-conn.Read():
+ if bytes.Compare(packet, []byte{byte(i)}) != 0 {
+ t.Errorf("incorrect payload read")
+ }
+ case err := <-conn.Error():
+ t.Errorf("incorrect error %v", err)
+ default:
+ t.Errorf("nothing read")
+ }
+ }
+ conn.Close()
+}
+
+func TestWriting(t *testing.T) {
+ conn, net := setupConnection()
+ conn.Write() <- []byte{0}
+ time.Sleep(10 * time.Millisecond)
+ if len(net.Out) == 0 {
+ t.Errorf("no output")
+ } else {
+ out := net.Out[0]
+ if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 {
+ t.Errorf("incorrect packet %v", out)
+ }
+ }
+ conn.Close()
+}
+
+// hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243
diff --git a/p2p/message.go b/p2p/message.go
new file mode 100644
index 0000000000..4886eaa1f2
--- /dev/null
+++ b/p2p/message.go
@@ -0,0 +1,75 @@
+package p2p
+
+import (
+ // "fmt"
+ "github.com/ethereum/eth-go/ethutil"
+)
+
+type MsgCode uint8
+
+type Msg struct {
+ code MsgCode // this is the raw code as per adaptive msg code scheme
+ data *ethutil.Value
+ encoded []byte
+}
+
+func (self *Msg) Code() MsgCode {
+ return self.code
+}
+
+func (self *Msg) Data() *ethutil.Value {
+ return self.data
+}
+
+func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) {
+
+ // // data := [][]interface{}{}
+ // data := []interface{}{}
+ // for _, value := range params {
+ // if encodable, ok := value.(ethutil.RlpEncodeDecode); ok {
+ // data = append(data, encodable.RlpValue())
+ // } else if raw, ok := value.([]interface{}); ok {
+ // data = append(data, raw)
+ // } else {
+ // // data = append(data, interface{}(raw))
+ // err = fmt.Errorf("Unable to encode object of type %T", value)
+ // return
+ // }
+ // }
+ return &Msg{
+ code: code,
+ data: ethutil.NewValue(interface{}(params)),
+ }, nil
+}
+
+func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) {
+ value := ethutil.NewValueFromBytes(encoded)
+ // Type of message
+ code := value.Get(0).Uint()
+ // Actual data
+ data := value.SliceFrom(1)
+
+ msg = &Msg{
+ code: MsgCode(code),
+ data: data,
+ // data: ethutil.NewValue(data),
+ encoded: encoded,
+ }
+ return
+}
+
+func (self *Msg) Decode(offset MsgCode) {
+ self.code = self.code - offset
+}
+
+// encode takes an offset argument to implement adaptive message coding
+// the encoded message is memoized to make msgs relayed to several peers more efficient
+func (self *Msg) Encode(offset MsgCode) (res []byte) {
+ if len(self.encoded) == 0 {
+ res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode()
+ self.encoded = res
+ } else {
+ res = self.encoded
+ }
+ return
+}
diff --git a/p2p/message_test.go b/p2p/message_test.go
new file mode 100644
index 0000000000..e9d46f2c3a
--- /dev/null
+++ b/p2p/message_test.go
@@ -0,0 +1,38 @@
+package p2p
+
+import (
+ "testing"
+)
+
+func TestNewMsg(t *testing.T) {
+ msg, _ := NewMsg(3, 1, "000")
+ if msg.Code() != 3 {
+ t.Errorf("incorrect code %v", msg.Code())
+ }
+ data0 := msg.Data().Get(0).Uint()
+ data1 := string(msg.Data().Get(1).Bytes())
+ if data0 != 1 {
+ t.Errorf("incorrect data %v", data0)
+ }
+ if data1 != "000" {
+ t.Errorf("incorrect data %v", data1)
+ }
+}
+
+func TestEncodeDecodeMsg(t *testing.T) {
+ msg, _ := NewMsg(3, 1, "000")
+ encoded := msg.Encode(3)
+ msg, _ = NewMsgFromBytes(encoded)
+ msg.Decode(3)
+ if msg.Code() != 3 {
+ t.Errorf("incorrect code %v", msg.Code())
+ }
+ data0 := msg.Data().Get(0).Uint()
+ data1 := msg.Data().Get(1).Str()
+ if data0 != 1 {
+ t.Errorf("incorrect data %v", data0)
+ }
+ if data1 != "000" {
+ t.Errorf("incorrect data %v", data1)
+ }
+}
diff --git a/p2p/messenger.go b/p2p/messenger.go
new file mode 100644
index 0000000000..d42ba1720e
--- /dev/null
+++ b/p2p/messenger.go
@@ -0,0 +1,220 @@
+package p2p
+
+import (
+ "fmt"
+ "sync"
+ "time"
+)
+
+const (
+ handlerTimeout = 1000
+)
+
+type Handlers map[string](func(p *Peer) Protocol)
+
+type Messenger struct {
+ conn *Connection
+ peer *Peer
+ handlers Handlers
+ protocolLock sync.RWMutex
+ protocols []Protocol
+ offsets []MsgCode // offsets for adaptive message idss
+ protocolTable map[string]int
+ quit chan chan bool
+ err chan *PeerError
+ pulse chan bool
+}
+
+func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger {
+ baseProtocol := NewBaseProtocol(peer)
+ return &Messenger{
+ conn: conn,
+ peer: peer,
+ offsets: []MsgCode{baseProtocol.Offset()},
+ handlers: handlers,
+ protocols: []Protocol{baseProtocol},
+ protocolTable: make(map[string]int),
+ err: errchan,
+ pulse: make(chan bool, 1),
+ quit: make(chan chan bool, 1),
+ }
+}
+
+func (self *Messenger) Start() {
+ self.conn.Open()
+ go self.messenger()
+ self.protocolLock.RLock()
+ defer self.protocolLock.RUnlock()
+ self.protocols[0].Start()
+}
+
+func (self *Messenger) Stop() {
+ // close pulse to stop ping pong monitoring
+ close(self.pulse)
+ self.protocolLock.RLock()
+ defer self.protocolLock.RUnlock()
+ for _, protocol := range self.protocols {
+ protocol.Stop() // could be parallel
+ }
+ q := make(chan bool)
+ self.quit <- q
+ <-q
+ self.conn.Close()
+}
+
+func (self *Messenger) messenger() {
+ in := self.conn.Read()
+ for {
+ select {
+ case payload, ok := <-in:
+ //dispatches message to the protocol asynchronously
+ if ok {
+ go self.handle(payload)
+ } else {
+ return
+ }
+ case q := <-self.quit:
+ q <- true
+ return
+ }
+ }
+}
+
+// handles each message by dispatching to the appropriate protocol
+// using adaptive message codes
+// this function is started as a separate go routine for each message
+// it waits for the protocol response
+// then encodes and sends outgoing messages to the connection's write channel
+func (self *Messenger) handle(payload []byte) {
+ // send ping to heartbeat channel signalling time of last message
+ // select {
+ // case self.pulse <- true:
+ // default:
+ // }
+ self.pulse <- true
+ // initialise message from payload
+ msg, err := NewMsgFromBytes(payload)
+ if err != nil {
+ self.err <- NewPeerError(MiscError, " %v", err)
+ return
+ }
+ // retrieves protocol based on message Code
+ protocol, offset, peerErr := self.getProtocol(msg.Code())
+ if err != nil {
+ self.err <- peerErr
+ return
+ }
+ // reset message code based on adaptive offset
+ msg.Decode(offset)
+ // dispatches
+ response := make(chan *Msg)
+ go protocol.HandleIn(msg, response)
+ // protocol reponse timeout to prevent leaks
+ timer := time.After(handlerTimeout * time.Millisecond)
+ for {
+ select {
+ case outgoing, ok := <-response:
+ // we check if response channel is not closed
+ if ok {
+ self.conn.Write() <- outgoing.Encode(offset)
+ } else {
+ return
+ }
+ case <-timer:
+ return
+ }
+ }
+}
+
+// negotiated protocols
+// stores offsets needed for adaptive message id scheme
+
+// based on offsets set at handshake
+// get the right protocol to handle the message
+func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) {
+ self.protocolLock.RLock()
+ defer self.protocolLock.RUnlock()
+ base := MsgCode(0)
+ for index, offset := range self.offsets {
+ if code < offset {
+ return self.protocols[index], base, nil
+ }
+ base = offset
+ }
+ return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code)
+}
+
+func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) {
+ fmt.Printf("pingpong keepalive started at %v", time.Now())
+
+ timer := time.After(timeout)
+ pinged := false
+ for {
+ select {
+ case _, ok := <-self.pulse:
+ if ok {
+ pinged = false
+ timer = time.After(timeout)
+ } else {
+ // pulse is closed, stop monitoring
+ return
+ }
+ case <-timer:
+ if pinged {
+ fmt.Printf("timeout at %v", time.Now())
+ timeoutCallback()
+ return
+ } else {
+ fmt.Printf("pinged at %v", time.Now())
+ pingCallback()
+ timer = time.After(gracePeriod)
+ pinged = true
+ }
+ }
+ }
+}
+
+func (self *Messenger) AddProtocols(protocols []string) {
+ self.protocolLock.Lock()
+ defer self.protocolLock.Unlock()
+ i := len(self.offsets)
+ offset := self.offsets[i-1]
+ for _, name := range protocols {
+ protocolFunc, ok := self.handlers[name]
+ if ok {
+ protocol := protocolFunc(self.peer)
+ self.protocolTable[name] = i
+ i++
+ offset += protocol.Offset()
+ fmt.Println("offset ", name, offset)
+
+ self.offsets = append(self.offsets, offset)
+ self.protocols = append(self.protocols, protocol)
+ protocol.Start()
+ } else {
+ fmt.Println("no ", name)
+ // protocol not handled
+ }
+ }
+}
+
+func (self *Messenger) Write(protocol string, msg *Msg) error {
+ self.protocolLock.RLock()
+ defer self.protocolLock.RUnlock()
+ i := 0
+ offset := MsgCode(0)
+ if len(protocol) > 0 {
+ var ok bool
+ i, ok = self.protocolTable[protocol]
+ if !ok {
+ return fmt.Errorf("protocol %v not handled by peer", protocol)
+ }
+ offset = self.offsets[i-1]
+ }
+ handler := self.protocols[i]
+ // checking if protocol status/caps allows the message to be sent out
+ if handler.HandleOut(msg) {
+ self.conn.Write() <- msg.Encode(offset)
+ }
+ return nil
+}
diff --git a/p2p/messenger_test.go b/p2p/messenger_test.go
new file mode 100644
index 0000000000..bc21d34ba6
--- /dev/null
+++ b/p2p/messenger_test.go
@@ -0,0 +1,146 @@
+package p2p
+
+import (
+ // "fmt"
+ "bytes"
+ "github.com/ethereum/eth-go/ethutil"
+ "testing"
+ "time"
+)
+
+func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) {
+ errchan := NewPeerErrorChannel()
+ addr := &TestAddr{"test:30303"}
+ net := NewTestNetworkConnection(addr)
+ conn := NewConnection(net, errchan)
+ mess := NewMessenger(nil, conn, errchan, handlers)
+ mess.Start()
+ return net, errchan, mess
+}
+
+type TestProtocol struct {
+ Msgs []*Msg
+}
+
+func (self *TestProtocol) Start() {
+}
+
+func (self *TestProtocol) Stop() {
+}
+
+func (self *TestProtocol) Offset() MsgCode {
+ return MsgCode(5)
+}
+
+func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) {
+ self.Msgs = append(self.Msgs, msg)
+ close(response)
+}
+
+func (self *TestProtocol) HandleOut(msg *Msg) bool {
+ if msg.Code() > 3 {
+ return false
+ } else {
+ return true
+ }
+}
+
+func (self *TestProtocol) Name() string {
+ return "a"
+}
+
+func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte {
+ msg, _ := NewMsg(code, params...)
+ encoded := msg.Encode(offset)
+ packet := []byte{34, 64, 8, 145}
+ packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...)
+ return append(packet, encoded...)
+}
+
+func TestRead(t *testing.T) {
+ handlers := make(Handlers)
+ testProtocol := &TestProtocol{Msgs: []*Msg{}}
+ handlers["a"] = func(p *Peer) Protocol { return testProtocol }
+ net, _, mess := setupMessenger(handlers)
+ mess.AddProtocols([]string{"a"})
+ defer mess.Stop()
+ wait := 1 * time.Millisecond
+ packet := Packet(16, 1, uint32(1), "000")
+ go net.In(0, packet)
+ time.Sleep(wait)
+ if len(testProtocol.Msgs) != 1 {
+ t.Errorf("msg not relayed to correct protocol")
+ } else {
+ if testProtocol.Msgs[0].Code() != 1 {
+ t.Errorf("incorrect msg code relayed to protocol")
+ }
+ }
+}
+
+func TestWrite(t *testing.T) {
+ handlers := make(Handlers)
+ testProtocol := &TestProtocol{Msgs: []*Msg{}}
+ handlers["a"] = func(p *Peer) Protocol { return testProtocol }
+ net, _, mess := setupMessenger(handlers)
+ mess.AddProtocols([]string{"a"})
+ defer mess.Stop()
+ wait := 1 * time.Millisecond
+ msg, _ := NewMsg(3, uint32(1), "000")
+ err := mess.Write("b", msg)
+ if err == nil {
+ t.Errorf("expect error for unknown protocol")
+ }
+ err = mess.Write("a", msg)
+ if err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ } else {
+ time.Sleep(wait)
+ if len(net.Out) != 1 {
+ t.Errorf("msg not written")
+ } else {
+ out := net.Out[0]
+ packet := Packet(16, 3, uint32(1), "000")
+ if bytes.Compare(out, packet) != 0 {
+ t.Errorf("incorrect packet %v", out)
+ }
+ }
+ }
+}
+
+func TestPulse(t *testing.T) {
+ net, _, mess := setupMessenger(make(Handlers))
+ defer mess.Stop()
+ ping := false
+ timeout := false
+ pingTimeout := 10 * time.Millisecond
+ gracePeriod := 200 * time.Millisecond
+ go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true })
+ net.In(0, Packet(0, 1))
+ if ping {
+ t.Errorf("ping sent too early")
+ }
+ time.Sleep(pingTimeout + 100*time.Millisecond)
+ if !ping {
+ t.Errorf("no ping sent after timeout")
+ }
+ if timeout {
+ t.Errorf("timeout too early")
+ }
+ ping = false
+ net.In(0, Packet(0, 1))
+ time.Sleep(pingTimeout + 100*time.Millisecond)
+ if !ping {
+ t.Errorf("no ping sent after timeout")
+ }
+ if timeout {
+ t.Errorf("timeout too early")
+ }
+ ping = false
+ time.Sleep(gracePeriod)
+ if ping {
+ t.Errorf("ping called twice")
+ }
+ if !timeout {
+ t.Errorf("no timeout after grace period")
+ }
+}
diff --git a/p2p/natpmp.go b/p2p/natpmp.go
new file mode 100644
index 0000000000..ff966d0701
--- /dev/null
+++ b/p2p/natpmp.go
@@ -0,0 +1,55 @@
+package p2p
+
+import (
+ "fmt"
+ "net"
+
+ natpmp "github.com/jackpal/go-nat-pmp"
+)
+
+// Adapt the NAT-PMP protocol to the NAT interface
+
+// TODO:
+// + Register for changes to the external address.
+// + Re-register port mapping when router reboots.
+// + A mechanism for keeping a port mapping registered.
+
+type natPMPClient struct {
+ client *natpmp.Client
+}
+
+func NewNatPMP(gateway net.IP) (nat NAT) {
+ return &natPMPClient{natpmp.NewClient(gateway)}
+}
+
+func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) {
+ response, err := n.client.GetExternalAddress()
+ if err != nil {
+ return
+ }
+ ip := response.ExternalIPAddress
+ addr = net.IPv4(ip[0], ip[1], ip[2], ip[3])
+ return
+}
+
+func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int,
+ description string, timeout int) (mappedExternalPort int, err error) {
+ if timeout <= 0 {
+ err = fmt.Errorf("timeout must not be <= 0")
+ return
+ }
+ // Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping.
+ response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout)
+ if err != nil {
+ return
+ }
+ mappedExternalPort = int(response.MappedExternalPort)
+ return
+}
+
+func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
+ // To destroy a mapping, send an add-port with
+ // an internalPort of the internal port to destroy, an external port of zero and a time of zero.
+ _, err = n.client.AddPortMapping(protocol, internalPort, 0, 0)
+ return
+}
diff --git a/p2p/natupnp.go b/p2p/natupnp.go
new file mode 100644
index 0000000000..fa9798d4d5
--- /dev/null
+++ b/p2p/natupnp.go
@@ -0,0 +1,335 @@
+package p2p
+
+// Just enough UPnP to be able to forward ports
+//
+
+import (
+ "bytes"
+ "encoding/xml"
+ "errors"
+ "net"
+ "net/http"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type upnpNAT struct {
+ serviceURL string
+ ourIP string
+}
+
+func upnpDiscover(attempts int) (nat NAT, err error) {
+ ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
+ if err != nil {
+ return
+ }
+ conn, err := net.ListenPacket("udp4", ":0")
+ if err != nil {
+ return
+ }
+ socket := conn.(*net.UDPConn)
+ defer socket.Close()
+
+ err = socket.SetDeadline(time.Now().Add(10 * time.Second))
+ if err != nil {
+ return
+ }
+
+ st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n"
+ buf := bytes.NewBufferString(
+ "M-SEARCH * HTTP/1.1\r\n" +
+ "HOST: 239.255.255.250:1900\r\n" +
+ st +
+ "MAN: \"ssdp:discover\"\r\n" +
+ "MX: 2\r\n\r\n")
+ message := buf.Bytes()
+ answerBytes := make([]byte, 1024)
+ for i := 0; i < attempts; i++ {
+ _, err = socket.WriteToUDP(message, ssdp)
+ if err != nil {
+ return
+ }
+ var n int
+ n, _, err = socket.ReadFromUDP(answerBytes)
+ if err != nil {
+ continue
+ // socket.Close()
+ // return
+ }
+ answer := string(answerBytes[0:n])
+ if strings.Index(answer, "\r\n"+st) < 0 {
+ continue
+ }
+ // HTTP header field names are case-insensitive.
+ // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
+ locString := "\r\nlocation: "
+ answer = strings.ToLower(answer)
+ locIndex := strings.Index(answer, locString)
+ if locIndex < 0 {
+ continue
+ }
+ loc := answer[locIndex+len(locString):]
+ endIndex := strings.Index(loc, "\r\n")
+ if endIndex < 0 {
+ continue
+ }
+ locURL := loc[0:endIndex]
+ var serviceURL string
+ serviceURL, err = getServiceURL(locURL)
+ if err != nil {
+ return
+ }
+ var ourIP string
+ ourIP, err = getOurIP()
+ if err != nil {
+ return
+ }
+ nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP}
+ return
+ }
+ err = errors.New("UPnP port discovery failed.")
+ return
+}
+
+// service represents the Service type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type service struct {
+ ServiceType string `xml:"serviceType"`
+ ControlURL string `xml:"controlURL"`
+}
+
+// deviceList represents the deviceList type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type deviceList struct {
+ XMLName xml.Name `xml:"deviceList"`
+ Device []device `xml:"device"`
+}
+
+// serviceList represents the serviceList type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type serviceList struct {
+ XMLName xml.Name `xml:"serviceList"`
+ Service []service `xml:"service"`
+}
+
+// device represents the device type in an UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type device struct {
+ XMLName xml.Name `xml:"device"`
+ DeviceType string `xml:"deviceType"`
+ DeviceList deviceList `xml:"deviceList"`
+ ServiceList serviceList `xml:"serviceList"`
+}
+
+// specVersion represents the specVersion in a UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type specVersion struct {
+ XMLName xml.Name `xml:"specVersion"`
+ Major int `xml:"major"`
+ Minor int `xml:"minor"`
+}
+
+// root represents the Root document for a UPnP xml description.
+// Only the parts we care about are present and thus the xml may have more
+// fields than present in the structure.
+type root struct {
+ XMLName xml.Name `xml:"root"`
+ SpecVersion specVersion
+ Device device
+}
+
+func getChildDevice(d *device, deviceType string) *device {
+ dl := d.DeviceList.Device
+ for i := 0; i < len(dl); i++ {
+ if dl[i].DeviceType == deviceType {
+ return &dl[i]
+ }
+ }
+ return nil
+}
+
+func getChildService(d *device, serviceType string) *service {
+ sl := d.ServiceList.Service
+ for i := 0; i < len(sl); i++ {
+ if sl[i].ServiceType == serviceType {
+ return &sl[i]
+ }
+ }
+ return nil
+}
+
+func getOurIP() (ip string, err error) {
+ hostname, err := os.Hostname()
+ if err != nil {
+ return
+ }
+ p, err := net.LookupIP(hostname)
+ if err != nil && len(p) > 0 {
+ return
+ }
+ return p[0].String(), nil
+}
+
+func getServiceURL(rootURL string) (url string, err error) {
+ r, err := http.Get(rootURL)
+ if err != nil {
+ return
+ }
+ defer r.Body.Close()
+ if r.StatusCode >= 400 {
+ err = errors.New(string(r.StatusCode))
+ return
+ }
+ var root root
+ err = xml.NewDecoder(r.Body).Decode(&root)
+
+ if err != nil {
+ return
+ }
+ a := &root.Device
+ if a.DeviceType != "urn:schemas-upnp-org:device:InternetGatewayDevice:1" {
+ err = errors.New("No InternetGatewayDevice")
+ return
+ }
+ b := getChildDevice(a, "urn:schemas-upnp-org:device:WANDevice:1")
+ if b == nil {
+ err = errors.New("No WANDevice")
+ return
+ }
+ c := getChildDevice(b, "urn:schemas-upnp-org:device:WANConnectionDevice:1")
+ if c == nil {
+ err = errors.New("No WANConnectionDevice")
+ return
+ }
+ d := getChildService(c, "urn:schemas-upnp-org:service:WANIPConnection:1")
+ if d == nil {
+ err = errors.New("No WANIPConnection")
+ return
+ }
+ url = combineURL(rootURL, d.ControlURL)
+ return
+}
+
+func combineURL(rootURL, subURL string) string {
+ protocolEnd := "://"
+ protoEndIndex := strings.Index(rootURL, protocolEnd)
+ a := rootURL[protoEndIndex+len(protocolEnd):]
+ rootIndex := strings.Index(a, "/")
+ return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
+}
+
+func soapRequest(url, function, message string) (r *http.Response, err error) {
+ fullMessage := "" +
+ "\r\n" +
+ "" + message + ""
+
+ req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
+ if err != nil {
+ return
+ }
+ req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
+ req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
+ //req.Header.Set("Transfer-Encoding", "chunked")
+ req.Header.Set("SOAPAction", "\"urn:schemas-upnp-org:service:WANIPConnection:1#"+function+"\"")
+ req.Header.Set("Connection", "Close")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Pragma", "no-cache")
+
+ r, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return
+ }
+
+ if r.Body != nil {
+ defer r.Body.Close()
+ }
+
+ if r.StatusCode >= 400 {
+ // log.Stderr(function, r.StatusCode)
+ err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
+ r = nil
+ return
+ }
+ return
+}
+
+type statusInfo struct {
+ externalIpAddress string
+}
+
+func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) {
+
+ message := "\r\n" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "GetStatusInfo", message)
+ if err != nil {
+ return
+ }
+
+ // TODO: Write a soap reply parser. It has to eat the Body and envelope tags...
+
+ response.Body.Close()
+ return
+}
+
+func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
+ info, err := n.getStatusInfo()
+ if err != nil {
+ return
+ }
+ addr = net.ParseIP(info.externalIpAddress)
+ return
+}
+
+func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
+ // A single concatenation would break ARM compilation.
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort)
+ message += "" + protocol + ""
+ message += "" + strconv.Itoa(internalPort) + "" +
+ "" + n.ourIP + "" +
+ "1"
+ message += description +
+ "" + strconv.Itoa(timeout) +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "AddPortMapping", message)
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was forwarded
+ // log.Println(message, response)
+ mappedExternalPort = externalPort
+ _ = response
+ return
+}
+
+func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
+
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort) +
+ "" + protocol + "" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "DeletePortMapping", message)
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was deleted
+ // log.Println(message, response)
+ _ = response
+ return
+}
diff --git a/p2p/network.go b/p2p/network.go
new file mode 100644
index 0000000000..820cef1a91
--- /dev/null
+++ b/p2p/network.go
@@ -0,0 +1,196 @@
+package p2p
+
+import (
+ "fmt"
+ "math/rand"
+ "net"
+ "strconv"
+ "time"
+)
+
+const (
+ DialerTimeout = 180 //seconds
+ KeepAlivePeriod = 60 //minutes
+ portMappingUpdateInterval = 900 // seconds = 15 mins
+ upnpDiscoverAttempts = 3
+)
+
+// Dialer is not an interface in net, so we define one
+// *net.Dialer conforms to this
+type Dialer interface {
+ Dial(network, address string) (net.Conn, error)
+}
+
+type Network interface {
+ Start() error
+ Listener(net.Addr) (net.Listener, error)
+ Dialer(net.Addr) (Dialer, error)
+ NewAddr(string, int) (addr net.Addr, err error)
+ ParseAddr(string) (addr net.Addr, err error)
+}
+
+type NAT interface {
+ GetExternalAddress() (addr net.IP, err error)
+ AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
+ DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
+}
+
+type TCPNetwork struct {
+ nat NAT
+ natType NATType
+ quit chan chan bool
+ ports chan string
+}
+
+type NATType int
+
+const (
+ NONE = iota
+ UPNP
+ PMP
+)
+
+const (
+ portMappingTimeout = 1200 // 20 mins
+)
+
+func NewTCPNetwork(natType NATType) (net *TCPNetwork) {
+ return &TCPNetwork{
+ natType: natType,
+ ports: make(chan string),
+ }
+}
+
+func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) {
+ return &net.Dialer{
+ Timeout: DialerTimeout * time.Second,
+ // KeepAlive: KeepAlivePeriod * time.Minute,
+ LocalAddr: addr,
+ }, nil
+}
+
+func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) {
+ if self.natType == UPNP {
+ _, port, _ := net.SplitHostPort(addr.String())
+ if self.quit == nil {
+ self.quit = make(chan chan bool)
+ go self.updatePortMappings()
+ }
+ self.ports <- port
+ }
+ return net.Listen(addr.Network(), addr.String())
+}
+
+func (self *TCPNetwork) Start() (err error) {
+ switch self.natType {
+ case NONE:
+ case UPNP:
+ nat, uerr := upnpDiscover(upnpDiscoverAttempts)
+ if uerr != nil {
+ err = fmt.Errorf("UPNP failed: ", uerr)
+ } else {
+ self.nat = nat
+ }
+ case PMP:
+ err = fmt.Errorf("PMP not implemented")
+ default:
+ err = fmt.Errorf("Invalid NAT type: %v", self.natType)
+ }
+ return
+}
+
+func (self *TCPNetwork) Stop() {
+ q := make(chan bool)
+ self.quit <- q
+ <-q
+}
+
+func (self *TCPNetwork) addPortMapping(lport int) (err error) {
+ _, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout)
+ if err != nil {
+ logger.Errorf("unable to add port mapping on %v: %v", lport, err)
+ } else {
+ logger.Debugf("succesfully added port mapping on %v", lport)
+ }
+ return
+}
+
+func (self *TCPNetwork) updatePortMappings() {
+ timer := time.NewTimer(portMappingUpdateInterval * time.Second)
+ lports := []int{}
+out:
+ for {
+ select {
+ case port := <-self.ports:
+ int64lport, _ := strconv.ParseInt(port, 10, 16)
+ lport := int(int64lport)
+ if err := self.addPortMapping(lport); err != nil {
+ lports = append(lports, lport)
+ }
+ case <-timer.C:
+ for lport := range lports {
+ if err := self.addPortMapping(lport); err != nil {
+ }
+ }
+ case errc := <-self.quit:
+ errc <- true
+ break out
+ }
+ }
+
+ timer.Stop()
+ for lport := range lports {
+ if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil {
+ logger.Debugf("unable to remove port mapping on %v: %v", lport, err)
+ } else {
+ logger.Debugf("succesfully removed port mapping on %v", lport)
+ }
+ }
+}
+
+func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) {
+ ip, err := self.lookupIP(host)
+ if err == nil {
+ return &net.TCPAddr{
+ IP: ip,
+ Port: port,
+ }, nil
+ }
+ return nil, err
+}
+
+func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err == nil {
+ iport, _ := strconv.Atoi(port)
+ addr, e := self.NewAddr(host, iport)
+ return addr, e
+ }
+ return nil, err
+}
+
+func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) {
+ if ip = net.ParseIP(host); ip != nil {
+ return
+ }
+
+ var ips []net.IP
+ ips, err = net.LookupIP(host)
+ if err != nil {
+ logger.Warnln(err)
+ return
+ }
+ if len(ips) == 0 {
+ err = fmt.Errorf("No IP addresses available for %v", host)
+ logger.Warnln(err)
+ return
+ }
+ if len(ips) > 1 {
+ // Pick a random IP address, simulating round-robin DNS.
+ rand.Seed(time.Now().UTC().UnixNano())
+ ip = ips[rand.Intn(len(ips))]
+ } else {
+ ip = ips[0]
+ }
+ return
+}
diff --git a/p2p/peer.go b/p2p/peer.go
new file mode 100644
index 0000000000..f4b68a007a
--- /dev/null
+++ b/p2p/peer.go
@@ -0,0 +1,83 @@
+package p2p
+
+import (
+ "fmt"
+ "net"
+ "strconv"
+)
+
+type Peer struct {
+ // quit chan chan bool
+ Inbound bool // inbound (via listener) or outbound (via dialout)
+ Address net.Addr
+ Host []byte
+ Port uint16
+ Pubkey []byte
+ Id string
+ Caps []string
+ peerErrorChan chan *PeerError
+ messenger *Messenger
+ peerErrorHandler *PeerErrorHandler
+ server *Server
+}
+
+func (self *Peer) Messenger() *Messenger {
+ return self.messenger
+}
+
+func (self *Peer) PeerErrorChan() chan *PeerError {
+ return self.peerErrorChan
+}
+
+func (self *Peer) Server() *Server {
+ return self.server
+}
+
+func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer {
+ peerErrorChan := NewPeerErrorChannel()
+ host, port, _ := net.SplitHostPort(address.String())
+ intport, _ := strconv.Atoi(port)
+ peer := &Peer{
+ Inbound: inbound,
+ Address: address,
+ Port: uint16(intport),
+ Host: net.ParseIP(host),
+ peerErrorChan: peerErrorChan,
+ server: server,
+ }
+ connection := NewConnection(conn, peerErrorChan)
+ peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers())
+ peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist())
+ return peer
+}
+
+func (self *Peer) String() string {
+ var kind string
+ if self.Inbound {
+ kind = "inbound"
+ } else {
+ kind = "outbound"
+ }
+ return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps)
+}
+
+func (self *Peer) Write(protocol string, msg *Msg) error {
+ return self.messenger.Write(protocol, msg)
+}
+
+func (self *Peer) Start() {
+ self.peerErrorHandler.Start()
+ self.messenger.Start()
+}
+
+func (self *Peer) Stop() {
+ self.peerErrorHandler.Stop()
+ self.messenger.Stop()
+ // q := make(chan bool)
+ // self.quit <- q
+ // <-q
+}
+
+func (p *Peer) Encode() []interface{} {
+ return []interface{}{p.Host, p.Port, p.Pubkey}
+}
diff --git a/p2p/peer_error.go b/p2p/peer_error.go
new file mode 100644
index 0000000000..de921878a0
--- /dev/null
+++ b/p2p/peer_error.go
@@ -0,0 +1,76 @@
+package p2p
+
+import (
+ "fmt"
+)
+
+type ErrorCode int
+
+const errorChanCapacity = 10
+
+const (
+ PacketTooShort = iota
+ PayloadTooShort
+ MagicTokenMismatch
+ EmptyPayload
+ ReadError
+ WriteError
+ MiscError
+ InvalidMsgCode
+ InvalidMsg
+ P2PVersionMismatch
+ PubkeyMissing
+ PubkeyInvalid
+ PubkeyForbidden
+ ProtocolBreach
+ PortMismatch
+ PingTimeout
+ InvalidGenesis
+ InvalidNetworkId
+ InvalidProtocolVersion
+)
+
+var errorToString = map[ErrorCode]string{
+ PacketTooShort: "Packet too short",
+ PayloadTooShort: "Payload too short",
+ MagicTokenMismatch: "Magic token mismatch",
+ EmptyPayload: "Empty payload",
+ ReadError: "Read error",
+ WriteError: "Write error",
+ MiscError: "Misc error",
+ InvalidMsgCode: "Invalid message code",
+ InvalidMsg: "Invalid message",
+ P2PVersionMismatch: "P2P Version Mismatch",
+ PubkeyMissing: "Public key missing",
+ PubkeyInvalid: "Public key invalid",
+ PubkeyForbidden: "Public key forbidden",
+ ProtocolBreach: "Protocol Breach",
+ PortMismatch: "Port mismatch",
+ PingTimeout: "Ping timeout",
+ InvalidGenesis: "Invalid genesis block",
+ InvalidNetworkId: "Invalid network id",
+ InvalidProtocolVersion: "Invalid protocol version",
+}
+
+type PeerError struct {
+ Code ErrorCode
+ message string
+}
+
+func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError {
+ desc, ok := errorToString[code]
+ if !ok {
+ panic("invalid error code")
+ }
+ format = desc + ": " + format
+ message := fmt.Sprintf(format, v...)
+ return &PeerError{code, message}
+}
+
+func (self *PeerError) Error() string {
+ return self.message
+}
+
+func NewPeerErrorChannel() chan *PeerError {
+ return make(chan *PeerError, errorChanCapacity)
+}
diff --git a/p2p/peer_error_handler.go b/p2p/peer_error_handler.go
new file mode 100644
index 0000000000..ca6cae4dbc
--- /dev/null
+++ b/p2p/peer_error_handler.go
@@ -0,0 +1,101 @@
+package p2p
+
+import (
+ "net"
+)
+
+const (
+ severityThreshold = 10
+)
+
+type DisconnectRequest struct {
+ addr net.Addr
+ reason DiscReason
+}
+
+type PeerErrorHandler struct {
+ quit chan chan bool
+ address net.Addr
+ peerDisconnect chan DisconnectRequest
+ severity int
+ peerErrorChan chan *PeerError
+ blacklist Blacklist
+}
+
+func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler {
+ return &PeerErrorHandler{
+ quit: make(chan chan bool),
+ address: address,
+ peerDisconnect: peerDisconnect,
+ peerErrorChan: peerErrorChan,
+ blacklist: blacklist,
+ }
+}
+
+func (self *PeerErrorHandler) Start() {
+ go self.listen()
+}
+
+func (self *PeerErrorHandler) Stop() {
+ q := make(chan bool)
+ self.quit <- q
+ <-q
+}
+
+func (self *PeerErrorHandler) listen() {
+ for {
+ select {
+ case peerError, ok := <-self.peerErrorChan:
+ if ok {
+ logger.Debugf("error %v\n", peerError)
+ go self.handle(peerError)
+ } else {
+ return
+ }
+ case q := <-self.quit:
+ q <- true
+ return
+ }
+ }
+}
+
+func (self *PeerErrorHandler) handle(peerError *PeerError) {
+ reason := DiscReason(' ')
+ switch peerError.Code {
+ case P2PVersionMismatch:
+ reason = DiscIncompatibleVersion
+ case PubkeyMissing, PubkeyInvalid:
+ reason = DiscInvalidIdentity
+ case PubkeyForbidden:
+ reason = DiscUselessPeer
+ case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach:
+ reason = DiscProtocolError
+ case PingTimeout:
+ reason = DiscReadTimeout
+ case WriteError, MiscError:
+ reason = DiscNetworkError
+ case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion:
+ reason = DiscSubprotocolError
+ default:
+ self.severity += self.getSeverity(peerError)
+ }
+
+ if self.severity >= severityThreshold {
+ reason = DiscSubprotocolError
+ }
+ if reason != DiscReason(' ') {
+ self.peerDisconnect <- DisconnectRequest{
+ addr: self.address,
+ reason: reason,
+ }
+ }
+}
+
+func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int {
+ switch peerError.Code {
+ case ReadError:
+ return 4 //tolerate 3 :)
+ default:
+ return 1
+ }
+}
diff --git a/p2p/peer_error_handler_test.go b/p2p/peer_error_handler_test.go
new file mode 100644
index 0000000000..790a7443b8
--- /dev/null
+++ b/p2p/peer_error_handler_test.go
@@ -0,0 +1,34 @@
+package p2p
+
+import (
+ // "fmt"
+ "net"
+ "testing"
+ "time"
+)
+
+func TestPeerErrorHandler(t *testing.T) {
+ address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303}
+ peerDisconnect := make(chan DisconnectRequest)
+ peerErrorChan := NewPeerErrorChannel()
+ peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist())
+ peh.Start()
+ defer peh.Stop()
+ for i := 0; i < 11; i++ {
+ select {
+ case <-peerDisconnect:
+ t.Errorf("expected no disconnect request")
+ default:
+ }
+ peerErrorChan <- NewPeerError(MiscError, "")
+ }
+ time.Sleep(1 * time.Millisecond)
+ select {
+ case request := <-peerDisconnect:
+ if request.addr.String() != address.String() {
+ t.Errorf("incorrect address %v != %v", request.addr, address)
+ }
+ default:
+ t.Errorf("expected disconnect request")
+ }
+}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
new file mode 100644
index 0000000000..c37540bef3
--- /dev/null
+++ b/p2p/peer_test.go
@@ -0,0 +1,96 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ // "net"
+ "testing"
+ "time"
+)
+
+func TestPeer(t *testing.T) {
+ handlers := make(Handlers)
+ testProtocol := &TestProtocol{Msgs: []*Msg{}}
+ handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
+ handlers["ccc"] = func(p *Peer) Protocol { return testProtocol }
+ addr := &TestAddr{"test:30"}
+ conn := NewTestNetworkConnection(addr)
+ _, server := SetupTestServer(handlers)
+ server.Handshake()
+ peer := NewPeer(conn, addr, true, server)
+ // peer.Messenger().AddProtocols([]string{"aaa", "ccc"})
+ peer.Start()
+ defer peer.Stop()
+ time.Sleep(2 * time.Millisecond)
+ if len(conn.Out) != 1 {
+ t.Errorf("handshake not sent")
+ } else {
+ out := conn.Out[0]
+ packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:])
+ if bytes.Compare(out, packet) != 0 {
+ t.Errorf("incorrect handshake packet %v != %v", out, packet)
+ }
+ }
+
+ packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000"))
+ conn.In(0, packet)
+ time.Sleep(10 * time.Millisecond)
+
+ pro, _ := peer.Messenger().protocols[0].(*BaseProtocol)
+ if pro.state != handshakeReceived {
+ t.Errorf("handshake not received")
+ }
+ if peer.Port != 30 {
+ t.Errorf("port incorrectly set")
+ }
+ if peer.Id != "peer" {
+ t.Errorf("id incorrectly set")
+ }
+ if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" {
+ t.Errorf("pubkey incorrectly set")
+ }
+ fmt.Println(peer.Caps)
+ if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" {
+ t.Errorf("protocols incorrectly set")
+ }
+
+ msg, _ := NewMsg(3)
+ err := peer.Write("aaa", msg)
+ if err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ } else {
+ time.Sleep(1 * time.Millisecond)
+ if len(conn.Out) != 2 {
+ t.Errorf("msg not written")
+ } else {
+ out := conn.Out[1]
+ packet := Packet(16, 3)
+ if bytes.Compare(out, packet) != 0 {
+ t.Errorf("incorrect packet %v != %v", out, packet)
+ }
+ }
+ }
+
+ msg, _ = NewMsg(2)
+ err = peer.Write("ccc", msg)
+ if err != nil {
+ t.Errorf("expect no error for known protocol: %v", err)
+ } else {
+ time.Sleep(1 * time.Millisecond)
+ if len(conn.Out) != 3 {
+ t.Errorf("msg not written")
+ } else {
+ out := conn.Out[2]
+ packet := Packet(21, 2)
+ if bytes.Compare(out, packet) != 0 {
+ t.Errorf("incorrect packet %v != %v", out, packet)
+ }
+ }
+ }
+
+ err = peer.Write("bbb", msg)
+ time.Sleep(1 * time.Millisecond)
+ if err == nil {
+ t.Errorf("expect error for unknown protocol")
+ }
+}
diff --git a/p2p/protocol.go b/p2p/protocol.go
new file mode 100644
index 0000000000..5d05ced7d2
--- /dev/null
+++ b/p2p/protocol.go
@@ -0,0 +1,278 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "sort"
+ "sync"
+ "time"
+)
+
+type Protocol interface {
+ Start()
+ Stop()
+ HandleIn(*Msg, chan *Msg)
+ HandleOut(*Msg) bool
+ Offset() MsgCode
+ Name() string
+}
+
+const (
+ P2PVersion = 0
+ pingTimeout = 2
+ pingGracePeriod = 2
+)
+
+const (
+ HandshakeMsg = iota
+ DiscMsg
+ PingMsg
+ PongMsg
+ GetPeersMsg
+ PeersMsg
+ offset = 16
+)
+
+type ProtocolState uint8
+
+const (
+ nullState = iota
+ handshakeReceived
+)
+
+type DiscReason byte
+
+const (
+ // Values are given explicitly instead of by iota because these values are
+ // defined by the wire protocol spec; it is easier for humans to ensure
+ // correctness when values are explicit.
+ DiscRequested = 0x00
+ DiscNetworkError = 0x01
+ DiscProtocolError = 0x02
+ DiscUselessPeer = 0x03
+ DiscTooManyPeers = 0x04
+ DiscAlreadyConnected = 0x05
+ DiscIncompatibleVersion = 0x06
+ DiscInvalidIdentity = 0x07
+ DiscQuitting = 0x08
+ DiscUnexpectedIdentity = 0x09
+ DiscSelf = 0x0a
+ DiscReadTimeout = 0x0b
+ DiscSubprotocolError = 0x10
+)
+
+var discReasonToString = map[DiscReason]string{
+ DiscRequested: "Disconnect requested",
+ DiscNetworkError: "Network error",
+ DiscProtocolError: "Breach of protocol",
+ DiscUselessPeer: "Useless peer",
+ DiscTooManyPeers: "Too many peers",
+ DiscAlreadyConnected: "Already connected",
+ DiscIncompatibleVersion: "Incompatible P2P protocol version",
+ DiscInvalidIdentity: "Invalid node identity",
+ DiscQuitting: "Client quitting",
+ DiscUnexpectedIdentity: "Unexpected identity",
+ DiscSelf: "Connected to self",
+ DiscReadTimeout: "Read timeout",
+ DiscSubprotocolError: "Subprotocol error",
+}
+
+func (d DiscReason) String() string {
+ if len(discReasonToString) < int(d) {
+ return "Unknown"
+ }
+
+ return discReasonToString[d]
+}
+
+type BaseProtocol struct {
+ peer *Peer
+ state ProtocolState
+ stateLock sync.RWMutex
+}
+
+func NewBaseProtocol(peer *Peer) *BaseProtocol {
+ self := &BaseProtocol{
+ peer: peer,
+ }
+
+ return self
+}
+
+func (self *BaseProtocol) Start() {
+ if self.peer != nil {
+ self.peer.Write("", self.peer.Server().Handshake())
+ go self.peer.Messenger().PingPong(
+ pingTimeout*time.Second,
+ pingGracePeriod*time.Second,
+ self.Ping,
+ self.Timeout,
+ )
+ }
+}
+
+func (self *BaseProtocol) Stop() {
+}
+
+func (self *BaseProtocol) Ping() {
+ msg, _ := NewMsg(PingMsg)
+ self.peer.Write("", msg)
+}
+
+func (self *BaseProtocol) Timeout() {
+ self.peerError(PingTimeout, "")
+}
+
+func (self *BaseProtocol) Name() string {
+ return ""
+}
+
+func (self *BaseProtocol) Offset() MsgCode {
+ return offset
+}
+
+func (self *BaseProtocol) CheckState(state ProtocolState) bool {
+ self.stateLock.RLock()
+ self.stateLock.RUnlock()
+ if self.state != state {
+ return false
+ } else {
+ return true
+ }
+}
+
+func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) {
+ if msg.Code() == HandshakeMsg {
+ self.handleHandshake(msg)
+ } else {
+ if !self.CheckState(handshakeReceived) {
+ self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code())
+ close(response)
+ return
+ }
+ switch msg.Code() {
+ case DiscMsg:
+ logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
+ self.peer.Server().PeerDisconnect() <- DisconnectRequest{
+ addr: self.peer.Address,
+ reason: DiscRequested,
+ }
+ case PingMsg:
+ out, _ := NewMsg(PongMsg)
+ response <- out
+ case PongMsg:
+ case GetPeersMsg:
+ // Peer asked for list of connected peers
+ if out, err := self.peer.Server().PeersMessage(); err != nil {
+ response <- out
+ }
+ case PeersMsg:
+ self.handlePeers(msg)
+ default:
+ self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code())
+ }
+ }
+ close(response)
+}
+
+func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) {
+ // somewhat overly paranoid
+ allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived)
+ return
+}
+
+func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) {
+ err := NewPeerError(errorCode, format, v...)
+ logger.Warnln(err)
+ fmt.Println(self.peer, err)
+ if self.peer != nil {
+ self.peer.PeerErrorChan() <- err
+ }
+}
+
+func (self *BaseProtocol) handlePeers(msg *Msg) {
+ it := msg.Data().NewIterator()
+ for it.Next() {
+ ip := net.IP(it.Value().Get(0).Bytes())
+ port := it.Value().Get(1).Uint()
+ address := &net.TCPAddr{IP: ip, Port: int(port)}
+ go self.peer.Server().PeerConnect(address)
+ }
+}
+
+func (self *BaseProtocol) handleHandshake(msg *Msg) {
+ self.stateLock.Lock()
+ defer self.stateLock.Unlock()
+ if self.state != nullState {
+ self.peerError(ProtocolBreach, "extra handshake")
+ return
+ }
+
+ c := msg.Data()
+
+ var (
+ p2pVersion = c.Get(0).Uint()
+ id = c.Get(1).Str()
+ caps = c.Get(2)
+ port = c.Get(3).Uint()
+ pubkey = c.Get(4).Bytes()
+ )
+ fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
+
+ // Check correctness of p2p protocol version
+ if p2pVersion != P2PVersion {
+ self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
+ return
+ }
+
+ // Handle the pub key (validation, uniqueness)
+ if len(pubkey) == 0 {
+ self.peerError(PubkeyMissing, "not supplied in handshake.")
+ return
+ }
+
+ if len(pubkey) != 64 {
+ self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
+ return
+ }
+
+ // Self connect detection
+ if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 {
+ self.peerError(PubkeyForbidden, "not allowed to connect to self")
+ return
+ }
+
+ // register pubkey on server. this also sets the pubkey on the peer (need lock)
+ if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil {
+ self.peerError(PubkeyForbidden, err.Error())
+ return
+ }
+
+ // check port
+ if self.peer.Inbound {
+ uint16port := uint16(port)
+ if self.peer.Port > 0 && self.peer.Port != uint16port {
+ self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
+ return
+ } else {
+ self.peer.Port = uint16port
+ }
+ }
+
+ capsIt := caps.NewIterator()
+ for capsIt.Next() {
+ cap := capsIt.Value().Str()
+ self.peer.Caps = append(self.peer.Caps, cap)
+ }
+ sort.Strings(self.peer.Caps)
+ self.peer.Messenger().AddProtocols(self.peer.Caps)
+
+ self.peer.Id = id
+
+ self.state = handshakeReceived
+
+ //p.ethereum.PushPeer(p)
+ // p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
+ return
+}
diff --git a/p2p/server.go b/p2p/server.go
new file mode 100644
index 0000000000..a6bbd92601
--- /dev/null
+++ b/p2p/server.go
@@ -0,0 +1,484 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "sort"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/ethereum/eth-go/ethlog"
+)
+
+const (
+ outboundAddressPoolSize = 10
+ disconnectGracePeriod = 2
+)
+
+type Blacklist interface {
+ Get([]byte) (bool, error)
+ Put([]byte) error
+ Delete([]byte) error
+ Exists(pubkey []byte) (ok bool)
+}
+
+type BlacklistMap struct {
+ blacklist map[string]bool
+ lock sync.RWMutex
+}
+
+func NewBlacklist() *BlacklistMap {
+ return &BlacklistMap{
+ blacklist: make(map[string]bool),
+ }
+}
+
+func (self *BlacklistMap) Get(pubkey []byte) (bool, error) {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ v, ok := self.blacklist[string(pubkey)]
+ var err error
+ if !ok {
+ err = fmt.Errorf("not found")
+ }
+ return v, err
+}
+
+func (self *BlacklistMap) Exists(pubkey []byte) (ok bool) {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ _, ok = self.blacklist[string(pubkey)]
+ return
+}
+
+func (self *BlacklistMap) Put(pubkey []byte) error {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ self.blacklist[string(pubkey)] = true
+ return nil
+}
+
+func (self *BlacklistMap) Delete(pubkey []byte) error {
+ self.lock.RLock()
+ defer self.lock.RUnlock()
+ delete(self.blacklist, string(pubkey))
+ return nil
+}
+
+type Server struct {
+ network Network
+ listening bool //needed?
+ dialing bool //needed?
+ closed bool
+ identity ClientIdentity
+ addr net.Addr
+ port uint16
+ protocols []string
+
+ quit chan chan bool
+ peersLock sync.RWMutex
+
+ maxPeers int
+ peers []*Peer
+ peerSlots chan int
+ peersTable map[string]int
+ peersMsg *Msg
+ peerCount int
+
+ peerConnect chan net.Addr
+ peerDisconnect chan DisconnectRequest
+ blacklist Blacklist
+ handlers Handlers
+}
+
+var logger = ethlog.NewLogger("P2P")
+
+func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server {
+ // get alphabetical list of protocol names from handlers map
+ protocols := []string{}
+ for protocol := range handlers {
+ protocols = append(protocols, protocol)
+ }
+ sort.Strings(protocols)
+
+ _, port, _ := net.SplitHostPort(addr.String())
+ intport, _ := strconv.Atoi(port)
+
+ self := &Server{
+ // NewSimpleClientIdentity(clientIdentifier, version, customIdentifier)
+ network: network,
+ identity: identity,
+ addr: addr,
+ port: uint16(intport),
+ protocols: protocols,
+
+ quit: make(chan chan bool),
+
+ maxPeers: maxPeers,
+ peers: make([]*Peer, maxPeers),
+ peerSlots: make(chan int, maxPeers),
+ peersTable: make(map[string]int),
+
+ peerConnect: make(chan net.Addr, outboundAddressPoolSize),
+ peerDisconnect: make(chan DisconnectRequest),
+ blacklist: blacklist,
+
+ handlers: handlers,
+ }
+ for i := 0; i < maxPeers; i++ {
+ self.peerSlots <- i // fill up with indexes
+ }
+ return self
+}
+
+func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) {
+ addr, err = self.network.NewAddr(host, port)
+ return
+}
+
+func (self *Server) ParseAddr(address string) (addr net.Addr, err error) {
+ addr, err = self.network.ParseAddr(address)
+ return
+}
+
+func (self *Server) ClientIdentity() ClientIdentity {
+ return self.identity
+}
+
+func (self *Server) PeersMessage() (msg *Msg, err error) {
+ // TODO: memoize and reset when peers change
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ msg = self.peersMsg
+ if msg == nil {
+ var peerData []interface{}
+ for _, i := range self.peersTable {
+ peer := self.peers[i]
+ peerData = append(peerData, peer.Encode())
+ }
+ if len(peerData) == 0 {
+ err = fmt.Errorf("no peers")
+ } else {
+ msg, err = NewMsg(PeersMsg, peerData...)
+ self.peersMsg = msg //memoize
+ }
+ }
+ return
+}
+
+func (self *Server) Peers() (peers []*Peer) {
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ for _, peer := range self.peers {
+ if peer != nil {
+ peers = append(peers, peer)
+ }
+ }
+ return
+}
+
+func (self *Server) PeerCount() int {
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ return self.peerCount
+}
+
+var getPeersMsg, _ = NewMsg(GetPeersMsg)
+
+func (self *Server) PeerConnect(addr net.Addr) {
+ // TODO: should buffer, filter and uniq
+ // send GetPeersMsg if not blocking
+ select {
+ case self.peerConnect <- addr: // not enough peers
+ self.Broadcast("", getPeersMsg)
+ default: // we dont care
+ }
+}
+
+func (self *Server) PeerDisconnect() chan DisconnectRequest {
+ return self.peerDisconnect
+}
+
+func (self *Server) Blacklist() Blacklist {
+ return self.blacklist
+}
+
+func (self *Server) Handlers() Handlers {
+ return self.handlers
+}
+
+func (self *Server) Broadcast(protocol string, msg *Msg) {
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ for _, peer := range self.peers {
+ if peer != nil {
+ peer.Write(protocol, msg)
+ }
+ }
+}
+
+// Start the server
+func (self *Server) Start(listen bool, dial bool) {
+ self.network.Start()
+ if listen {
+ listener, err := self.network.Listener(self.addr)
+ if err != nil {
+ logger.Warnf("Error initializing listener: %v", err)
+ logger.Warnf("Connection listening disabled")
+ self.listening = false
+ } else {
+ self.listening = true
+ logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr())
+ go self.inboundPeerHandler(listener)
+ }
+ }
+ if dial {
+ dialer, err := self.network.Dialer(self.addr)
+ if err != nil {
+ logger.Warnf("Error initializing dialer: %v", err)
+ logger.Warnf("Connection dialout disabled")
+ self.dialing = false
+ } else {
+ self.dialing = true
+ logger.Infoln("Dial peers watching outbound address pool")
+ go self.outboundPeerHandler(dialer)
+ }
+ }
+ logger.Infoln("server started")
+}
+
+func (self *Server) Stop() {
+ logger.Infoln("server stopping...")
+ // // quit one loop if dialing
+ if self.dialing {
+ logger.Infoln("stop dialout...")
+ dialq := make(chan bool)
+ self.quit <- dialq
+ <-dialq
+ fmt.Println("quit another")
+ }
+ // quit the other loop if listening
+ if self.listening {
+ logger.Infoln("stop listening...")
+ listenq := make(chan bool)
+ self.quit <- listenq
+ <-listenq
+ fmt.Println("quit one")
+ }
+
+ fmt.Println("quit waited")
+
+ logger.Infoln("stopping peers...")
+ peers := []net.Addr{}
+ self.peersLock.RLock()
+ self.closed = true
+ for _, peer := range self.peers {
+ if peer != nil {
+ peers = append(peers, peer.Address)
+ }
+ }
+ self.peersLock.RUnlock()
+ for _, address := range peers {
+ go self.removePeer(DisconnectRequest{
+ addr: address,
+ reason: DiscQuitting,
+ })
+ }
+ // wait till they actually disconnect
+ // this is checked by draining the peerSlots (slots are released back if a peer is removed)
+ i := 0
+ fmt.Println("draining peers")
+
+FOR:
+ for {
+ select {
+ case slot := <-self.peerSlots:
+ i++
+ fmt.Printf("%v: found slot %v", i, slot)
+ if i == self.maxPeers {
+ break FOR
+ }
+ }
+ }
+ logger.Infoln("server stopped")
+}
+
+// main loop for adding connections via listening
+func (self *Server) inboundPeerHandler(listener net.Listener) {
+ for {
+ select {
+ case slot := <-self.peerSlots:
+ go self.connectInboundPeer(listener, slot)
+ case errc := <-self.quit:
+ listener.Close()
+ fmt.Println("quit listenloop")
+ errc <- true
+ return
+ }
+ }
+}
+
+// main loop for adding outbound peers based on peerConnect address pool
+// this same loop handles peer disconnect requests as well
+func (self *Server) outboundPeerHandler(dialer Dialer) {
+ // addressChan initially set to nil (only watches peerConnect if we need more peers)
+ var addressChan chan net.Addr
+ slots := self.peerSlots
+ var slot *int
+ for {
+ select {
+ case i := <-slots:
+ // we need a peer in slot i, slot reserved
+ slot = &i
+ // now we can watch for candidate peers in the next loop
+ addressChan = self.peerConnect
+ // do not consume more until candidate peer is found
+ slots = nil
+ case address := <-addressChan:
+ // candidate peer found, will dial out asyncronously
+ // if connection fails slot will be released
+ go self.connectOutboundPeer(dialer, address, *slot)
+ // we can watch if more peers needed in the next loop
+ slots = self.peerSlots
+ // until then we dont care about candidate peers
+ addressChan = nil
+ case request := <-self.peerDisconnect:
+ go self.removePeer(request)
+ case errc := <-self.quit:
+ if addressChan != nil && slot != nil {
+ self.peerSlots <- *slot
+ }
+ fmt.Println("quit dialloop")
+ errc <- true
+ return
+ }
+ }
+}
+
+// check if peer address already connected
+func (self *Server) connected(address net.Addr) (err error) {
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ // fmt.Printf("address: %v\n", address)
+ slot, found := self.peersTable[address.String()]
+ if found {
+ err = fmt.Errorf("already connected as peer %v (%v)", slot, address)
+ }
+ return
+}
+
+// connect to peer via listener.Accept()
+func (self *Server) connectInboundPeer(listener net.Listener, slot int) {
+ var address net.Addr
+ conn, err := listener.Accept()
+ if err == nil {
+ address = conn.RemoteAddr()
+ err = self.connected(address)
+ if err != nil {
+ conn.Close()
+ }
+ }
+ if err != nil {
+ logger.Debugln(err)
+ self.peerSlots <- slot
+ } else {
+ fmt.Printf("adding %v\n", address)
+ go self.addPeer(conn, address, true, slot)
+ }
+}
+
+// connect to peer via dial out
+func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) {
+ var conn net.Conn
+ err := self.connected(address)
+ if err == nil {
+ conn, err = dialer.Dial(address.Network(), address.String())
+ }
+ if err != nil {
+ logger.Debugln(err)
+ self.peerSlots <- slot
+ } else {
+ go self.addPeer(conn, address, false, slot)
+ }
+}
+
+// creates the new peer object and inserts it into its slot
+func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) {
+ self.peersLock.Lock()
+ defer self.peersLock.Unlock()
+ if self.closed {
+ fmt.Println("oopsy, not no longer need peer")
+ conn.Close() //oopsy our bad
+ self.peerSlots <- slot // release slot
+ } else {
+ peer := NewPeer(conn, address, inbound, self)
+ self.peers[slot] = peer
+ self.peersTable[address.String()] = slot
+ self.peerCount++
+ // reset peersmsg
+ self.peersMsg = nil
+ fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot)
+ peer.Start()
+ }
+}
+
+// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
+func (self *Server) removePeer(request DisconnectRequest) {
+ self.peersLock.Lock()
+
+ address := request.addr
+ slot := self.peersTable[address.String()]
+ peer := self.peers[slot]
+ fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot)
+ if peer == nil {
+ logger.Debugf("already removed peer on %v", address)
+ self.peersLock.Unlock()
+ return
+ }
+ // remove from list and index
+ self.peerCount--
+ self.peers[slot] = nil
+ delete(self.peersTable, address.String())
+ // reset peersmsg
+ self.peersMsg = nil
+ fmt.Printf("removed peer %v (slot %v)\n", peer, slot)
+ self.peersLock.Unlock()
+
+ // sending disconnect message
+ disconnectMsg, _ := NewMsg(DiscMsg, request.reason)
+ peer.Write("", disconnectMsg)
+ // be nice and wait
+ time.Sleep(disconnectGracePeriod * time.Second)
+ // switch off peer and close connections etc.
+ fmt.Println("stopping peer")
+ peer.Stop()
+ fmt.Println("stopped peer")
+ // release slot to signal need for a new peer, last!
+ self.peerSlots <- slot
+}
+
+// fix handshake message to push to peers
+func (self *Server) Handshake() *Msg {
+ fmt.Println(self.identity.Pubkey()[1:])
+ msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:])
+ return msg
+}
+
+func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error {
+ // Check for blacklisting
+ if self.blacklist.Exists(pubkey) {
+ return fmt.Errorf("blacklisted")
+ }
+
+ self.peersLock.RLock()
+ defer self.peersLock.RUnlock()
+ for _, peer := range self.peers {
+ if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 {
+ return fmt.Errorf("already connected")
+ }
+ }
+ candidate.Pubkey = pubkey
+ return nil
+}
diff --git a/p2p/server_test.go b/p2p/server_test.go
new file mode 100644
index 0000000000..f749cc4908
--- /dev/null
+++ b/p2p/server_test.go
@@ -0,0 +1,208 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+)
+
+type TestNetwork struct {
+ connections map[string]*TestNetworkConnection
+ dialer Dialer
+ maxinbound int
+}
+
+func NewTestNetwork(maxinbound int) *TestNetwork {
+ connections := make(map[string]*TestNetworkConnection)
+ return &TestNetwork{
+ connections: connections,
+ dialer: &TestDialer{connections},
+ maxinbound: maxinbound,
+ }
+}
+
+func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) {
+ return self.dialer, nil
+}
+
+func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) {
+ return &TestListener{
+ connections: self.connections,
+ addr: addr,
+ max: self.maxinbound,
+ }, nil
+}
+
+func (self *TestNetwork) Start() error {
+ return nil
+}
+
+func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) {
+ return
+}
+
+func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) {
+ return
+}
+
+type TestAddr struct {
+ name string
+}
+
+func (self *TestAddr) String() string {
+ return self.name
+}
+
+func (*TestAddr) Network() string {
+ return "test"
+}
+
+type TestDialer struct {
+ connections map[string]*TestNetworkConnection
+}
+
+func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) {
+ address := &TestAddr{addr}
+ tconn := NewTestNetworkConnection(address)
+ self.connections[addr] = tconn
+ conn = net.Conn(tconn)
+ return
+}
+
+type TestListener struct {
+ connections map[string]*TestNetworkConnection
+ addr net.Addr
+ max int
+ i int
+}
+
+func (self *TestListener) Accept() (conn net.Conn, err error) {
+ self.i++
+ if self.i > self.max {
+ err = fmt.Errorf("no more")
+ } else {
+ addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)}
+ tconn := NewTestNetworkConnection(addr)
+ key := tconn.RemoteAddr().String()
+ self.connections[key] = tconn
+ conn = net.Conn(tconn)
+ fmt.Printf("accepted connection from: %v \n", addr)
+ }
+ return
+}
+
+func (self *TestListener) Close() error {
+ return nil
+}
+
+func (self *TestListener) Addr() net.Addr {
+ return self.addr
+}
+
+func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) {
+ network = NewTestNetwork(1)
+ addr := &TestAddr{"test:30303"}
+ identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey")
+ maxPeers := 2
+ if handlers == nil {
+ handlers = make(Handlers)
+ }
+ blackist := NewBlacklist()
+ server = New(network, addr, identity, handlers, maxPeers, blackist)
+ fmt.Println(server.identity.Pubkey())
+ return
+}
+
+func TestServerListener(t *testing.T) {
+ network, server := SetupTestServer(nil)
+ server.Start(true, false)
+ time.Sleep(10 * time.Millisecond)
+ server.Stop()
+ peer1, ok := network.connections["inboundpeer-1"]
+ if !ok {
+ t.Error("not found inbound peer 1")
+ } else {
+ fmt.Printf("out: %v\n", peer1.Out)
+ if len(peer1.Out) != 2 {
+ t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
+ }
+ }
+
+}
+
+func TestServerDialer(t *testing.T) {
+ network, server := SetupTestServer(nil)
+ server.Start(false, true)
+ server.peerConnect <- &TestAddr{"outboundpeer-1"}
+ time.Sleep(10 * time.Millisecond)
+ server.Stop()
+ peer1, ok := network.connections["outboundpeer-1"]
+ if !ok {
+ t.Error("not found outbound peer 1")
+ } else {
+ fmt.Printf("out: %v\n", peer1.Out)
+ if len(peer1.Out) != 2 {
+ t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
+ }
+ }
+}
+
+func TestServerBroadcast(t *testing.T) {
+ handlers := make(Handlers)
+ testProtocol := &TestProtocol{Msgs: []*Msg{}}
+ handlers["aaa"] = func(p *Peer) Protocol { return testProtocol }
+ network, server := SetupTestServer(handlers)
+ server.Start(true, true)
+ server.peerConnect <- &TestAddr{"outboundpeer-1"}
+ time.Sleep(10 * time.Millisecond)
+ msg, _ := NewMsg(0)
+ server.Broadcast("", msg)
+ packet := Packet(0, 0)
+ time.Sleep(10 * time.Millisecond)
+ server.Stop()
+ peer1, ok := network.connections["outboundpeer-1"]
+ if !ok {
+ t.Error("not found outbound peer 1")
+ } else {
+ fmt.Printf("out: %v\n", peer1.Out)
+ if len(peer1.Out) != 3 {
+ t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out))
+ } else {
+ if bytes.Compare(peer1.Out[1], packet) != 0 {
+ t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet)
+ }
+ }
+ }
+ peer2, ok := network.connections["inboundpeer-1"]
+ if !ok {
+ t.Error("not found inbound peer 2")
+ } else {
+ fmt.Printf("out: %v\n", peer2.Out)
+ if len(peer1.Out) != 3 {
+ t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out))
+ } else {
+ if bytes.Compare(peer2.Out[1], packet) != 0 {
+ t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet)
+ }
+ }
+ }
+}
+
+func TestServerPeersMessage(t *testing.T) {
+ handlers := make(Handlers)
+ _, server := SetupTestServer(handlers)
+ server.Start(true, true)
+ defer server.Stop()
+ server.peerConnect <- &TestAddr{"outboundpeer-1"}
+ time.Sleep(10 * time.Millisecond)
+ peersMsg, err := server.PeersMessage()
+ fmt.Println(peersMsg)
+ if err != nil {
+ t.Errorf("expect no error, got %v", err)
+ }
+ if c := server.PeerCount(); c != 2 {
+ t.Errorf("expect 2 peers, got %v", c)
+ }
+}