Merge branch 'develop'
This commit is contained in:
commit
2f2dd80e48
|
@ -15,6 +15,10 @@
|
|||
"Comment": "1.2.0-95-g9b2bd2b",
|
||||
"Rev": "9b2bd2b3489748d4d0a204fa4eb2ee9e89e0ebc6"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/davecgh/go-spew/spew",
|
||||
"Rev": "3e6e67c4dcea3ac2f25fd4731abc0e1deaf36216"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/ethereum/ethash",
|
||||
"Comment": "v23.1-206-gf0e6321",
|
||||
|
|
|
@ -0,0 +1,450 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// ptrSize is the size of a pointer on the current arch.
|
||||
ptrSize = unsafe.Sizeof((*byte)(nil))
|
||||
)
|
||||
|
||||
var (
|
||||
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
|
||||
// internal reflect.Value fields. These values are valid before golang
|
||||
// commit ecccf07e7f9d which changed the format. The are also valid
|
||||
// after commit 82f48826c6c7 which changed the format again to mirror
|
||||
// the original format. Code in the init function updates these offsets
|
||||
// as necessary.
|
||||
offsetPtr = uintptr(ptrSize)
|
||||
offsetScalar = uintptr(0)
|
||||
offsetFlag = uintptr(ptrSize * 2)
|
||||
|
||||
// flagKindWidth and flagKindShift indicate various bits that the
|
||||
// reflect package uses internally to track kind information.
|
||||
//
|
||||
// flagRO indicates whether or not the value field of a reflect.Value is
|
||||
// read-only.
|
||||
//
|
||||
// flagIndir indicates whether the value field of a reflect.Value is
|
||||
// the actual data or a pointer to the data.
|
||||
//
|
||||
// These values are valid before golang commit 90a7c3c86944 which
|
||||
// changed their positions. Code in the init function updates these
|
||||
// flags as necessary.
|
||||
flagKindWidth = uintptr(5)
|
||||
flagKindShift = uintptr(flagKindWidth - 1)
|
||||
flagRO = uintptr(1 << 0)
|
||||
flagIndir = uintptr(1 << 1)
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Older versions of reflect.Value stored small integers directly in the
|
||||
// ptr field (which is named val in the older versions). Versions
|
||||
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
|
||||
// scalar for this purpose which unfortunately came before the flag
|
||||
// field, so the offset of the flag field is different for those
|
||||
// versions.
|
||||
//
|
||||
// This code constructs a new reflect.Value from a known small integer
|
||||
// and checks if the size of the reflect.Value struct indicates it has
|
||||
// the scalar field. When it does, the offsets are updated accordingly.
|
||||
vv := reflect.ValueOf(0xf00)
|
||||
if unsafe.Sizeof(vv) == (ptrSize * 4) {
|
||||
offsetScalar = ptrSize * 2
|
||||
offsetFlag = ptrSize * 3
|
||||
}
|
||||
|
||||
// Commit 90a7c3c86944 changed the flag positions such that the low
|
||||
// order bits are the kind. This code extracts the kind from the flags
|
||||
// field and ensures it's the correct type. When it's not, the flag
|
||||
// order has been changed to the newer format, so the flags are updated
|
||||
// accordingly.
|
||||
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
|
||||
upfv := *(*uintptr)(upf)
|
||||
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
|
||||
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
|
||||
flagKindShift = 0
|
||||
flagRO = 1 << 5
|
||||
flagIndir = 1 << 6
|
||||
}
|
||||
}
|
||||
|
||||
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
|
||||
// the typical safety restrictions preventing access to unaddressable and
|
||||
// unexported data. It works by digging the raw pointer to the underlying
|
||||
// value out of the protected value and generating a new unprotected (unsafe)
|
||||
// reflect.Value to it.
|
||||
//
|
||||
// This allows us to check for implementations of the Stringer and error
|
||||
// interfaces to be used for pretty printing ordinarily unaddressable and
|
||||
// inaccessible values such as unexported struct fields.
|
||||
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
|
||||
indirects := 1
|
||||
vt := v.Type()
|
||||
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
|
||||
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
|
||||
if rvf&flagIndir != 0 {
|
||||
vt = reflect.PtrTo(v.Type())
|
||||
indirects++
|
||||
} else if offsetScalar != 0 {
|
||||
// The value is in the scalar field when it's not one of the
|
||||
// reference types.
|
||||
switch vt.Kind() {
|
||||
case reflect.Uintptr:
|
||||
case reflect.Chan:
|
||||
case reflect.Func:
|
||||
case reflect.Map:
|
||||
case reflect.Ptr:
|
||||
case reflect.UnsafePointer:
|
||||
default:
|
||||
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
|
||||
offsetScalar)
|
||||
}
|
||||
}
|
||||
|
||||
pv := reflect.NewAt(vt, upv)
|
||||
rv = pv
|
||||
for i := 0; i < indirects; i++ {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
// Some constants in the form of bytes to avoid string overhead. This mirrors
|
||||
// the technique used in the fmt package.
|
||||
var (
|
||||
panicBytes = []byte("(PANIC=")
|
||||
plusBytes = []byte("+")
|
||||
iBytes = []byte("i")
|
||||
trueBytes = []byte("true")
|
||||
falseBytes = []byte("false")
|
||||
interfaceBytes = []byte("(interface {})")
|
||||
commaNewlineBytes = []byte(",\n")
|
||||
newlineBytes = []byte("\n")
|
||||
openBraceBytes = []byte("{")
|
||||
openBraceNewlineBytes = []byte("{\n")
|
||||
closeBraceBytes = []byte("}")
|
||||
asteriskBytes = []byte("*")
|
||||
colonBytes = []byte(":")
|
||||
colonSpaceBytes = []byte(": ")
|
||||
openParenBytes = []byte("(")
|
||||
closeParenBytes = []byte(")")
|
||||
spaceBytes = []byte(" ")
|
||||
pointerChainBytes = []byte("->")
|
||||
nilAngleBytes = []byte("<nil>")
|
||||
maxNewlineBytes = []byte("<max depth reached>\n")
|
||||
maxShortBytes = []byte("<max>")
|
||||
circularBytes = []byte("<already shown>")
|
||||
circularShortBytes = []byte("<shown>")
|
||||
invalidAngleBytes = []byte("<invalid>")
|
||||
openBracketBytes = []byte("[")
|
||||
closeBracketBytes = []byte("]")
|
||||
percentBytes = []byte("%")
|
||||
precisionBytes = []byte(".")
|
||||
openAngleBytes = []byte("<")
|
||||
closeAngleBytes = []byte(">")
|
||||
openMapBytes = []byte("map[")
|
||||
closeMapBytes = []byte("]")
|
||||
lenEqualsBytes = []byte("len=")
|
||||
capEqualsBytes = []byte("cap=")
|
||||
)
|
||||
|
||||
// hexDigits is used to map a decimal value to a hex digit.
|
||||
var hexDigits = "0123456789abcdef"
|
||||
|
||||
// catchPanic handles any panics that might occur during the handleMethods
|
||||
// calls.
|
||||
func catchPanic(w io.Writer, v reflect.Value) {
|
||||
if err := recover(); err != nil {
|
||||
w.Write(panicBytes)
|
||||
fmt.Fprintf(w, "%v", err)
|
||||
w.Write(closeParenBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// handleMethods attempts to call the Error and String methods on the underlying
|
||||
// type the passed reflect.Value represents and outputes the result to Writer w.
|
||||
//
|
||||
// It handles panics in any called methods by catching and displaying the error
|
||||
// as the formatted value.
|
||||
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
|
||||
// We need an interface to check if the type implements the error or
|
||||
// Stringer interface. However, the reflect package won't give us an
|
||||
// interface on certain things like unexported struct fields in order
|
||||
// to enforce visibility rules. We use unsafe to bypass these restrictions
|
||||
// since this package does not mutate the values.
|
||||
if !v.CanInterface() {
|
||||
v = unsafeReflectValue(v)
|
||||
}
|
||||
|
||||
// Choose whether or not to do error and Stringer interface lookups against
|
||||
// the base type or a pointer to the base type depending on settings.
|
||||
// Technically calling one of these methods with a pointer receiver can
|
||||
// mutate the value, however, types which choose to satisify an error or
|
||||
// Stringer interface with a pointer receiver should not be mutating their
|
||||
// state inside these interface methods.
|
||||
var viface interface{}
|
||||
if !cs.DisablePointerMethods {
|
||||
if !v.CanAddr() {
|
||||
v = unsafeReflectValue(v)
|
||||
}
|
||||
viface = v.Addr().Interface()
|
||||
} else {
|
||||
if v.CanAddr() {
|
||||
v = v.Addr()
|
||||
}
|
||||
viface = v.Interface()
|
||||
}
|
||||
|
||||
// Is it an error or Stringer?
|
||||
switch iface := viface.(type) {
|
||||
case error:
|
||||
defer catchPanic(w, v)
|
||||
if cs.ContinueOnMethod {
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(iface.Error()))
|
||||
w.Write(closeParenBytes)
|
||||
w.Write(spaceBytes)
|
||||
return false
|
||||
}
|
||||
|
||||
w.Write([]byte(iface.Error()))
|
||||
return true
|
||||
|
||||
case fmt.Stringer:
|
||||
defer catchPanic(w, v)
|
||||
if cs.ContinueOnMethod {
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(iface.String()))
|
||||
w.Write(closeParenBytes)
|
||||
w.Write(spaceBytes)
|
||||
return false
|
||||
}
|
||||
w.Write([]byte(iface.String()))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// printBool outputs a boolean value as true or false to Writer w.
|
||||
func printBool(w io.Writer, val bool) {
|
||||
if val {
|
||||
w.Write(trueBytes)
|
||||
} else {
|
||||
w.Write(falseBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// printInt outputs a signed integer value to Writer w.
|
||||
func printInt(w io.Writer, val int64, base int) {
|
||||
w.Write([]byte(strconv.FormatInt(val, base)))
|
||||
}
|
||||
|
||||
// printUint outputs an unsigned integer value to Writer w.
|
||||
func printUint(w io.Writer, val uint64, base int) {
|
||||
w.Write([]byte(strconv.FormatUint(val, base)))
|
||||
}
|
||||
|
||||
// printFloat outputs a floating point value using the specified precision,
|
||||
// which is expected to be 32 or 64bit, to Writer w.
|
||||
func printFloat(w io.Writer, val float64, precision int) {
|
||||
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
|
||||
}
|
||||
|
||||
// printComplex outputs a complex value using the specified float precision
|
||||
// for the real and imaginary parts to Writer w.
|
||||
func printComplex(w io.Writer, c complex128, floatPrecision int) {
|
||||
r := real(c)
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
|
||||
i := imag(c)
|
||||
if i >= 0 {
|
||||
w.Write(plusBytes)
|
||||
}
|
||||
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
|
||||
w.Write(iBytes)
|
||||
w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
|
||||
// prefix to Writer w.
|
||||
func printHexPtr(w io.Writer, p uintptr) {
|
||||
// Null pointer.
|
||||
num := uint64(p)
|
||||
if num == 0 {
|
||||
w.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
|
||||
buf := make([]byte, 18)
|
||||
|
||||
// It's simpler to construct the hex string right to left.
|
||||
base := uint64(16)
|
||||
i := len(buf) - 1
|
||||
for num >= base {
|
||||
buf[i] = hexDigits[num%base]
|
||||
num /= base
|
||||
i--
|
||||
}
|
||||
buf[i] = hexDigits[num]
|
||||
|
||||
// Add '0x' prefix.
|
||||
i--
|
||||
buf[i] = 'x'
|
||||
i--
|
||||
buf[i] = '0'
|
||||
|
||||
// Strip unused leading bytes.
|
||||
buf = buf[i:]
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
|
||||
// elements to be sorted.
|
||||
type valuesSorter struct {
|
||||
values []reflect.Value
|
||||
strings []string // either nil or same len and values
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// newValuesSorter initializes a valuesSorter instance, which holds a set of
|
||||
// surrogate keys on which the data should be sorted. It uses flags in
|
||||
// ConfigState to decide if and how to populate those surrogate keys.
|
||||
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
|
||||
vs := &valuesSorter{values: values, cs: cs}
|
||||
if canSortSimply(vs.values[0].Kind()) {
|
||||
return vs
|
||||
}
|
||||
if !cs.DisableMethods {
|
||||
vs.strings = make([]string, len(values))
|
||||
for i := range vs.values {
|
||||
b := bytes.Buffer{}
|
||||
if !handleMethods(cs, &b, vs.values[i]) {
|
||||
vs.strings = nil
|
||||
break
|
||||
}
|
||||
vs.strings[i] = b.String()
|
||||
}
|
||||
}
|
||||
if vs.strings == nil && cs.SpewKeys {
|
||||
vs.strings = make([]string, len(values))
|
||||
for i := range vs.values {
|
||||
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
|
||||
}
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
|
||||
// directly, or whether it should be considered for sorting by surrogate keys
|
||||
// (if the ConfigState allows it).
|
||||
func canSortSimply(kind reflect.Kind) bool {
|
||||
// This switch parallels valueSortLess, except for the default case.
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
return true
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return true
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
case reflect.String:
|
||||
return true
|
||||
case reflect.Uintptr:
|
||||
return true
|
||||
case reflect.Array:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Len returns the number of values in the slice. It is part of the
|
||||
// sort.Interface implementation.
|
||||
func (s *valuesSorter) Len() int {
|
||||
return len(s.values)
|
||||
}
|
||||
|
||||
// Swap swaps the values at the passed indices. It is part of the
|
||||
// sort.Interface implementation.
|
||||
func (s *valuesSorter) Swap(i, j int) {
|
||||
s.values[i], s.values[j] = s.values[j], s.values[i]
|
||||
if s.strings != nil {
|
||||
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
|
||||
}
|
||||
}
|
||||
|
||||
// valueSortLess returns whether the first value should sort before the second
|
||||
// value. It is used by valueSorter.Less as part of the sort.Interface
|
||||
// implementation.
|
||||
func valueSortLess(a, b reflect.Value) bool {
|
||||
switch a.Kind() {
|
||||
case reflect.Bool:
|
||||
return !a.Bool() && b.Bool()
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return a.Int() < b.Int()
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return a.Uint() < b.Uint()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return a.Float() < b.Float()
|
||||
case reflect.String:
|
||||
return a.String() < b.String()
|
||||
case reflect.Uintptr:
|
||||
return a.Uint() < b.Uint()
|
||||
case reflect.Array:
|
||||
// Compare the contents of both arrays.
|
||||
l := a.Len()
|
||||
for i := 0; i < l; i++ {
|
||||
av := a.Index(i)
|
||||
bv := b.Index(i)
|
||||
if av.Interface() == bv.Interface() {
|
||||
continue
|
||||
}
|
||||
return valueSortLess(av, bv)
|
||||
}
|
||||
}
|
||||
return a.String() < b.String()
|
||||
}
|
||||
|
||||
// Less returns whether the value at index i should sort before the
|
||||
// value at index j. It is part of the sort.Interface implementation.
|
||||
func (s *valuesSorter) Less(i, j int) bool {
|
||||
if s.strings == nil {
|
||||
return valueSortLess(s.values[i], s.values[j])
|
||||
}
|
||||
return s.strings[i] < s.strings[j]
|
||||
}
|
||||
|
||||
// sortValues is a sort function that handles both native types and any type that
|
||||
// can be converted to error or Stringer. Other inputs are sorted according to
|
||||
// their Value.String() value to ensure display stability.
|
||||
func sortValues(values []reflect.Value, cs *ConfigState) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
sort.Sort(newValuesSorter(values, cs))
|
||||
}
|
298
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/common_test.go
generated
vendored
Normal file
298
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/common_test.go
generated
vendored
Normal file
|
@ -0,0 +1,298 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
// custom type to test Stinger interface on non-pointer receiver.
|
||||
type stringer string
|
||||
|
||||
// String implements the Stringer interface for testing invocation of custom
|
||||
// stringers on types with non-pointer receivers.
|
||||
func (s stringer) String() string {
|
||||
return "stringer " + string(s)
|
||||
}
|
||||
|
||||
// custom type to test Stinger interface on pointer receiver.
|
||||
type pstringer string
|
||||
|
||||
// String implements the Stringer interface for testing invocation of custom
|
||||
// stringers on types with only pointer receivers.
|
||||
func (s *pstringer) String() string {
|
||||
return "stringer " + string(*s)
|
||||
}
|
||||
|
||||
// xref1 and xref2 are cross referencing structs for testing circular reference
|
||||
// detection.
|
||||
type xref1 struct {
|
||||
ps2 *xref2
|
||||
}
|
||||
type xref2 struct {
|
||||
ps1 *xref1
|
||||
}
|
||||
|
||||
// indirCir1, indirCir2, and indirCir3 are used to generate an indirect circular
|
||||
// reference for testing detection.
|
||||
type indirCir1 struct {
|
||||
ps2 *indirCir2
|
||||
}
|
||||
type indirCir2 struct {
|
||||
ps3 *indirCir3
|
||||
}
|
||||
type indirCir3 struct {
|
||||
ps1 *indirCir1
|
||||
}
|
||||
|
||||
// embed is used to test embedded structures.
|
||||
type embed struct {
|
||||
a string
|
||||
}
|
||||
|
||||
// embedwrap is used to test embedded structures.
|
||||
type embedwrap struct {
|
||||
*embed
|
||||
e *embed
|
||||
}
|
||||
|
||||
// panicer is used to intentionally cause a panic for testing spew properly
|
||||
// handles them
|
||||
type panicer int
|
||||
|
||||
func (p panicer) String() string {
|
||||
panic("test panic")
|
||||
}
|
||||
|
||||
// customError is used to test custom error interface invocation.
|
||||
type customError int
|
||||
|
||||
func (e customError) Error() string {
|
||||
return fmt.Sprintf("error: %d", int(e))
|
||||
}
|
||||
|
||||
// stringizeWants converts a slice of wanted test output into a format suitable
|
||||
// for a test error message.
|
||||
func stringizeWants(wants []string) string {
|
||||
s := ""
|
||||
for i, want := range wants {
|
||||
if i > 0 {
|
||||
s += fmt.Sprintf("want%d: %s", i+1, want)
|
||||
} else {
|
||||
s += "want: " + want
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// testFailed returns whether or not a test failed by checking if the result
|
||||
// of the test is in the slice of wanted strings.
|
||||
func testFailed(result string, wants []string) bool {
|
||||
for _, want := range wants {
|
||||
if result == want {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type sortableStruct struct {
|
||||
x int
|
||||
}
|
||||
|
||||
func (ss sortableStruct) String() string {
|
||||
return fmt.Sprintf("ss.%d", ss.x)
|
||||
}
|
||||
|
||||
type unsortableStruct struct {
|
||||
x int
|
||||
}
|
||||
|
||||
type sortTestCase struct {
|
||||
input []reflect.Value
|
||||
expected []reflect.Value
|
||||
}
|
||||
|
||||
func helpTestSortValues(tests []sortTestCase, cs *spew.ConfigState, t *testing.T) {
|
||||
getInterfaces := func(values []reflect.Value) []interface{} {
|
||||
interfaces := []interface{}{}
|
||||
for _, v := range values {
|
||||
interfaces = append(interfaces, v.Interface())
|
||||
}
|
||||
return interfaces
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
spew.SortValues(test.input, cs)
|
||||
// reflect.DeepEqual cannot really make sense of reflect.Value,
|
||||
// probably because of all the pointer tricks. For instance,
|
||||
// v(2.0) != v(2.0) on a 32-bits system. Turn them into interface{}
|
||||
// instead.
|
||||
input := getInterfaces(test.input)
|
||||
expected := getInterfaces(test.expected)
|
||||
if !reflect.DeepEqual(input, expected) {
|
||||
t.Errorf("Sort mismatch:\n %v != %v", input, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSortValues ensures the sort functionality for relect.Value based sorting
|
||||
// works as intended.
|
||||
func TestSortValues(t *testing.T) {
|
||||
v := reflect.ValueOf
|
||||
|
||||
a := v("a")
|
||||
b := v("b")
|
||||
c := v("c")
|
||||
embedA := v(embed{"a"})
|
||||
embedB := v(embed{"b"})
|
||||
embedC := v(embed{"c"})
|
||||
tests := []sortTestCase{
|
||||
// No values.
|
||||
{
|
||||
[]reflect.Value{},
|
||||
[]reflect.Value{},
|
||||
},
|
||||
// Bools.
|
||||
{
|
||||
[]reflect.Value{v(false), v(true), v(false)},
|
||||
[]reflect.Value{v(false), v(false), v(true)},
|
||||
},
|
||||
// Ints.
|
||||
{
|
||||
[]reflect.Value{v(2), v(1), v(3)},
|
||||
[]reflect.Value{v(1), v(2), v(3)},
|
||||
},
|
||||
// Uints.
|
||||
{
|
||||
[]reflect.Value{v(uint8(2)), v(uint8(1)), v(uint8(3))},
|
||||
[]reflect.Value{v(uint8(1)), v(uint8(2)), v(uint8(3))},
|
||||
},
|
||||
// Floats.
|
||||
{
|
||||
[]reflect.Value{v(2.0), v(1.0), v(3.0)},
|
||||
[]reflect.Value{v(1.0), v(2.0), v(3.0)},
|
||||
},
|
||||
// Strings.
|
||||
{
|
||||
[]reflect.Value{b, a, c},
|
||||
[]reflect.Value{a, b, c},
|
||||
},
|
||||
// Array
|
||||
{
|
||||
[]reflect.Value{v([3]int{3, 2, 1}), v([3]int{1, 3, 2}), v([3]int{1, 2, 3})},
|
||||
[]reflect.Value{v([3]int{1, 2, 3}), v([3]int{1, 3, 2}), v([3]int{3, 2, 1})},
|
||||
},
|
||||
// Uintptrs.
|
||||
{
|
||||
[]reflect.Value{v(uintptr(2)), v(uintptr(1)), v(uintptr(3))},
|
||||
[]reflect.Value{v(uintptr(1)), v(uintptr(2)), v(uintptr(3))},
|
||||
},
|
||||
// SortableStructs.
|
||||
{
|
||||
// Note: not sorted - DisableMethods is set.
|
||||
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||
},
|
||||
// UnsortableStructs.
|
||||
{
|
||||
// Note: not sorted - SpewKeys is false.
|
||||
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||
},
|
||||
// Invalid.
|
||||
{
|
||||
[]reflect.Value{embedB, embedA, embedC},
|
||||
[]reflect.Value{embedB, embedA, embedC},
|
||||
},
|
||||
}
|
||||
cs := spew.ConfigState{DisableMethods: true, SpewKeys: false}
|
||||
helpTestSortValues(tests, &cs, t)
|
||||
}
|
||||
|
||||
// TestSortValuesWithMethods ensures the sort functionality for relect.Value
|
||||
// based sorting works as intended when using string methods.
|
||||
func TestSortValuesWithMethods(t *testing.T) {
|
||||
v := reflect.ValueOf
|
||||
|
||||
a := v("a")
|
||||
b := v("b")
|
||||
c := v("c")
|
||||
tests := []sortTestCase{
|
||||
// Ints.
|
||||
{
|
||||
[]reflect.Value{v(2), v(1), v(3)},
|
||||
[]reflect.Value{v(1), v(2), v(3)},
|
||||
},
|
||||
// Strings.
|
||||
{
|
||||
[]reflect.Value{b, a, c},
|
||||
[]reflect.Value{a, b, c},
|
||||
},
|
||||
// SortableStructs.
|
||||
{
|
||||
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||
[]reflect.Value{v(sortableStruct{1}), v(sortableStruct{2}), v(sortableStruct{3})},
|
||||
},
|
||||
// UnsortableStructs.
|
||||
{
|
||||
// Note: not sorted - SpewKeys is false.
|
||||
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||
},
|
||||
}
|
||||
cs := spew.ConfigState{DisableMethods: false, SpewKeys: false}
|
||||
helpTestSortValues(tests, &cs, t)
|
||||
}
|
||||
|
||||
// TestSortValuesWithSpew ensures the sort functionality for relect.Value
|
||||
// based sorting works as intended when using spew to stringify keys.
|
||||
func TestSortValuesWithSpew(t *testing.T) {
|
||||
v := reflect.ValueOf
|
||||
|
||||
a := v("a")
|
||||
b := v("b")
|
||||
c := v("c")
|
||||
tests := []sortTestCase{
|
||||
// Ints.
|
||||
{
|
||||
[]reflect.Value{v(2), v(1), v(3)},
|
||||
[]reflect.Value{v(1), v(2), v(3)},
|
||||
},
|
||||
// Strings.
|
||||
{
|
||||
[]reflect.Value{b, a, c},
|
||||
[]reflect.Value{a, b, c},
|
||||
},
|
||||
// SortableStructs.
|
||||
{
|
||||
[]reflect.Value{v(sortableStruct{2}), v(sortableStruct{1}), v(sortableStruct{3})},
|
||||
[]reflect.Value{v(sortableStruct{1}), v(sortableStruct{2}), v(sortableStruct{3})},
|
||||
},
|
||||
// UnsortableStructs.
|
||||
{
|
||||
[]reflect.Value{v(unsortableStruct{2}), v(unsortableStruct{1}), v(unsortableStruct{3})},
|
||||
[]reflect.Value{v(unsortableStruct{1}), v(unsortableStruct{2}), v(unsortableStruct{3})},
|
||||
},
|
||||
}
|
||||
cs := spew.ConfigState{DisableMethods: true, SpewKeys: true}
|
||||
helpTestSortValues(tests, &cs, t)
|
||||
}
|
|
@ -0,0 +1,294 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ConfigState houses the configuration options used by spew to format and
|
||||
// display values. There is a global instance, Config, that is used to control
|
||||
// all top-level Formatter and Dump functionality. Each ConfigState instance
|
||||
// provides methods equivalent to the top-level functions.
|
||||
//
|
||||
// The zero value for ConfigState provides no indentation. You would typically
|
||||
// want to set it to a space or a tab.
|
||||
//
|
||||
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
|
||||
// with default settings. See the documentation of NewDefaultConfig for default
|
||||
// values.
|
||||
type ConfigState struct {
|
||||
// Indent specifies the string to use for each indentation level. The
|
||||
// global config instance that all top-level functions use set this to a
|
||||
// single space by default. If you would like more indentation, you might
|
||||
// set this to a tab with "\t" or perhaps two spaces with " ".
|
||||
Indent string
|
||||
|
||||
// MaxDepth controls the maximum number of levels to descend into nested
|
||||
// data structures. The default, 0, means there is no limit.
|
||||
//
|
||||
// NOTE: Circular data structures are properly detected, so it is not
|
||||
// necessary to set this value unless you specifically want to limit deeply
|
||||
// nested data structures.
|
||||
MaxDepth int
|
||||
|
||||
// DisableMethods specifies whether or not error and Stringer interfaces are
|
||||
// invoked for types that implement them.
|
||||
DisableMethods bool
|
||||
|
||||
// DisablePointerMethods specifies whether or not to check for and invoke
|
||||
// error and Stringer interfaces on types which only accept a pointer
|
||||
// receiver when the current type is not a pointer.
|
||||
//
|
||||
// NOTE: This might be an unsafe action since calling one of these methods
|
||||
// with a pointer receiver could technically mutate the value, however,
|
||||
// in practice, types which choose to satisify an error or Stringer
|
||||
// interface with a pointer receiver should not be mutating their state
|
||||
// inside these interface methods.
|
||||
DisablePointerMethods bool
|
||||
|
||||
// ContinueOnMethod specifies whether or not recursion should continue once
|
||||
// a custom error or Stringer interface is invoked. The default, false,
|
||||
// means it will print the results of invoking the custom error or Stringer
|
||||
// interface and return immediately instead of continuing to recurse into
|
||||
// the internals of the data type.
|
||||
//
|
||||
// NOTE: This flag does not have any effect if method invocation is disabled
|
||||
// via the DisableMethods or DisablePointerMethods options.
|
||||
ContinueOnMethod bool
|
||||
|
||||
// SortKeys specifies map keys should be sorted before being printed. Use
|
||||
// this to have a more deterministic, diffable output. Note that only
|
||||
// native types (bool, int, uint, floats, uintptr and string) and types
|
||||
// that support the error or Stringer interfaces (if methods are
|
||||
// enabled) are supported, with other types sorted according to the
|
||||
// reflect.Value.String() output which guarantees display stability.
|
||||
SortKeys bool
|
||||
|
||||
// SpewKeys specifies that, as a last resort attempt, map keys should
|
||||
// be spewed to strings and sorted by those strings. This is only
|
||||
// considered if SortKeys is true.
|
||||
SpewKeys bool
|
||||
}
|
||||
|
||||
// Config is the active configuration of the top-level functions.
|
||||
// The configuration can be changed by modifying the contents of spew.Config.
|
||||
var Config = ConfigState{Indent: " "}
|
||||
|
||||
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the formatted string as a value that satisfies error. See NewFormatter
|
||||
// for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
|
||||
return fmt.Errorf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprint(w, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintf(w, format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||
// passed with a Formatter interface returned by c.NewFormatter. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintln(w, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
|
||||
return fmt.Print(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Printf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
|
||||
return fmt.Println(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprint(a ...interface{}) string {
|
||||
return fmt.Sprint(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
|
||||
return fmt.Sprintf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||
// were passed with a Formatter interface returned by c.NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprintln(a ...interface{}) string {
|
||||
return fmt.Sprintln(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
/*
|
||||
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||
interface. As a result, it integrates cleanly with standard fmt package
|
||||
printing functions. The formatter is useful for inline printing of smaller data
|
||||
types similar to the standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Typically this function shouldn't be called directly. It is much easier to make
|
||||
use of the custom formatter by calling one of the convenience functions such as
|
||||
c.Printf, c.Println, or c.Printf.
|
||||
*/
|
||||
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
|
||||
return newFormatter(c, v)
|
||||
}
|
||||
|
||||
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||
// exactly the same as Dump.
|
||||
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
|
||||
fdump(c, w, a...)
|
||||
}
|
||||
|
||||
/*
|
||||
Dump displays the passed parameters to standard out with newlines, customizable
|
||||
indentation, and additional debug information such as complete types and all
|
||||
pointer addresses used to indirect to the final value. It provides the
|
||||
following features over the built-in printing facilities provided by the fmt
|
||||
package:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output
|
||||
|
||||
The configuration options are controlled by modifying the public members
|
||||
of c. See ConfigState for options documentation.
|
||||
|
||||
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||
get the formatted result as a string.
|
||||
*/
|
||||
func (c *ConfigState) Dump(a ...interface{}) {
|
||||
fdump(c, os.Stdout, a...)
|
||||
}
|
||||
|
||||
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||
// as Dump.
|
||||
func (c *ConfigState) Sdump(a ...interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
fdump(c, &buf, a...)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||
// length with each argument converted to a spew Formatter interface using
|
||||
// the ConfigState associated with s.
|
||||
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
|
||||
formatters = make([]interface{}, len(args))
|
||||
for index, arg := range args {
|
||||
formatters[index] = newFormatter(c, arg)
|
||||
}
|
||||
return formatters
|
||||
}
|
||||
|
||||
// NewDefaultConfig returns a ConfigState with the following default settings.
|
||||
//
|
||||
// Indent: " "
|
||||
// MaxDepth: 0
|
||||
// DisableMethods: false
|
||||
// DisablePointerMethods: false
|
||||
// ContinueOnMethod: false
|
||||
// SortKeys: false
|
||||
func NewDefaultConfig() *ConfigState {
|
||||
return &ConfigState{Indent: " "}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
/*
|
||||
Package spew implements a deep pretty printer for Go data structures to aid in
|
||||
debugging.
|
||||
|
||||
A quick overview of the additional features spew provides over the built-in
|
||||
printing facilities for Go data types are as follows:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output (only when using
|
||||
Dump style)
|
||||
|
||||
There are two different approaches spew allows for dumping Go data structures:
|
||||
|
||||
* Dump style which prints with newlines, customizable indentation,
|
||||
and additional debug information such as types and all pointer addresses
|
||||
used to indirect to the final value
|
||||
* A custom Formatter interface that integrates cleanly with the standard fmt
|
||||
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
|
||||
similar to the default %v while providing the additional functionality
|
||||
outlined above and passing unsupported format verbs such as %x and %q
|
||||
along to fmt
|
||||
|
||||
Quick Start
|
||||
|
||||
This section demonstrates how to quickly get started with spew. See the
|
||||
sections below for further details on formatting and configuration options.
|
||||
|
||||
To dump a variable with full newlines, indentation, type, and pointer
|
||||
information use Dump, Fdump, or Sdump:
|
||||
spew.Dump(myVar1, myVar2, ...)
|
||||
spew.Fdump(someWriter, myVar1, myVar2, ...)
|
||||
str := spew.Sdump(myVar1, myVar2, ...)
|
||||
|
||||
Alternatively, if you would prefer to use format strings with a compacted inline
|
||||
printing style, use the convenience wrappers Printf, Fprintf, etc with
|
||||
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
|
||||
%#+v (adds types and pointer addresses):
|
||||
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
|
||||
Configuration Options
|
||||
|
||||
Configuration of spew is handled by fields in the ConfigState type. For
|
||||
convenience, all of the top-level functions use a global state available
|
||||
via the spew.Config global.
|
||||
|
||||
It is also possible to create a ConfigState instance that provides methods
|
||||
equivalent to the top-level functions. This allows concurrent configuration
|
||||
options. See the ConfigState documentation for more details.
|
||||
|
||||
The following configuration options are available:
|
||||
* Indent
|
||||
String to use for each indentation level for Dump functions.
|
||||
It is a single space by default. A popular alternative is "\t".
|
||||
|
||||
* MaxDepth
|
||||
Maximum number of levels to descend into nested data structures.
|
||||
There is no limit by default.
|
||||
|
||||
* DisableMethods
|
||||
Disables invocation of error and Stringer interface methods.
|
||||
Method invocation is enabled by default.
|
||||
|
||||
* DisablePointerMethods
|
||||
Disables invocation of error and Stringer interface methods on types
|
||||
which only accept pointer receivers from non-pointer variables.
|
||||
Pointer method invocation is enabled by default.
|
||||
|
||||
* ContinueOnMethod
|
||||
Enables recursion into types after invoking error and Stringer interface
|
||||
methods. Recursion after method invocation is disabled by default.
|
||||
|
||||
* SortKeys
|
||||
Specifies map keys should be sorted before being printed. Use
|
||||
this to have a more deterministic, diffable output. Note that
|
||||
only native types (bool, int, uint, floats, uintptr and string)
|
||||
and types which implement error or Stringer interfaces are
|
||||
supported with other types sorted according to the
|
||||
reflect.Value.String() output which guarantees display
|
||||
stability. Natural map order is used by default.
|
||||
|
||||
* SpewKeys
|
||||
Specifies that, as a last resort attempt, map keys should be
|
||||
spewed to strings and sorted by those strings. This is only
|
||||
considered if SortKeys is true.
|
||||
|
||||
Dump Usage
|
||||
|
||||
Simply call spew.Dump with a list of variables you want to dump:
|
||||
|
||||
spew.Dump(myVar1, myVar2, ...)
|
||||
|
||||
You may also call spew.Fdump if you would prefer to output to an arbitrary
|
||||
io.Writer. For example, to dump to standard error:
|
||||
|
||||
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
|
||||
|
||||
A third option is to call spew.Sdump to get the formatted output as a string:
|
||||
|
||||
str := spew.Sdump(myVar1, myVar2, ...)
|
||||
|
||||
Sample Dump Output
|
||||
|
||||
See the Dump example for details on the setup of the types and variables being
|
||||
shown here.
|
||||
|
||||
(main.Foo) {
|
||||
unexportedField: (*main.Bar)(0xf84002e210)({
|
||||
flag: (main.Flag) flagTwo,
|
||||
data: (uintptr) <nil>
|
||||
}),
|
||||
ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
(string) (len=3) "one": (bool) true
|
||||
}
|
||||
}
|
||||
|
||||
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
|
||||
command as shown.
|
||||
([]uint8) (len=32 cap=32) {
|
||||
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||
00000020 31 32 |12|
|
||||
}
|
||||
|
||||
Custom Formatter
|
||||
|
||||
Spew provides a custom formatter that implements the fmt.Formatter interface
|
||||
so that it integrates cleanly with standard fmt package printing functions. The
|
||||
formatter is useful for inline printing of smaller data types similar to the
|
||||
standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Custom Formatter Usage
|
||||
|
||||
The simplest way to make use of the spew custom formatter is to call one of the
|
||||
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
|
||||
functions have syntax you are most likely already familiar with:
|
||||
|
||||
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
spew.Println(myVar, myVar2)
|
||||
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
|
||||
See the Index for the full list convenience functions.
|
||||
|
||||
Sample Formatter Output
|
||||
|
||||
Double pointer to a uint8:
|
||||
%v: <**>5
|
||||
%+v: <**>(0xf8400420d0->0xf8400420c8)5
|
||||
%#v: (**uint8)5
|
||||
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
|
||||
|
||||
Pointer to circular struct with a uint8 field and a pointer to itself:
|
||||
%v: <*>{1 <*><shown>}
|
||||
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
|
||||
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
|
||||
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
|
||||
|
||||
See the Printf example for details on the setup of variables being shown
|
||||
here.
|
||||
|
||||
Errors
|
||||
|
||||
Since it is possible for custom Stringer/error interfaces to panic, spew
|
||||
detects them and handles them internally by printing the panic information
|
||||
inline with the output. Since spew is intended to provide deep pretty printing
|
||||
capabilities on structures, it intentionally does not return any errors.
|
||||
*/
|
||||
package spew
|
|
@ -0,0 +1,506 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// uint8Type is a reflect.Type representing a uint8. It is used to
|
||||
// convert cgo types to uint8 slices for hexdumping.
|
||||
uint8Type = reflect.TypeOf(uint8(0))
|
||||
|
||||
// cCharRE is a regular expression that matches a cgo char.
|
||||
// It is used to detect character arrays to hexdump them.
|
||||
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
|
||||
|
||||
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
|
||||
// char. It is used to detect unsigned character arrays to hexdump
|
||||
// them.
|
||||
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
|
||||
|
||||
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
|
||||
// It is used to detect uint8_t arrays to hexdump them.
|
||||
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
|
||||
)
|
||||
|
||||
// dumpState contains information about the state of a dump operation.
|
||||
type dumpState struct {
|
||||
w io.Writer
|
||||
depth int
|
||||
pointers map[uintptr]int
|
||||
ignoreNextType bool
|
||||
ignoreNextIndent bool
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// indent performs indentation according to the depth level and cs.Indent
|
||||
// option.
|
||||
func (d *dumpState) indent() {
|
||||
if d.ignoreNextIndent {
|
||||
d.ignoreNextIndent = false
|
||||
return
|
||||
}
|
||||
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
|
||||
}
|
||||
|
||||
// unpackValue returns values inside of non-nil interfaces when possible.
|
||||
// This is useful for data types like structs, arrays, slices, and maps which
|
||||
// can contain varying types packed inside an interface.
|
||||
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Interface && !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// dumpPtr handles formatting of pointers by indirecting them as necessary.
|
||||
func (d *dumpState) dumpPtr(v reflect.Value) {
|
||||
// Remove pointers at or below the current depth from map used to detect
|
||||
// circular refs.
|
||||
for k, depth := range d.pointers {
|
||||
if depth >= d.depth {
|
||||
delete(d.pointers, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep list of all dereferenced pointers to show later.
|
||||
pointerChain := make([]uintptr, 0)
|
||||
|
||||
// Figure out how many levels of indirection there are by dereferencing
|
||||
// pointers and unpacking interfaces down the chain while detecting circular
|
||||
// references.
|
||||
nilFound := false
|
||||
cycleFound := false
|
||||
indirects := 0
|
||||
ve := v
|
||||
for ve.Kind() == reflect.Ptr {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
indirects++
|
||||
addr := ve.Pointer()
|
||||
pointerChain = append(pointerChain, addr)
|
||||
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
|
||||
cycleFound = true
|
||||
indirects--
|
||||
break
|
||||
}
|
||||
d.pointers[addr] = d.depth
|
||||
|
||||
ve = ve.Elem()
|
||||
if ve.Kind() == reflect.Interface {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
ve = ve.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// Display type information.
|
||||
d.w.Write(openParenBytes)
|
||||
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||
d.w.Write([]byte(ve.Type().String()))
|
||||
d.w.Write(closeParenBytes)
|
||||
|
||||
// Display pointer information.
|
||||
if len(pointerChain) > 0 {
|
||||
d.w.Write(openParenBytes)
|
||||
for i, addr := range pointerChain {
|
||||
if i > 0 {
|
||||
d.w.Write(pointerChainBytes)
|
||||
}
|
||||
printHexPtr(d.w, addr)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// Display dereferenced value.
|
||||
d.w.Write(openParenBytes)
|
||||
switch {
|
||||
case nilFound == true:
|
||||
d.w.Write(nilAngleBytes)
|
||||
|
||||
case cycleFound == true:
|
||||
d.w.Write(circularBytes)
|
||||
|
||||
default:
|
||||
d.ignoreNextType = true
|
||||
d.dump(ve)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
|
||||
// reflection) arrays and slices are dumped in hexdump -C fashion.
|
||||
func (d *dumpState) dumpSlice(v reflect.Value) {
|
||||
// Determine whether this type should be hex dumped or not. Also,
|
||||
// for types which should be hexdumped, try to use the underlying data
|
||||
// first, then fall back to trying to convert them to a uint8 slice.
|
||||
var buf []uint8
|
||||
doConvert := false
|
||||
doHexDump := false
|
||||
numEntries := v.Len()
|
||||
if numEntries > 0 {
|
||||
vt := v.Index(0).Type()
|
||||
vts := vt.String()
|
||||
switch {
|
||||
// C types that need to be converted.
|
||||
case cCharRE.MatchString(vts):
|
||||
fallthrough
|
||||
case cUnsignedCharRE.MatchString(vts):
|
||||
fallthrough
|
||||
case cUint8tCharRE.MatchString(vts):
|
||||
doConvert = true
|
||||
|
||||
// Try to use existing uint8 slices and fall back to converting
|
||||
// and copying if that fails.
|
||||
case vt.Kind() == reflect.Uint8:
|
||||
// We need an addressable interface to convert the type back
|
||||
// into a byte slice. However, the reflect package won't give
|
||||
// us an interface on certain things like unexported struct
|
||||
// fields in order to enforce visibility rules. We use unsafe
|
||||
// to bypass these restrictions since this package does not
|
||||
// mutate the values.
|
||||
vs := v
|
||||
if !vs.CanInterface() || !vs.CanAddr() {
|
||||
vs = unsafeReflectValue(vs)
|
||||
}
|
||||
vs = vs.Slice(0, numEntries)
|
||||
|
||||
// Use the existing uint8 slice if it can be type
|
||||
// asserted.
|
||||
iface := vs.Interface()
|
||||
if slice, ok := iface.([]uint8); ok {
|
||||
buf = slice
|
||||
doHexDump = true
|
||||
break
|
||||
}
|
||||
|
||||
// The underlying data needs to be converted if it can't
|
||||
// be type asserted to a uint8 slice.
|
||||
doConvert = true
|
||||
}
|
||||
|
||||
// Copy and convert the underlying type if needed.
|
||||
if doConvert && vt.ConvertibleTo(uint8Type) {
|
||||
// Convert and copy each element into a uint8 byte
|
||||
// slice.
|
||||
buf = make([]uint8, numEntries)
|
||||
for i := 0; i < numEntries; i++ {
|
||||
vv := v.Index(i)
|
||||
buf[i] = uint8(vv.Convert(uint8Type).Uint())
|
||||
}
|
||||
doHexDump = true
|
||||
}
|
||||
}
|
||||
|
||||
// Hexdump the entire slice as needed.
|
||||
if doHexDump {
|
||||
indent := strings.Repeat(d.cs.Indent, d.depth)
|
||||
str := indent + hex.Dump(buf)
|
||||
str = strings.Replace(str, "\n", "\n"+indent, -1)
|
||||
str = strings.TrimRight(str, d.cs.Indent)
|
||||
d.w.Write([]byte(str))
|
||||
return
|
||||
}
|
||||
|
||||
// Recursively call dump for each item.
|
||||
for i := 0; i < numEntries; i++ {
|
||||
d.dump(d.unpackValue(v.Index(i)))
|
||||
if i < (numEntries - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dump is the main workhorse for dumping a value. It uses the passed reflect
|
||||
// value to figure out what kind of object we are dealing with and formats it
|
||||
// appropriately. It is a recursive function, however circular data structures
|
||||
// are detected and handled properly.
|
||||
func (d *dumpState) dump(v reflect.Value) {
|
||||
// Handle invalid reflect values immediately.
|
||||
kind := v.Kind()
|
||||
if kind == reflect.Invalid {
|
||||
d.w.Write(invalidAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle pointers specially.
|
||||
if kind == reflect.Ptr {
|
||||
d.indent()
|
||||
d.dumpPtr(v)
|
||||
return
|
||||
}
|
||||
|
||||
// Print type information unless already handled elsewhere.
|
||||
if !d.ignoreNextType {
|
||||
d.indent()
|
||||
d.w.Write(openParenBytes)
|
||||
d.w.Write([]byte(v.Type().String()))
|
||||
d.w.Write(closeParenBytes)
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
d.ignoreNextType = false
|
||||
|
||||
// Display length and capacity if the built-in len and cap functions
|
||||
// work with the value's kind and the len/cap itself is non-zero.
|
||||
valueLen, valueCap := 0, 0
|
||||
switch v.Kind() {
|
||||
case reflect.Array, reflect.Slice, reflect.Chan:
|
||||
valueLen, valueCap = v.Len(), v.Cap()
|
||||
case reflect.Map, reflect.String:
|
||||
valueLen = v.Len()
|
||||
}
|
||||
if valueLen != 0 || valueCap != 0 {
|
||||
d.w.Write(openParenBytes)
|
||||
if valueLen != 0 {
|
||||
d.w.Write(lenEqualsBytes)
|
||||
printInt(d.w, int64(valueLen), 10)
|
||||
}
|
||||
if valueCap != 0 {
|
||||
if valueLen != 0 {
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
d.w.Write(capEqualsBytes)
|
||||
printInt(d.w, int64(valueCap), 10)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
|
||||
// Call Stringer/error interfaces if they exist and the handle methods flag
|
||||
// is enabled
|
||||
if !d.cs.DisableMethods {
|
||||
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||
if handled := handleMethods(d.cs, d.w, v); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Invalid:
|
||||
// Do nothing. We should never get here since invalid has already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Bool:
|
||||
printBool(d.w, v.Bool())
|
||||
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
printInt(d.w, v.Int(), 10)
|
||||
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
printUint(d.w, v.Uint(), 10)
|
||||
|
||||
case reflect.Float32:
|
||||
printFloat(d.w, v.Float(), 32)
|
||||
|
||||
case reflect.Float64:
|
||||
printFloat(d.w, v.Float(), 64)
|
||||
|
||||
case reflect.Complex64:
|
||||
printComplex(d.w, v.Complex(), 32)
|
||||
|
||||
case reflect.Complex128:
|
||||
printComplex(d.w, v.Complex(), 64)
|
||||
|
||||
case reflect.Slice:
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Array:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
d.dumpSlice(v)
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.String:
|
||||
d.w.Write([]byte(strconv.Quote(v.String())))
|
||||
|
||||
case reflect.Interface:
|
||||
// The only time we should get here is for nil interfaces due to
|
||||
// unpackValue calls.
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// Do nothing. We should never get here since pointers have already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Map:
|
||||
// nil maps should be indicated as different than empty maps
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
numEntries := v.Len()
|
||||
keys := v.MapKeys()
|
||||
if d.cs.SortKeys {
|
||||
sortValues(keys, d.cs)
|
||||
}
|
||||
for i, key := range keys {
|
||||
d.dump(d.unpackValue(key))
|
||||
d.w.Write(colonSpaceBytes)
|
||||
d.ignoreNextIndent = true
|
||||
d.dump(d.unpackValue(v.MapIndex(key)))
|
||||
if i < (numEntries - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Struct:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
vt := v.Type()
|
||||
numFields := v.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
d.indent()
|
||||
vtf := vt.Field(i)
|
||||
d.w.Write([]byte(vtf.Name))
|
||||
d.w.Write(colonSpaceBytes)
|
||||
d.ignoreNextIndent = true
|
||||
d.dump(d.unpackValue(v.Field(i)))
|
||||
if i < (numFields - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Uintptr:
|
||||
printHexPtr(d.w, uintptr(v.Uint()))
|
||||
|
||||
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||
printHexPtr(d.w, v.Pointer())
|
||||
|
||||
// There were not any other types at the time this code was written, but
|
||||
// fall back to letting the default fmt package handle it in case any new
|
||||
// types are added.
|
||||
default:
|
||||
if v.CanInterface() {
|
||||
fmt.Fprintf(d.w, "%v", v.Interface())
|
||||
} else {
|
||||
fmt.Fprintf(d.w, "%v", v.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fdump is a helper function to consolidate the logic from the various public
|
||||
// methods which take varying writers and config states.
|
||||
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
|
||||
for _, arg := range a {
|
||||
if arg == nil {
|
||||
w.Write(interfaceBytes)
|
||||
w.Write(spaceBytes)
|
||||
w.Write(nilAngleBytes)
|
||||
w.Write(newlineBytes)
|
||||
continue
|
||||
}
|
||||
|
||||
d := dumpState{w: w, cs: cs}
|
||||
d.pointers = make(map[uintptr]int)
|
||||
d.dump(reflect.ValueOf(arg))
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||
// exactly the same as Dump.
|
||||
func Fdump(w io.Writer, a ...interface{}) {
|
||||
fdump(&Config, w, a...)
|
||||
}
|
||||
|
||||
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||
// as Dump.
|
||||
func Sdump(a ...interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
fdump(&Config, &buf, a...)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
/*
|
||||
Dump displays the passed parameters to standard out with newlines, customizable
|
||||
indentation, and additional debug information such as complete types and all
|
||||
pointer addresses used to indirect to the final value. It provides the
|
||||
following features over the built-in printing facilities provided by the fmt
|
||||
package:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output
|
||||
|
||||
The configuration options are controlled by an exported package global,
|
||||
spew.Config. See ConfigState for options documentation.
|
||||
|
||||
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||
get the formatted result as a string.
|
||||
*/
|
||||
func Dump(a ...interface{}) {
|
||||
fdump(&Config, os.Stdout, a...)
|
||||
}
|
1021
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dump_test.go
generated
vendored
Normal file
1021
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dump_test.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
97
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpcgo_test.go
generated
vendored
Normal file
97
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpcgo_test.go
generated
vendored
Normal file
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when both cgo is supported and "-tags testcgo" is added to the go test
|
||||
// command line. This means the cgo tests are only added (and hence run) when
|
||||
// specifially requested. This configuration is used because spew itself
|
||||
// does not require cgo to run even though it does handle certain cgo types
|
||||
// specially. Rather than forcing all clients to require cgo and an external
|
||||
// C compiler just to run the tests, this scheme makes them optional.
|
||||
// +build cgo,testcgo
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew/testdata"
|
||||
)
|
||||
|
||||
func addCgoDumpTests() {
|
||||
// C char pointer.
|
||||
v := testdata.GetCgoCharPointer()
|
||||
nv := testdata.GetCgoNullCharPointer()
|
||||
pv := &v
|
||||
vcAddr := fmt.Sprintf("%p", v)
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "*testdata._Ctype_char"
|
||||
vs := "116"
|
||||
addDumpTest(v, "("+vt+")("+vcAddr+")("+vs+")\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+"->"+vcAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+"->"+vcAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "("+vt+")(<nil>)\n")
|
||||
|
||||
// C char array.
|
||||
v2, v2l, v2c := testdata.GetCgoCharArray()
|
||||
v2Len := fmt.Sprintf("%d", v2l)
|
||||
v2Cap := fmt.Sprintf("%d", v2c)
|
||||
v2t := "[6]testdata._Ctype_char"
|
||||
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 32 00 " +
|
||||
" |test2.|\n}"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
|
||||
// C unsigned char array.
|
||||
v3, v3l, v3c := testdata.GetCgoUnsignedCharArray()
|
||||
v3Len := fmt.Sprintf("%d", v3l)
|
||||
v3Cap := fmt.Sprintf("%d", v3c)
|
||||
v3t := "[6]testdata._Ctype_unsignedchar"
|
||||
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 33 00 " +
|
||||
" |test3.|\n}"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
|
||||
// C signed char array.
|
||||
v4, v4l, v4c := testdata.GetCgoSignedCharArray()
|
||||
v4Len := fmt.Sprintf("%d", v4l)
|
||||
v4Cap := fmt.Sprintf("%d", v4c)
|
||||
v4t := "[6]testdata._Ctype_schar"
|
||||
v4t2 := "testdata._Ctype_schar"
|
||||
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
|
||||
"{\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 101,\n (" + v4t2 +
|
||||
") 115,\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 52,\n (" + v4t2 +
|
||||
") 0\n}"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
|
||||
// C uint8_t array.
|
||||
v5, v5l, v5c := testdata.GetCgoUint8tArray()
|
||||
v5Len := fmt.Sprintf("%d", v5l)
|
||||
v5Cap := fmt.Sprintf("%d", v5c)
|
||||
v5t := "[6]testdata._Ctype_uint8_t"
|
||||
v5s := "(len=" + v5Len + " cap=" + v5Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 35 00 " +
|
||||
" |test5.|\n}"
|
||||
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
|
||||
|
||||
// C typedefed unsigned char array.
|
||||
v6, v6l, v6c := testdata.GetCgoTypdefedUnsignedCharArray()
|
||||
v6Len := fmt.Sprintf("%d", v6l)
|
||||
v6Cap := fmt.Sprintf("%d", v6c)
|
||||
v6t := "[6]testdata._Ctype_custom_uchar_t"
|
||||
v6s := "(len=" + v6Len + " cap=" + v6Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 36 00 " +
|
||||
" |test6.|\n}"
|
||||
addDumpTest(v6, "("+v6t+") "+v6s+"\n")
|
||||
}
|
26
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpnocgo_test.go
generated
vendored
Normal file
26
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/dumpnocgo_test.go
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when either cgo is not supported or "-tags testcgo" is not added to the go
|
||||
// test command line. This file intentionally does not setup any cgo tests in
|
||||
// this scenario.
|
||||
// +build !cgo !testcgo
|
||||
|
||||
package spew_test
|
||||
|
||||
func addCgoDumpTests() {
|
||||
// Don't add any tests for cgo since this file is only compiled when
|
||||
// there should not be any cgo tests.
|
||||
}
|
230
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/example_test.go
generated
vendored
Normal file
230
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/example_test.go
generated
vendored
Normal file
|
@ -0,0 +1,230 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
type Flag int
|
||||
|
||||
const (
|
||||
flagOne Flag = iota
|
||||
flagTwo
|
||||
)
|
||||
|
||||
var flagStrings = map[Flag]string{
|
||||
flagOne: "flagOne",
|
||||
flagTwo: "flagTwo",
|
||||
}
|
||||
|
||||
func (f Flag) String() string {
|
||||
if s, ok := flagStrings[f]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("Unknown flag (%d)", int(f))
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
flag Flag
|
||||
data uintptr
|
||||
}
|
||||
|
||||
type Foo struct {
|
||||
unexportedField Bar
|
||||
ExportedField map[interface{}]interface{}
|
||||
}
|
||||
|
||||
// This example demonstrates how to use Dump to dump variables to stdout.
|
||||
func ExampleDump() {
|
||||
// The following package level declarations are assumed for this example:
|
||||
/*
|
||||
type Flag int
|
||||
|
||||
const (
|
||||
flagOne Flag = iota
|
||||
flagTwo
|
||||
)
|
||||
|
||||
var flagStrings = map[Flag]string{
|
||||
flagOne: "flagOne",
|
||||
flagTwo: "flagTwo",
|
||||
}
|
||||
|
||||
func (f Flag) String() string {
|
||||
if s, ok := flagStrings[f]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("Unknown flag (%d)", int(f))
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
flag Flag
|
||||
data uintptr
|
||||
}
|
||||
|
||||
type Foo struct {
|
||||
unexportedField Bar
|
||||
ExportedField map[interface{}]interface{}
|
||||
}
|
||||
*/
|
||||
|
||||
// Setup some sample data structures for the example.
|
||||
bar := Bar{Flag(flagTwo), uintptr(0)}
|
||||
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
|
||||
f := Flag(5)
|
||||
b := []byte{
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
|
||||
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
|
||||
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
|
||||
0x31, 0x32,
|
||||
}
|
||||
|
||||
// Dump!
|
||||
spew.Dump(s1, f, b)
|
||||
|
||||
// Output:
|
||||
// (spew_test.Foo) {
|
||||
// unexportedField: (spew_test.Bar) {
|
||||
// flag: (spew_test.Flag) flagTwo,
|
||||
// data: (uintptr) <nil>
|
||||
// },
|
||||
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
// (string) (len=3) "one": (bool) true
|
||||
// }
|
||||
// }
|
||||
// (spew_test.Flag) Unknown flag (5)
|
||||
// ([]uint8) (len=34 cap=34) {
|
||||
// 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||
// 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||
// 00000020 31 32 |12|
|
||||
// }
|
||||
//
|
||||
}
|
||||
|
||||
// This example demonstrates how to use Printf to display a variable with a
|
||||
// format string and inline formatting.
|
||||
func ExamplePrintf() {
|
||||
// Create a double pointer to a uint 8.
|
||||
ui8 := uint8(5)
|
||||
pui8 := &ui8
|
||||
ppui8 := &pui8
|
||||
|
||||
// Create a circular data type.
|
||||
type circular struct {
|
||||
ui8 uint8
|
||||
c *circular
|
||||
}
|
||||
c := circular{ui8: 1}
|
||||
c.c = &c
|
||||
|
||||
// Print!
|
||||
spew.Printf("ppui8: %v\n", ppui8)
|
||||
spew.Printf("circular: %v\n", c)
|
||||
|
||||
// Output:
|
||||
// ppui8: <**>5
|
||||
// circular: {1 <*>{1 <*><shown>}}
|
||||
}
|
||||
|
||||
// This example demonstrates how to use a ConfigState.
|
||||
func ExampleConfigState() {
|
||||
// Modify the indent level of the ConfigState only. The global
|
||||
// configuration is not modified.
|
||||
scs := spew.ConfigState{Indent: "\t"}
|
||||
|
||||
// Output using the ConfigState instance.
|
||||
v := map[string]int{"one": 1}
|
||||
scs.Printf("v: %v\n", v)
|
||||
scs.Dump(v)
|
||||
|
||||
// Output:
|
||||
// v: map[one:1]
|
||||
// (map[string]int) (len=1) {
|
||||
// (string) (len=3) "one": (int) 1
|
||||
// }
|
||||
}
|
||||
|
||||
// This example demonstrates how to use ConfigState.Dump to dump variables to
|
||||
// stdout
|
||||
func ExampleConfigState_Dump() {
|
||||
// See the top-level Dump example for details on the types used in this
|
||||
// example.
|
||||
|
||||
// Create two ConfigState instances with different indentation.
|
||||
scs := spew.ConfigState{Indent: "\t"}
|
||||
scs2 := spew.ConfigState{Indent: " "}
|
||||
|
||||
// Setup some sample data structures for the example.
|
||||
bar := Bar{Flag(flagTwo), uintptr(0)}
|
||||
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
|
||||
|
||||
// Dump using the ConfigState instances.
|
||||
scs.Dump(s1)
|
||||
scs2.Dump(s1)
|
||||
|
||||
// Output:
|
||||
// (spew_test.Foo) {
|
||||
// unexportedField: (spew_test.Bar) {
|
||||
// flag: (spew_test.Flag) flagTwo,
|
||||
// data: (uintptr) <nil>
|
||||
// },
|
||||
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
// (string) (len=3) "one": (bool) true
|
||||
// }
|
||||
// }
|
||||
// (spew_test.Foo) {
|
||||
// unexportedField: (spew_test.Bar) {
|
||||
// flag: (spew_test.Flag) flagTwo,
|
||||
// data: (uintptr) <nil>
|
||||
// },
|
||||
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
// (string) (len=3) "one": (bool) true
|
||||
// }
|
||||
// }
|
||||
//
|
||||
}
|
||||
|
||||
// This example demonstrates how to use ConfigState.Printf to display a variable
|
||||
// with a format string and inline formatting.
|
||||
func ExampleConfigState_Printf() {
|
||||
// See the top-level Dump example for details on the types used in this
|
||||
// example.
|
||||
|
||||
// Create two ConfigState instances and modify the method handling of the
|
||||
// first ConfigState only.
|
||||
scs := spew.NewDefaultConfig()
|
||||
scs2 := spew.NewDefaultConfig()
|
||||
scs.DisableMethods = true
|
||||
|
||||
// Alternatively
|
||||
// scs := spew.ConfigState{Indent: " ", DisableMethods: true}
|
||||
// scs2 := spew.ConfigState{Indent: " "}
|
||||
|
||||
// This is of type Flag which implements a Stringer and has raw value 1.
|
||||
f := flagTwo
|
||||
|
||||
// Dump using the ConfigState instances.
|
||||
scs.Printf("f: %v\n", f)
|
||||
scs2.Printf("f: %v\n", f)
|
||||
|
||||
// Output:
|
||||
// f: 1
|
||||
// f: flagTwo
|
||||
}
|
|
@ -0,0 +1,419 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// supportedFlags is a list of all the character flags supported by fmt package.
|
||||
const supportedFlags = "0-+# "
|
||||
|
||||
// formatState implements the fmt.Formatter interface and contains information
|
||||
// about the state of a formatting operation. The NewFormatter function can
|
||||
// be used to get a new Formatter which can be used directly as arguments
|
||||
// in standard fmt package printing calls.
|
||||
type formatState struct {
|
||||
value interface{}
|
||||
fs fmt.State
|
||||
depth int
|
||||
pointers map[uintptr]int
|
||||
ignoreNextType bool
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// buildDefaultFormat recreates the original format string without precision
|
||||
// and width information to pass in to fmt.Sprintf in the case of an
|
||||
// unrecognized type. Unless new types are added to the language, this
|
||||
// function won't ever be called.
|
||||
func (f *formatState) buildDefaultFormat() (format string) {
|
||||
buf := bytes.NewBuffer(percentBytes)
|
||||
|
||||
for _, flag := range supportedFlags {
|
||||
if f.fs.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteRune('v')
|
||||
|
||||
format = buf.String()
|
||||
return format
|
||||
}
|
||||
|
||||
// constructOrigFormat recreates the original format string including precision
|
||||
// and width information to pass along to the standard fmt package. This allows
|
||||
// automatic deferral of all format strings this package doesn't support.
|
||||
func (f *formatState) constructOrigFormat(verb rune) (format string) {
|
||||
buf := bytes.NewBuffer(percentBytes)
|
||||
|
||||
for _, flag := range supportedFlags {
|
||||
if f.fs.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
if width, ok := f.fs.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(width))
|
||||
}
|
||||
|
||||
if precision, ok := f.fs.Precision(); ok {
|
||||
buf.Write(precisionBytes)
|
||||
buf.WriteString(strconv.Itoa(precision))
|
||||
}
|
||||
|
||||
buf.WriteRune(verb)
|
||||
|
||||
format = buf.String()
|
||||
return format
|
||||
}
|
||||
|
||||
// unpackValue returns values inside of non-nil interfaces when possible and
|
||||
// ensures that types for values which have been unpacked from an interface
|
||||
// are displayed when the show types flag is also set.
|
||||
// This is useful for data types like structs, arrays, slices, and maps which
|
||||
// can contain varying types packed inside an interface.
|
||||
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Interface {
|
||||
f.ignoreNextType = false
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// formatPtr handles formatting of pointers by indirecting them as necessary.
|
||||
func (f *formatState) formatPtr(v reflect.Value) {
|
||||
// Display nil if top level pointer is nil.
|
||||
showTypes := f.fs.Flag('#')
|
||||
if v.IsNil() && (!showTypes || f.ignoreNextType) {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove pointers at or below the current depth from map used to detect
|
||||
// circular refs.
|
||||
for k, depth := range f.pointers {
|
||||
if depth >= f.depth {
|
||||
delete(f.pointers, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep list of all dereferenced pointers to possibly show later.
|
||||
pointerChain := make([]uintptr, 0)
|
||||
|
||||
// Figure out how many levels of indirection there are by derferencing
|
||||
// pointers and unpacking interfaces down the chain while detecting circular
|
||||
// references.
|
||||
nilFound := false
|
||||
cycleFound := false
|
||||
indirects := 0
|
||||
ve := v
|
||||
for ve.Kind() == reflect.Ptr {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
indirects++
|
||||
addr := ve.Pointer()
|
||||
pointerChain = append(pointerChain, addr)
|
||||
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
|
||||
cycleFound = true
|
||||
indirects--
|
||||
break
|
||||
}
|
||||
f.pointers[addr] = f.depth
|
||||
|
||||
ve = ve.Elem()
|
||||
if ve.Kind() == reflect.Interface {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
ve = ve.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// Display type or indirection level depending on flags.
|
||||
if showTypes && !f.ignoreNextType {
|
||||
f.fs.Write(openParenBytes)
|
||||
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||
f.fs.Write([]byte(ve.Type().String()))
|
||||
f.fs.Write(closeParenBytes)
|
||||
} else {
|
||||
if nilFound || cycleFound {
|
||||
indirects += strings.Count(ve.Type().String(), "*")
|
||||
}
|
||||
f.fs.Write(openAngleBytes)
|
||||
f.fs.Write([]byte(strings.Repeat("*", indirects)))
|
||||
f.fs.Write(closeAngleBytes)
|
||||
}
|
||||
|
||||
// Display pointer information depending on flags.
|
||||
if f.fs.Flag('+') && (len(pointerChain) > 0) {
|
||||
f.fs.Write(openParenBytes)
|
||||
for i, addr := range pointerChain {
|
||||
if i > 0 {
|
||||
f.fs.Write(pointerChainBytes)
|
||||
}
|
||||
printHexPtr(f.fs, addr)
|
||||
}
|
||||
f.fs.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// Display dereferenced value.
|
||||
switch {
|
||||
case nilFound == true:
|
||||
f.fs.Write(nilAngleBytes)
|
||||
|
||||
case cycleFound == true:
|
||||
f.fs.Write(circularShortBytes)
|
||||
|
||||
default:
|
||||
f.ignoreNextType = true
|
||||
f.format(ve)
|
||||
}
|
||||
}
|
||||
|
||||
// format is the main workhorse for providing the Formatter interface. It
|
||||
// uses the passed reflect value to figure out what kind of object we are
|
||||
// dealing with and formats it appropriately. It is a recursive function,
|
||||
// however circular data structures are detected and handled properly.
|
||||
func (f *formatState) format(v reflect.Value) {
|
||||
// Handle invalid reflect values immediately.
|
||||
kind := v.Kind()
|
||||
if kind == reflect.Invalid {
|
||||
f.fs.Write(invalidAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle pointers specially.
|
||||
if kind == reflect.Ptr {
|
||||
f.formatPtr(v)
|
||||
return
|
||||
}
|
||||
|
||||
// Print type information unless already handled elsewhere.
|
||||
if !f.ignoreNextType && f.fs.Flag('#') {
|
||||
f.fs.Write(openParenBytes)
|
||||
f.fs.Write([]byte(v.Type().String()))
|
||||
f.fs.Write(closeParenBytes)
|
||||
}
|
||||
f.ignoreNextType = false
|
||||
|
||||
// Call Stringer/error interfaces if they exist and the handle methods
|
||||
// flag is enabled.
|
||||
if !f.cs.DisableMethods {
|
||||
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||
if handled := handleMethods(f.cs, f.fs, v); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Invalid:
|
||||
// Do nothing. We should never get here since invalid has already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Bool:
|
||||
printBool(f.fs, v.Bool())
|
||||
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
printInt(f.fs, v.Int(), 10)
|
||||
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
printUint(f.fs, v.Uint(), 10)
|
||||
|
||||
case reflect.Float32:
|
||||
printFloat(f.fs, v.Float(), 32)
|
||||
|
||||
case reflect.Float64:
|
||||
printFloat(f.fs, v.Float(), 64)
|
||||
|
||||
case reflect.Complex64:
|
||||
printComplex(f.fs, v.Complex(), 32)
|
||||
|
||||
case reflect.Complex128:
|
||||
printComplex(f.fs, v.Complex(), 64)
|
||||
|
||||
case reflect.Slice:
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Array:
|
||||
f.fs.Write(openBracketBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
numEntries := v.Len()
|
||||
for i := 0; i < numEntries; i++ {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(v.Index(i)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeBracketBytes)
|
||||
|
||||
case reflect.String:
|
||||
f.fs.Write([]byte(v.String()))
|
||||
|
||||
case reflect.Interface:
|
||||
// The only time we should get here is for nil interfaces due to
|
||||
// unpackValue calls.
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// Do nothing. We should never get here since pointers have already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Map:
|
||||
// nil maps should be indicated as different than empty maps
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
|
||||
f.fs.Write(openMapBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
keys := v.MapKeys()
|
||||
if f.cs.SortKeys {
|
||||
sortValues(keys, f.cs)
|
||||
}
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(key))
|
||||
f.fs.Write(colonBytes)
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(v.MapIndex(key)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeMapBytes)
|
||||
|
||||
case reflect.Struct:
|
||||
numFields := v.NumField()
|
||||
f.fs.Write(openBraceBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
vt := v.Type()
|
||||
for i := 0; i < numFields; i++ {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
vtf := vt.Field(i)
|
||||
if f.fs.Flag('+') || f.fs.Flag('#') {
|
||||
f.fs.Write([]byte(vtf.Name))
|
||||
f.fs.Write(colonBytes)
|
||||
}
|
||||
f.format(f.unpackValue(v.Field(i)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Uintptr:
|
||||
printHexPtr(f.fs, uintptr(v.Uint()))
|
||||
|
||||
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||
printHexPtr(f.fs, v.Pointer())
|
||||
|
||||
// There were not any other types at the time this code was written, but
|
||||
// fall back to letting the default fmt package handle it if any get added.
|
||||
default:
|
||||
format := f.buildDefaultFormat()
|
||||
if v.CanInterface() {
|
||||
fmt.Fprintf(f.fs, format, v.Interface())
|
||||
} else {
|
||||
fmt.Fprintf(f.fs, format, v.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
|
||||
// details.
|
||||
func (f *formatState) Format(fs fmt.State, verb rune) {
|
||||
f.fs = fs
|
||||
|
||||
// Use standard formatting for verbs that are not v.
|
||||
if verb != 'v' {
|
||||
format := f.constructOrigFormat(verb)
|
||||
fmt.Fprintf(fs, format, f.value)
|
||||
return
|
||||
}
|
||||
|
||||
if f.value == nil {
|
||||
if fs.Flag('#') {
|
||||
fs.Write(interfaceBytes)
|
||||
}
|
||||
fs.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
f.format(reflect.ValueOf(f.value))
|
||||
}
|
||||
|
||||
// newFormatter is a helper function to consolidate the logic from the various
|
||||
// public methods which take varying config states.
|
||||
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
|
||||
fs := &formatState{value: v, cs: cs}
|
||||
fs.pointers = make(map[uintptr]int)
|
||||
return fs
|
||||
}
|
||||
|
||||
/*
|
||||
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||
interface. As a result, it integrates cleanly with standard fmt package
|
||||
printing functions. The formatter is useful for inline printing of smaller data
|
||||
types similar to the standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Typically this function shouldn't be called directly. It is much easier to make
|
||||
use of the custom formatter by calling one of the convenience functions such as
|
||||
Printf, Println, or Fprintf.
|
||||
*/
|
||||
func NewFormatter(v interface{}) fmt.Formatter {
|
||||
return newFormatter(&Config, v)
|
||||
}
|
1535
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/format_test.go
generated
vendored
Normal file
1535
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/format_test.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
156
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/internal_test.go
generated
vendored
Normal file
156
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/internal_test.go
generated
vendored
Normal file
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
/*
|
||||
This test file is part of the spew package rather than than the spew_test
|
||||
package because it needs access to internals to properly test certain cases
|
||||
which are not possible via the public interface since they should never happen.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// dummyFmtState implements a fake fmt.State to use for testing invalid
|
||||
// reflect.Value handling. This is necessary because the fmt package catches
|
||||
// invalid values before invoking the formatter on them.
|
||||
type dummyFmtState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (dfs *dummyFmtState) Flag(f int) bool {
|
||||
if f == int('+') {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dfs *dummyFmtState) Precision() (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (dfs *dummyFmtState) Width() (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// TestInvalidReflectValue ensures the dump and formatter code handles an
|
||||
// invalid reflect value properly. This needs access to internal state since it
|
||||
// should never happen in real code and therefore can't be tested via the public
|
||||
// API.
|
||||
func TestInvalidReflectValue(t *testing.T) {
|
||||
i := 1
|
||||
|
||||
// Dump invalid reflect value.
|
||||
v := new(reflect.Value)
|
||||
buf := new(bytes.Buffer)
|
||||
d := dumpState{w: buf, cs: &Config}
|
||||
d.dump(*v)
|
||||
s := buf.String()
|
||||
want := "<invalid>"
|
||||
if s != want {
|
||||
t.Errorf("InvalidReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Formatter invalid reflect value.
|
||||
buf2 := new(dummyFmtState)
|
||||
f := formatState{value: *v, cs: &Config, fs: buf2}
|
||||
f.format(*v)
|
||||
s = buf2.String()
|
||||
want = "<invalid>"
|
||||
if s != want {
|
||||
t.Errorf("InvalidReflectValue #%d got: %s want: %s", i, s, want)
|
||||
}
|
||||
}
|
||||
|
||||
// changeKind uses unsafe to intentionally change the kind of a reflect.Value to
|
||||
// the maximum kind value which does not exist. This is needed to test the
|
||||
// fallback code which punts to the standard fmt library for new types that
|
||||
// might get added to the language.
|
||||
func changeKind(v *reflect.Value, readOnly bool) {
|
||||
rvf := (*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + offsetFlag))
|
||||
*rvf = *rvf | ((1<<flagKindWidth - 1) << flagKindShift)
|
||||
if readOnly {
|
||||
*rvf |= flagRO
|
||||
} else {
|
||||
*rvf &= ^uintptr(flagRO)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddedReflectValue tests functionaly of the dump and formatter code which
|
||||
// falls back to the standard fmt library for new types that might get added to
|
||||
// the language.
|
||||
func TestAddedReflectValue(t *testing.T) {
|
||||
i := 1
|
||||
|
||||
// Dump using a reflect.Value that is exported.
|
||||
v := reflect.ValueOf(int8(5))
|
||||
changeKind(&v, false)
|
||||
buf := new(bytes.Buffer)
|
||||
d := dumpState{w: buf, cs: &Config}
|
||||
d.dump(v)
|
||||
s := buf.String()
|
||||
want := "(int8) 5"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Dump using a reflect.Value that is not exported.
|
||||
changeKind(&v, true)
|
||||
buf.Reset()
|
||||
d.dump(v)
|
||||
s = buf.String()
|
||||
want = "(int8) <int8 Value>"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Formatter using a reflect.Value that is exported.
|
||||
changeKind(&v, false)
|
||||
buf2 := new(dummyFmtState)
|
||||
f := formatState{value: v, cs: &Config, fs: buf2}
|
||||
f.format(v)
|
||||
s = buf2.String()
|
||||
want = "5"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Formatter using a reflect.Value that is not exported.
|
||||
changeKind(&v, true)
|
||||
buf2.Reset()
|
||||
f = formatState{value: v, cs: &Config, fs: buf2}
|
||||
f.format(v)
|
||||
s = buf2.String()
|
||||
want = "<int8 Value>"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
|
||||
}
|
||||
}
|
||||
|
||||
// SortValues makes the internal sortValues function available to the test
|
||||
// package.
|
||||
func SortValues(values []reflect.Value, cs *ConfigState) {
|
||||
sortValues(values, cs)
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the formatted string as a value that satisfies error. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Errorf(format string, a ...interface{}) (err error) {
|
||||
return fmt.Errorf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprint(w, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintf(w, format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||
// passed with a default Formatter interface returned by NewFormatter. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintln(w, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Print(a ...interface{}) (n int, err error) {
|
||||
return fmt.Print(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Printf(format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Printf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Println(a ...interface{}) (n int, err error) {
|
||||
return fmt.Println(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprint(a ...interface{}) string {
|
||||
return fmt.Sprint(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprintf(format string, a ...interface{}) string {
|
||||
return fmt.Sprintf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||
// were passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprintln(a ...interface{}) string {
|
||||
return fmt.Sprintln(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||
// length with each argument converted to a default spew Formatter interface.
|
||||
func convertArgs(args []interface{}) (formatters []interface{}) {
|
||||
formatters = make([]interface{}, len(args))
|
||||
for index, arg := range args {
|
||||
formatters[index] = NewFormatter(arg)
|
||||
}
|
||||
return formatters
|
||||
}
|
308
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/spew_test.go
generated
vendored
Normal file
308
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/spew_test.go
generated
vendored
Normal file
|
@ -0,0 +1,308 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// spewFunc is used to identify which public function of the spew package or
|
||||
// ConfigState a test applies to.
|
||||
type spewFunc int
|
||||
|
||||
const (
|
||||
fCSFdump spewFunc = iota
|
||||
fCSFprint
|
||||
fCSFprintf
|
||||
fCSFprintln
|
||||
fCSPrint
|
||||
fCSPrintln
|
||||
fCSSdump
|
||||
fCSSprint
|
||||
fCSSprintf
|
||||
fCSSprintln
|
||||
fCSErrorf
|
||||
fCSNewFormatter
|
||||
fErrorf
|
||||
fFprint
|
||||
fFprintln
|
||||
fPrint
|
||||
fPrintln
|
||||
fSdump
|
||||
fSprint
|
||||
fSprintf
|
||||
fSprintln
|
||||
)
|
||||
|
||||
// Map of spewFunc values to names for pretty printing.
|
||||
var spewFuncStrings = map[spewFunc]string{
|
||||
fCSFdump: "ConfigState.Fdump",
|
||||
fCSFprint: "ConfigState.Fprint",
|
||||
fCSFprintf: "ConfigState.Fprintf",
|
||||
fCSFprintln: "ConfigState.Fprintln",
|
||||
fCSSdump: "ConfigState.Sdump",
|
||||
fCSPrint: "ConfigState.Print",
|
||||
fCSPrintln: "ConfigState.Println",
|
||||
fCSSprint: "ConfigState.Sprint",
|
||||
fCSSprintf: "ConfigState.Sprintf",
|
||||
fCSSprintln: "ConfigState.Sprintln",
|
||||
fCSErrorf: "ConfigState.Errorf",
|
||||
fCSNewFormatter: "ConfigState.NewFormatter",
|
||||
fErrorf: "spew.Errorf",
|
||||
fFprint: "spew.Fprint",
|
||||
fFprintln: "spew.Fprintln",
|
||||
fPrint: "spew.Print",
|
||||
fPrintln: "spew.Println",
|
||||
fSdump: "spew.Sdump",
|
||||
fSprint: "spew.Sprint",
|
||||
fSprintf: "spew.Sprintf",
|
||||
fSprintln: "spew.Sprintln",
|
||||
}
|
||||
|
||||
func (f spewFunc) String() string {
|
||||
if s, ok := spewFuncStrings[f]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("Unknown spewFunc (%d)", int(f))
|
||||
}
|
||||
|
||||
// spewTest is used to describe a test to be performed against the public
|
||||
// functions of the spew package or ConfigState.
|
||||
type spewTest struct {
|
||||
cs *spew.ConfigState
|
||||
f spewFunc
|
||||
format string
|
||||
in interface{}
|
||||
want string
|
||||
}
|
||||
|
||||
// spewTests houses the tests to be performed against the public functions of
|
||||
// the spew package and ConfigState.
|
||||
//
|
||||
// These tests are only intended to ensure the public functions are exercised
|
||||
// and are intentionally not exhaustive of types. The exhaustive type
|
||||
// tests are handled in the dump and format tests.
|
||||
var spewTests []spewTest
|
||||
|
||||
// redirStdout is a helper function to return the standard output from f as a
|
||||
// byte slice.
|
||||
func redirStdout(f func()) ([]byte, error) {
|
||||
tempFile, err := ioutil.TempFile("", "ss-test")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileName := tempFile.Name()
|
||||
defer os.Remove(fileName) // Ignore error
|
||||
|
||||
origStdout := os.Stdout
|
||||
os.Stdout = tempFile
|
||||
f()
|
||||
os.Stdout = origStdout
|
||||
tempFile.Close()
|
||||
|
||||
return ioutil.ReadFile(fileName)
|
||||
}
|
||||
|
||||
func initSpewTests() {
|
||||
// Config states with various settings.
|
||||
scsDefault := spew.NewDefaultConfig()
|
||||
scsNoMethods := &spew.ConfigState{Indent: " ", DisableMethods: true}
|
||||
scsNoPmethods := &spew.ConfigState{Indent: " ", DisablePointerMethods: true}
|
||||
scsMaxDepth := &spew.ConfigState{Indent: " ", MaxDepth: 1}
|
||||
scsContinue := &spew.ConfigState{Indent: " ", ContinueOnMethod: true}
|
||||
|
||||
// Variables for tests on types which implement Stringer interface with and
|
||||
// without a pointer receiver.
|
||||
ts := stringer("test")
|
||||
tps := pstringer("test")
|
||||
|
||||
// depthTester is used to test max depth handling for structs, array, slices
|
||||
// and maps.
|
||||
type depthTester struct {
|
||||
ic indirCir1
|
||||
arr [1]string
|
||||
slice []string
|
||||
m map[string]int
|
||||
}
|
||||
dt := depthTester{indirCir1{nil}, [1]string{"arr"}, []string{"slice"},
|
||||
map[string]int{"one": 1}}
|
||||
|
||||
// Variable for tests on types which implement error interface.
|
||||
te := customError(10)
|
||||
|
||||
spewTests = []spewTest{
|
||||
{scsDefault, fCSFdump, "", int8(127), "(int8) 127\n"},
|
||||
{scsDefault, fCSFprint, "", int16(32767), "32767"},
|
||||
{scsDefault, fCSFprintf, "%v", int32(2147483647), "2147483647"},
|
||||
{scsDefault, fCSFprintln, "", int(2147483647), "2147483647\n"},
|
||||
{scsDefault, fCSPrint, "", int64(9223372036854775807), "9223372036854775807"},
|
||||
{scsDefault, fCSPrintln, "", uint8(255), "255\n"},
|
||||
{scsDefault, fCSSdump, "", uint8(64), "(uint8) 64\n"},
|
||||
{scsDefault, fCSSprint, "", complex(1, 2), "(1+2i)"},
|
||||
{scsDefault, fCSSprintf, "%v", complex(float32(3), 4), "(3+4i)"},
|
||||
{scsDefault, fCSSprintln, "", complex(float64(5), 6), "(5+6i)\n"},
|
||||
{scsDefault, fCSErrorf, "%#v", uint16(65535), "(uint16)65535"},
|
||||
{scsDefault, fCSNewFormatter, "%v", uint32(4294967295), "4294967295"},
|
||||
{scsDefault, fErrorf, "%v", uint64(18446744073709551615), "18446744073709551615"},
|
||||
{scsDefault, fFprint, "", float32(3.14), "3.14"},
|
||||
{scsDefault, fFprintln, "", float64(6.28), "6.28\n"},
|
||||
{scsDefault, fPrint, "", true, "true"},
|
||||
{scsDefault, fPrintln, "", false, "false\n"},
|
||||
{scsDefault, fSdump, "", complex(-10, -20), "(complex128) (-10-20i)\n"},
|
||||
{scsDefault, fSprint, "", complex(-1, -2), "(-1-2i)"},
|
||||
{scsDefault, fSprintf, "%v", complex(float32(-3), -4), "(-3-4i)"},
|
||||
{scsDefault, fSprintln, "", complex(float64(-5), -6), "(-5-6i)\n"},
|
||||
{scsNoMethods, fCSFprint, "", ts, "test"},
|
||||
{scsNoMethods, fCSFprint, "", &ts, "<*>test"},
|
||||
{scsNoMethods, fCSFprint, "", tps, "test"},
|
||||
{scsNoMethods, fCSFprint, "", &tps, "<*>test"},
|
||||
{scsNoPmethods, fCSFprint, "", ts, "stringer test"},
|
||||
{scsNoPmethods, fCSFprint, "", &ts, "<*>stringer test"},
|
||||
{scsNoPmethods, fCSFprint, "", tps, "test"},
|
||||
{scsNoPmethods, fCSFprint, "", &tps, "<*>stringer test"},
|
||||
{scsMaxDepth, fCSFprint, "", dt, "{{<max>} [<max>] [<max>] map[<max>]}"},
|
||||
{scsMaxDepth, fCSFdump, "", dt, "(spew_test.depthTester) {\n" +
|
||||
" ic: (spew_test.indirCir1) {\n <max depth reached>\n },\n" +
|
||||
" arr: ([1]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
|
||||
" slice: ([]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
|
||||
" m: (map[string]int) (len=1) {\n <max depth reached>\n }\n}\n"},
|
||||
{scsContinue, fCSFprint, "", ts, "(stringer test) test"},
|
||||
{scsContinue, fCSFdump, "", ts, "(spew_test.stringer) " +
|
||||
"(len=4) (stringer test) \"test\"\n"},
|
||||
{scsContinue, fCSFprint, "", te, "(error: 10) 10"},
|
||||
{scsContinue, fCSFdump, "", te, "(spew_test.customError) " +
|
||||
"(error: 10) 10\n"},
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpew executes all of the tests described by spewTests.
|
||||
func TestSpew(t *testing.T) {
|
||||
initSpewTests()
|
||||
|
||||
t.Logf("Running %d tests", len(spewTests))
|
||||
for i, test := range spewTests {
|
||||
buf := new(bytes.Buffer)
|
||||
switch test.f {
|
||||
case fCSFdump:
|
||||
test.cs.Fdump(buf, test.in)
|
||||
|
||||
case fCSFprint:
|
||||
test.cs.Fprint(buf, test.in)
|
||||
|
||||
case fCSFprintf:
|
||||
test.cs.Fprintf(buf, test.format, test.in)
|
||||
|
||||
case fCSFprintln:
|
||||
test.cs.Fprintln(buf, test.in)
|
||||
|
||||
case fCSPrint:
|
||||
b, err := redirStdout(func() { test.cs.Print(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fCSPrintln:
|
||||
b, err := redirStdout(func() { test.cs.Println(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fCSSdump:
|
||||
str := test.cs.Sdump(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSSprint:
|
||||
str := test.cs.Sprint(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSSprintf:
|
||||
str := test.cs.Sprintf(test.format, test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSSprintln:
|
||||
str := test.cs.Sprintln(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSErrorf:
|
||||
err := test.cs.Errorf(test.format, test.in)
|
||||
buf.WriteString(err.Error())
|
||||
|
||||
case fCSNewFormatter:
|
||||
fmt.Fprintf(buf, test.format, test.cs.NewFormatter(test.in))
|
||||
|
||||
case fErrorf:
|
||||
err := spew.Errorf(test.format, test.in)
|
||||
buf.WriteString(err.Error())
|
||||
|
||||
case fFprint:
|
||||
spew.Fprint(buf, test.in)
|
||||
|
||||
case fFprintln:
|
||||
spew.Fprintln(buf, test.in)
|
||||
|
||||
case fPrint:
|
||||
b, err := redirStdout(func() { spew.Print(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fPrintln:
|
||||
b, err := redirStdout(func() { spew.Println(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fSdump:
|
||||
str := spew.Sdump(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fSprint:
|
||||
str := spew.Sprint(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fSprintf:
|
||||
str := spew.Sprintf(test.format, test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fSprintln:
|
||||
str := spew.Sprintln(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
default:
|
||||
t.Errorf("%v #%d unrecognized function", test.f, i)
|
||||
continue
|
||||
}
|
||||
s := buf.String()
|
||||
if test.want != s {
|
||||
t.Errorf("ConfigState #%d\n got: %s want: %s", i, s, test.want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
82
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/testdata/dumpcgo.go
generated
vendored
Normal file
82
Godeps/_workspace/src/github.com/davecgh/go-spew/spew/testdata/dumpcgo.go
generated
vendored
Normal file
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when both cgo is supported and "-tags testcgo" is added to the go test
|
||||
// command line. This code should really only be in the dumpcgo_test.go file,
|
||||
// but unfortunately Go will not allow cgo in test files, so this is a
|
||||
// workaround to allow cgo types to be tested. This configuration is used
|
||||
// because spew itself does not require cgo to run even though it does handle
|
||||
// certain cgo types specially. Rather than forcing all clients to require cgo
|
||||
// and an external C compiler just to run the tests, this scheme makes them
|
||||
// optional.
|
||||
// +build cgo,testcgo
|
||||
|
||||
package testdata
|
||||
|
||||
/*
|
||||
#include <stdint.h>
|
||||
typedef unsigned char custom_uchar_t;
|
||||
|
||||
char *ncp = 0;
|
||||
char *cp = "test";
|
||||
char ca[6] = {'t', 'e', 's', 't', '2', '\0'};
|
||||
unsigned char uca[6] = {'t', 'e', 's', 't', '3', '\0'};
|
||||
signed char sca[6] = {'t', 'e', 's', 't', '4', '\0'};
|
||||
uint8_t ui8ta[6] = {'t', 'e', 's', 't', '5', '\0'};
|
||||
custom_uchar_t tuca[6] = {'t', 'e', 's', 't', '6', '\0'};
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// GetCgoNullCharPointer returns a null char pointer via cgo. This is only
|
||||
// used for tests.
|
||||
func GetCgoNullCharPointer() interface{} {
|
||||
return C.ncp
|
||||
}
|
||||
|
||||
// GetCgoCharPointer returns a char pointer via cgo. This is only used for
|
||||
// tests.
|
||||
func GetCgoCharPointer() interface{} {
|
||||
return C.cp
|
||||
}
|
||||
|
||||
// GetCgoCharArray returns a char array via cgo and the array's len and cap.
|
||||
// This is only used for tests.
|
||||
func GetCgoCharArray() (interface{}, int, int) {
|
||||
return C.ca, len(C.ca), cap(C.ca)
|
||||
}
|
||||
|
||||
// GetCgoUnsignedCharArray returns an unsigned char array via cgo and the
|
||||
// array's len and cap. This is only used for tests.
|
||||
func GetCgoUnsignedCharArray() (interface{}, int, int) {
|
||||
return C.uca, len(C.uca), cap(C.uca)
|
||||
}
|
||||
|
||||
// GetCgoSignedCharArray returns a signed char array via cgo and the array's len
|
||||
// and cap. This is only used for tests.
|
||||
func GetCgoSignedCharArray() (interface{}, int, int) {
|
||||
return C.sca, len(C.sca), cap(C.sca)
|
||||
}
|
||||
|
||||
// GetCgoUint8tArray returns a uint8_t array via cgo and the array's len and
|
||||
// cap. This is only used for tests.
|
||||
func GetCgoUint8tArray() (interface{}, int, int) {
|
||||
return C.ui8ta, len(C.ui8ta), cap(C.ui8ta)
|
||||
}
|
||||
|
||||
// GetCgoTypdefedUnsignedCharArray returns a typedefed unsigned char array via
|
||||
// cgo and the array's len and cap. This is only used for tests.
|
||||
func GetCgoTypdefedUnsignedCharArray() (interface{}, int, int) {
|
||||
return C.tuca, len(C.tuca), cap(C.tuca)
|
||||
}
|
|
@ -535,6 +535,7 @@ func (self *Ethereum) AddPeer(nodeURL string) error {
|
|||
func (s *Ethereum) Stop() {
|
||||
s.txSub.Unsubscribe() // quits txBroadcastLoop
|
||||
|
||||
s.net.Stop()
|
||||
s.protocolManager.Stop()
|
||||
s.chainManager.Stop()
|
||||
s.txPool.Stop()
|
||||
|
@ -544,7 +545,6 @@ func (s *Ethereum) Stop() {
|
|||
}
|
||||
s.StopAutoDAG()
|
||||
|
||||
glog.V(logger.Info).Infoln("Server stopped")
|
||||
close(s.shutdownChan)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,276 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/logger"
|
||||
"github.com/ethereum/go-ethereum/logger/glog"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
const (
|
||||
// This is the amount of time spent waiting in between
|
||||
// redialing a certain node.
|
||||
dialHistoryExpiration = 30 * time.Second
|
||||
|
||||
// Discovery lookup tasks will wait for this long when
|
||||
// no results are returned. This can happen if the table
|
||||
// becomes empty (i.e. not often).
|
||||
emptyLookupDelay = 10 * time.Second
|
||||
)
|
||||
|
||||
// dialstate schedules dials and discovery lookups.
|
||||
// it get's a chance to compute new tasks on every iteration
|
||||
// of the main loop in Server.run.
|
||||
type dialstate struct {
|
||||
maxDynDials int
|
||||
ntab discoverTable
|
||||
|
||||
lookupRunning bool
|
||||
bootstrapped bool
|
||||
|
||||
dialing map[discover.NodeID]connFlag
|
||||
lookupBuf []*discover.Node // current discovery lookup results
|
||||
randomNodes []*discover.Node // filled from Table
|
||||
static map[discover.NodeID]*discover.Node
|
||||
hist *dialHistory
|
||||
}
|
||||
|
||||
type discoverTable interface {
|
||||
Self() *discover.Node
|
||||
Close()
|
||||
Bootstrap([]*discover.Node)
|
||||
Lookup(target discover.NodeID) []*discover.Node
|
||||
ReadRandomNodes([]*discover.Node) int
|
||||
}
|
||||
|
||||
// the dial history remembers recent dials.
|
||||
type dialHistory []pastDial
|
||||
|
||||
// pastDial is an entry in the dial history.
|
||||
type pastDial struct {
|
||||
id discover.NodeID
|
||||
exp time.Time
|
||||
}
|
||||
|
||||
type task interface {
|
||||
Do(*Server)
|
||||
}
|
||||
|
||||
// A dialTask is generated for each node that is dialed.
|
||||
type dialTask struct {
|
||||
flags connFlag
|
||||
dest *discover.Node
|
||||
}
|
||||
|
||||
// discoverTask runs discovery table operations.
|
||||
// Only one discoverTask is active at any time.
|
||||
//
|
||||
// If bootstrap is true, the task runs Table.Bootstrap,
|
||||
// otherwise it performs a random lookup and leaves the
|
||||
// results in the task.
|
||||
type discoverTask struct {
|
||||
bootstrap bool
|
||||
results []*discover.Node
|
||||
}
|
||||
|
||||
// A waitExpireTask is generated if there are no other tasks
|
||||
// to keep the loop in Server.run ticking.
|
||||
type waitExpireTask struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate {
|
||||
s := &dialstate{
|
||||
maxDynDials: maxdyn,
|
||||
ntab: ntab,
|
||||
static: make(map[discover.NodeID]*discover.Node),
|
||||
dialing: make(map[discover.NodeID]connFlag),
|
||||
randomNodes: make([]*discover.Node, maxdyn/2),
|
||||
hist: new(dialHistory),
|
||||
}
|
||||
for _, n := range static {
|
||||
s.static[n.ID] = n
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *dialstate) addStatic(n *discover.Node) {
|
||||
s.static[n.ID] = n
|
||||
}
|
||||
|
||||
func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
|
||||
var newtasks []task
|
||||
addDial := func(flag connFlag, n *discover.Node) bool {
|
||||
_, dialing := s.dialing[n.ID]
|
||||
if dialing || peers[n.ID] != nil || s.hist.contains(n.ID) {
|
||||
return false
|
||||
}
|
||||
s.dialing[n.ID] = flag
|
||||
newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
|
||||
return true
|
||||
}
|
||||
|
||||
// Compute number of dynamic dials necessary at this point.
|
||||
needDynDials := s.maxDynDials
|
||||
for _, p := range peers {
|
||||
if p.rw.is(dynDialedConn) {
|
||||
needDynDials--
|
||||
}
|
||||
}
|
||||
for _, flag := range s.dialing {
|
||||
if flag&dynDialedConn != 0 {
|
||||
needDynDials--
|
||||
}
|
||||
}
|
||||
|
||||
// Expire the dial history on every invocation.
|
||||
s.hist.expire(now)
|
||||
|
||||
// Create dials for static nodes if they are not connected.
|
||||
for _, n := range s.static {
|
||||
addDial(staticDialedConn, n)
|
||||
}
|
||||
|
||||
// Use random nodes from the table for half of the necessary
|
||||
// dynamic dials.
|
||||
randomCandidates := needDynDials / 2
|
||||
if randomCandidates > 0 && s.bootstrapped {
|
||||
n := s.ntab.ReadRandomNodes(s.randomNodes)
|
||||
for i := 0; i < randomCandidates && i < n; i++ {
|
||||
if addDial(dynDialedConn, s.randomNodes[i]) {
|
||||
needDynDials--
|
||||
}
|
||||
}
|
||||
}
|
||||
// Create dynamic dials from random lookup results, removing tried
|
||||
// items from the result buffer.
|
||||
i := 0
|
||||
for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
|
||||
if addDial(dynDialedConn, s.lookupBuf[i]) {
|
||||
needDynDials--
|
||||
}
|
||||
}
|
||||
s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
|
||||
// Launch a discovery lookup if more candidates are needed. The
|
||||
// first discoverTask bootstraps the table and won't return any
|
||||
// results.
|
||||
if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
|
||||
s.lookupRunning = true
|
||||
newtasks = append(newtasks, &discoverTask{bootstrap: !s.bootstrapped})
|
||||
}
|
||||
|
||||
// Launch a timer to wait for the next node to expire if all
|
||||
// candidates have been tried and no task is currently active.
|
||||
// This should prevent cases where the dialer logic is not ticked
|
||||
// because there are no pending events.
|
||||
if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
|
||||
t := &waitExpireTask{s.hist.min().exp.Sub(now)}
|
||||
newtasks = append(newtasks, t)
|
||||
}
|
||||
return newtasks
|
||||
}
|
||||
|
||||
func (s *dialstate) taskDone(t task, now time.Time) {
|
||||
switch t := t.(type) {
|
||||
case *dialTask:
|
||||
s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
|
||||
delete(s.dialing, t.dest.ID)
|
||||
case *discoverTask:
|
||||
if t.bootstrap {
|
||||
s.bootstrapped = true
|
||||
}
|
||||
s.lookupRunning = false
|
||||
s.lookupBuf = append(s.lookupBuf, t.results...)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *dialTask) Do(srv *Server) {
|
||||
addr := &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}
|
||||
glog.V(logger.Debug).Infof("dialing %v\n", t.dest)
|
||||
fd, err := srv.Dialer.Dial("tcp", addr.String())
|
||||
if err != nil {
|
||||
glog.V(logger.Detail).Infof("dial error: %v", err)
|
||||
return
|
||||
}
|
||||
srv.setupConn(fd, t.flags, t.dest)
|
||||
}
|
||||
func (t *dialTask) String() string {
|
||||
return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
|
||||
}
|
||||
|
||||
func (t *discoverTask) Do(srv *Server) {
|
||||
if t.bootstrap {
|
||||
srv.ntab.Bootstrap(srv.BootstrapNodes)
|
||||
} else {
|
||||
var target discover.NodeID
|
||||
rand.Read(target[:])
|
||||
t.results = srv.ntab.Lookup(target)
|
||||
// newTasks generates a lookup task whenever dynamic dials are
|
||||
// necessary. Lookups need to take some time, otherwise the
|
||||
// event loop spins too fast. An empty result can only be
|
||||
// returned if the table is empty.
|
||||
if len(t.results) == 0 {
|
||||
time.Sleep(emptyLookupDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *discoverTask) String() (s string) {
|
||||
if t.bootstrap {
|
||||
s = "discovery bootstrap"
|
||||
} else {
|
||||
s = "discovery lookup"
|
||||
}
|
||||
if len(t.results) > 0 {
|
||||
s += fmt.Sprintf(" (%d results)", len(t.results))
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (t waitExpireTask) Do(*Server) {
|
||||
time.Sleep(t.Duration)
|
||||
}
|
||||
func (t waitExpireTask) String() string {
|
||||
return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
|
||||
}
|
||||
|
||||
// Use only these methods to access or modify dialHistory.
|
||||
func (h dialHistory) min() pastDial {
|
||||
return h[0]
|
||||
}
|
||||
func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
|
||||
heap.Push(h, pastDial{id, exp})
|
||||
}
|
||||
func (h dialHistory) contains(id discover.NodeID) bool {
|
||||
for _, v := range h {
|
||||
if v.id == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func (h *dialHistory) expire(now time.Time) {
|
||||
for h.Len() > 0 && h.min().exp.Before(now) {
|
||||
heap.Pop(h)
|
||||
}
|
||||
}
|
||||
|
||||
// heap.Interface boilerplate
|
||||
func (h dialHistory) Len() int { return len(h) }
|
||||
func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
|
||||
func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *dialHistory) Push(x interface{}) {
|
||||
*h = append(*h, x.(pastDial))
|
||||
}
|
||||
func (h *dialHistory) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
|
@ -0,0 +1,482 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
func init() {
|
||||
spew.Config.Indent = "\t"
|
||||
}
|
||||
|
||||
type dialtest struct {
|
||||
init *dialstate // state before and after the test.
|
||||
rounds []round
|
||||
}
|
||||
|
||||
type round struct {
|
||||
peers []*Peer // current peer set
|
||||
done []task // tasks that got done this round
|
||||
new []task // the result must match this one
|
||||
}
|
||||
|
||||
func runDialTest(t *testing.T, test dialtest) {
|
||||
var (
|
||||
vtime time.Time
|
||||
running int
|
||||
)
|
||||
pm := func(ps []*Peer) map[discover.NodeID]*Peer {
|
||||
m := make(map[discover.NodeID]*Peer)
|
||||
for _, p := range ps {
|
||||
m[p.rw.id] = p
|
||||
}
|
||||
return m
|
||||
}
|
||||
for i, round := range test.rounds {
|
||||
for _, task := range round.done {
|
||||
running--
|
||||
if running < 0 {
|
||||
panic("running task counter underflow")
|
||||
}
|
||||
test.init.taskDone(task, vtime)
|
||||
}
|
||||
|
||||
new := test.init.newTasks(running, pm(round.peers), vtime)
|
||||
if !sametasks(new, round.new) {
|
||||
t.Errorf("round %d: new tasks mismatch:\ngot %v\nwant %v\nstate: %v\nrunning: %v\n",
|
||||
i, spew.Sdump(new), spew.Sdump(round.new), spew.Sdump(test.init), spew.Sdump(running))
|
||||
}
|
||||
|
||||
// Time advances by 16 seconds on every round.
|
||||
vtime = vtime.Add(16 * time.Second)
|
||||
running += len(new)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeTable []*discover.Node
|
||||
|
||||
func (t fakeTable) Self() *discover.Node { return new(discover.Node) }
|
||||
func (t fakeTable) Close() {}
|
||||
func (t fakeTable) Bootstrap([]*discover.Node) {}
|
||||
func (t fakeTable) Lookup(target discover.NodeID) []*discover.Node {
|
||||
return nil
|
||||
}
|
||||
func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int {
|
||||
return copy(buf, t)
|
||||
}
|
||||
|
||||
// This test checks that dynamic dials are launched from discovery results.
|
||||
func TestDialStateDynDial(t *testing.T) {
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(nil, fakeTable{}, 5),
|
||||
rounds: []round{
|
||||
// A discovery query is launched.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
new: []task{&discoverTask{bootstrap: true}},
|
||||
},
|
||||
// Dynamic dials are launched when it completes.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
done: []task{
|
||||
&discoverTask{bootstrap: true, results: []*discover.Node{
|
||||
{ID: uintID(2)}, // this one is already connected and not dialed.
|
||||
{ID: uintID(3)},
|
||||
{ID: uintID(4)},
|
||||
{ID: uintID(5)},
|
||||
{ID: uintID(6)}, // these are not tried because max dyn dials is 5
|
||||
{ID: uintID(7)}, // ...
|
||||
}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||
},
|
||||
},
|
||||
// Some of the dials complete but no new ones are launched yet because
|
||||
// the sum of active dial count and dynamic peer count is == maxDynDials.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
},
|
||||
},
|
||||
// No new dial tasks are launched in the this round because
|
||||
// maxDynDials has been reached.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 14 * time.Second},
|
||||
},
|
||||
},
|
||||
// In this round, the peer with id 2 drops off. The query
|
||||
// results from last discovery lookup are reused.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(3)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(4)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
|
||||
},
|
||||
},
|
||||
// More peers (3,4) drop off and dial for ID 6 completes.
|
||||
// The last query result from the discovery lookup is reused
|
||||
// and a new one is spawned because more candidates are needed.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(6)}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
|
||||
&discoverTask{},
|
||||
},
|
||||
},
|
||||
// Peer 7 is connected, but there still aren't enough dynamic peers
|
||||
// (4 out of 5). However, a discovery is already running, so ensure
|
||||
// no new is started.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(7)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(7)}},
|
||||
},
|
||||
},
|
||||
// Finish the running node discovery with an empty set. A new lookup
|
||||
// should be immediately requested.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(0)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(5)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(7)}},
|
||||
},
|
||||
done: []task{
|
||||
&discoverTask{},
|
||||
},
|
||||
new: []task{
|
||||
&discoverTask{},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialStateDynDialFromTable(t *testing.T) {
|
||||
// This table always returns the same random nodes
|
||||
// in the order given below.
|
||||
table := fakeTable{
|
||||
{ID: uintID(1)},
|
||||
{ID: uintID(2)},
|
||||
{ID: uintID(3)},
|
||||
{ID: uintID(4)},
|
||||
{ID: uintID(5)},
|
||||
{ID: uintID(6)},
|
||||
{ID: uintID(7)},
|
||||
{ID: uintID(8)},
|
||||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(nil, table, 10),
|
||||
rounds: []round{
|
||||
// Discovery bootstrap is launched.
|
||||
{
|
||||
new: []task{&discoverTask{bootstrap: true}},
|
||||
},
|
||||
// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed.
|
||||
{
|
||||
done: []task{
|
||||
&discoverTask{bootstrap: true},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||
&discoverTask{bootstrap: false},
|
||||
},
|
||||
},
|
||||
// Dialing nodes 1,2 succeeds. Dials from the lookup are launched.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(1)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(2)}},
|
||||
&discoverTask{results: []*discover.Node{
|
||||
{ID: uintID(10)},
|
||||
{ID: uintID(11)},
|
||||
{ID: uintID(12)},
|
||||
}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
|
||||
&discoverTask{bootstrap: false},
|
||||
},
|
||||
},
|
||||
// Dialing nodes 3,4,5 fails. The dials from the lookup succeed.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(5)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(10)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(11)}},
|
||||
&dialTask{dynDialedConn, &discover.Node{ID: uintID(12)}},
|
||||
},
|
||||
},
|
||||
// Waiting for expiry. No waitExpireTask is launched because the
|
||||
// discovery query is still running.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
|
||||
},
|
||||
},
|
||||
// Nodes 3,4 are not tried again because only the first two
|
||||
// returned random nodes (nodes 1,2) are tried and they're
|
||||
// already connected.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(10)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(11)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(12)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// This test checks that static dials are launched.
|
||||
func TestDialStateStaticDial(t *testing.T) {
|
||||
wantStatic := []*discover.Node{
|
||||
{ID: uintID(1)},
|
||||
{ID: uintID(2)},
|
||||
{ID: uintID(3)},
|
||||
{ID: uintID(4)},
|
||||
{ID: uintID(5)},
|
||||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(wantStatic, fakeTable{}, 0),
|
||||
rounds: []round{
|
||||
// Static dials are launched for the nodes that
|
||||
// aren't yet connected.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
|
||||
},
|
||||
},
|
||||
// No new tasks are launched in this round because all static
|
||||
// nodes are either connected or still being dialed.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
},
|
||||
},
|
||||
// No new dial tasks are launched because all static
|
||||
// nodes are now connected.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(4)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(5)}},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 14 * time.Second},
|
||||
},
|
||||
},
|
||||
// Wait a round for dial history to expire, no new tasks should spawn.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(4)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
|
||||
},
|
||||
},
|
||||
// If a static node is dropped, it should be immediately redialed,
|
||||
// irrespective whether it was originally static or dynamic.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(3)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(5)}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(4)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// This test checks that past dials are not retried for some time.
|
||||
func TestDialStateCache(t *testing.T) {
|
||||
wantStatic := []*discover.Node{
|
||||
{ID: uintID(1)},
|
||||
{ID: uintID(2)},
|
||||
{ID: uintID(3)},
|
||||
}
|
||||
|
||||
runDialTest(t, dialtest{
|
||||
init: newDialState(wantStatic, fakeTable{}, 0),
|
||||
rounds: []round{
|
||||
// Static dials are launched for the nodes that
|
||||
// aren't yet connected.
|
||||
{
|
||||
peers: nil,
|
||||
new: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
},
|
||||
},
|
||||
// No new tasks are launched in this round because all static
|
||||
// nodes are either connected or still being dialed.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: staticDialedConn, id: uintID(2)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(1)}},
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(2)}},
|
||||
},
|
||||
},
|
||||
// A salvage task is launched to wait for node 3's history
|
||||
// entry to expire.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
done: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
},
|
||||
new: []task{
|
||||
&waitExpireTask{Duration: 14 * time.Second},
|
||||
},
|
||||
},
|
||||
// Still waiting for node 3's entry to expire in the cache.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
},
|
||||
// The cache entry for node 3 has expired and is retried.
|
||||
{
|
||||
peers: []*Peer{
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(1)}},
|
||||
{rw: &conn{flags: dynDialedConn, id: uintID(2)}},
|
||||
},
|
||||
new: []task{
|
||||
&dialTask{staticDialedConn, &discover.Node{ID: uintID(3)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// compares task lists but doesn't care about the order.
|
||||
func sametasks(a, b []task) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
next:
|
||||
for _, ta := range a {
|
||||
for _, tb := range b {
|
||||
if reflect.DeepEqual(ta, tb) {
|
||||
continue next
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func uintID(i uint32) discover.NodeID {
|
||||
var id discover.NodeID
|
||||
binary.BigEndian.PutUint32(id[:], i)
|
||||
return id
|
||||
}
|
|
@ -8,6 +8,7 @@ package discover
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sort"
|
||||
"sync"
|
||||
|
@ -90,10 +91,58 @@ func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string
|
|||
}
|
||||
|
||||
// Self returns the local node.
|
||||
// The returned node should not be modified by the caller.
|
||||
func (tab *Table) Self() *Node {
|
||||
return tab.self
|
||||
}
|
||||
|
||||
// ReadRandomNodes fills the given slice with random nodes from the
|
||||
// table. It will not write the same node more than once. The nodes in
|
||||
// the slice are copies and can be modified by the caller.
|
||||
func (tab *Table) ReadRandomNodes(buf []*Node) (n int) {
|
||||
tab.mutex.Lock()
|
||||
defer tab.mutex.Unlock()
|
||||
// TODO: tree-based buckets would help here
|
||||
// Find all non-empty buckets and get a fresh slice of their entries.
|
||||
var buckets [][]*Node
|
||||
for _, b := range tab.buckets {
|
||||
if len(b.entries) > 0 {
|
||||
buckets = append(buckets, b.entries[:])
|
||||
}
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
return 0
|
||||
}
|
||||
// Shuffle the buckets.
|
||||
for i := uint32(len(buckets)) - 1; i > 0; i-- {
|
||||
j := randUint(i)
|
||||
buckets[i], buckets[j] = buckets[j], buckets[i]
|
||||
}
|
||||
// Move head of each bucket into buf, removing buckets that become empty.
|
||||
var i, j int
|
||||
for ; i < len(buf); i, j = i+1, (j+1)%len(buckets) {
|
||||
b := buckets[j]
|
||||
buf[i] = &(*b[0])
|
||||
buckets[j] = b[1:]
|
||||
if len(b) == 1 {
|
||||
buckets = append(buckets[:j], buckets[j+1:]...)
|
||||
}
|
||||
if len(buckets) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return i + 1
|
||||
}
|
||||
|
||||
func randUint(max uint32) uint32 {
|
||||
if max == 0 {
|
||||
return 0
|
||||
}
|
||||
var b [4]byte
|
||||
rand.Read(b[:])
|
||||
return binary.BigEndian.Uint32(b[:]) % max
|
||||
}
|
||||
|
||||
// Close terminates the network listener and flushes the node database.
|
||||
func (tab *Table) Close() {
|
||||
tab.net.close()
|
||||
|
|
|
@ -210,6 +210,36 @@ func TestTable_closest(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTable_ReadRandomNodesGetAll(t *testing.T) {
|
||||
cfg := &quick.Config{
|
||||
MaxCount: 200,
|
||||
Rand: quickrand,
|
||||
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||
args[0] = reflect.ValueOf(make([]*Node, rand.Intn(1000)))
|
||||
},
|
||||
}
|
||||
test := func(buf []*Node) bool {
|
||||
tab := newTable(nil, NodeID{}, &net.UDPAddr{}, "")
|
||||
for i := 0; i < len(buf); i++ {
|
||||
ld := quickrand.Intn(len(tab.buckets))
|
||||
tab.add([]*Node{nodeAtDistance(tab.self.sha, ld)})
|
||||
}
|
||||
gotN := tab.ReadRandomNodes(buf)
|
||||
if gotN != tab.len() {
|
||||
t.Errorf("wrong number of nodes, got %d, want %d", gotN, tab.len())
|
||||
return false
|
||||
}
|
||||
if hasDuplicates(buf[:gotN]) {
|
||||
t.Errorf("result contains duplicates")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err := quick.Check(test, cfg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
type closeTest struct {
|
||||
Self NodeID
|
||||
Target common.Hash
|
||||
|
@ -517,7 +547,10 @@ func (n *preminedTestnet) mine(target NodeID) {
|
|||
|
||||
func hasDuplicates(slice []*Node) bool {
|
||||
seen := make(map[NodeID]bool)
|
||||
for _, e := range slice {
|
||||
for i, e := range slice {
|
||||
if e == nil {
|
||||
panic(fmt.Sprintf("nil *Node at %d", i))
|
||||
}
|
||||
if seen[e.ID] {
|
||||
return true
|
||||
}
|
||||
|
|
448
p2p/handshake.go
448
p2p/handshake.go
|
@ -1,448 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
const (
|
||||
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||
sigLen = 65 // elliptic S256
|
||||
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
|
||||
shaLen = 32 // hash length (for nonce etc)
|
||||
|
||||
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
|
||||
authRespLen = pubLen + shaLen + 1
|
||||
|
||||
eciesBytes = 65 + 16 + 32
|
||||
encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
|
||||
encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||
)
|
||||
|
||||
// conn represents a remote connection after encryption handshake
|
||||
// and protocol handshake have completed.
|
||||
//
|
||||
// The MsgReadWriter is usually layered as follows:
|
||||
//
|
||||
// netWrapper (I/O timeouts, thread-safe ReadMsg, WriteMsg)
|
||||
// rlpxFrameRW (message encoding, encryption, authentication)
|
||||
// bufio.ReadWriter (buffering)
|
||||
// net.Conn (network I/O)
|
||||
//
|
||||
type conn struct {
|
||||
MsgReadWriter
|
||||
*protoHandshake
|
||||
}
|
||||
|
||||
// secrets represents the connection secrets
|
||||
// which are negotiated during the encryption handshake.
|
||||
type secrets struct {
|
||||
RemoteID discover.NodeID
|
||||
AES, MAC []byte
|
||||
EgressMAC, IngressMAC hash.Hash
|
||||
Token []byte
|
||||
}
|
||||
|
||||
// protoHandshake is the RLP structure of the protocol handshake.
|
||||
type protoHandshake struct {
|
||||
Version uint64
|
||||
Name string
|
||||
Caps []Cap
|
||||
ListenPort uint64
|
||||
ID discover.NodeID
|
||||
}
|
||||
|
||||
// setupConn starts a protocol session on the given connection. It
|
||||
// runs the encryption handshake and the protocol handshake. If dial
|
||||
// is non-nil, the connection the local node is the initiator. If
|
||||
// keepconn returns false, the connection will be disconnected with
|
||||
// DiscTooManyPeers after the key exchange.
|
||||
func setupConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
|
||||
if dial == nil {
|
||||
return setupInboundConn(fd, prv, our, keepconn)
|
||||
} else {
|
||||
return setupOutboundConn(fd, prv, our, dial, keepconn)
|
||||
}
|
||||
}
|
||||
|
||||
func setupInboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, keepconn func(discover.NodeID) bool) (*conn, error) {
|
||||
secrets, err := receiverEncHandshake(fd, prv, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
||||
}
|
||||
rw := newRlpxFrameRW(fd, secrets)
|
||||
if !keepconn(secrets.RemoteID) {
|
||||
SendItems(rw, discMsg, DiscTooManyPeers)
|
||||
return nil, errors.New("we have too many peers")
|
||||
}
|
||||
// Run the protocol handshake using authenticated messages.
|
||||
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := Send(rw, handshakeMsg, our); err != nil {
|
||||
return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
||||
}
|
||||
return &conn{rw, rhs}, nil
|
||||
}
|
||||
|
||||
func setupOutboundConn(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
|
||||
secrets, err := initiatorEncHandshake(fd, prv, dial.ID, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption handshake failed: %v", err)
|
||||
}
|
||||
rw := newRlpxFrameRW(fd, secrets)
|
||||
if !keepconn(secrets.RemoteID) {
|
||||
SendItems(rw, discMsg, DiscTooManyPeers)
|
||||
return nil, errors.New("we have too many peers")
|
||||
}
|
||||
// Run the protocol handshake using authenticated messages.
|
||||
//
|
||||
// Note that even though writing the handshake is first, we prefer
|
||||
// returning the handshake read error. If the remote side
|
||||
// disconnects us early with a valid reason, we should return it
|
||||
// as the error so it can be tracked elsewhere.
|
||||
werr := make(chan error, 1)
|
||||
go func() { werr <- Send(rw, handshakeMsg, our) }()
|
||||
rhs, err := readProtocolHandshake(rw, secrets.RemoteID, our)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := <-werr; err != nil {
|
||||
return nil, fmt.Errorf("protocol handshake write error: %v", err)
|
||||
}
|
||||
if rhs.ID != dial.ID {
|
||||
return nil, errors.New("dialed node id mismatch")
|
||||
}
|
||||
return &conn{rw, rhs}, nil
|
||||
}
|
||||
|
||||
// encHandshake contains the state of the encryption handshake.
|
||||
type encHandshake struct {
|
||||
initiator bool
|
||||
remoteID discover.NodeID
|
||||
|
||||
remotePub *ecies.PublicKey // remote-pubk
|
||||
initNonce, respNonce []byte // nonce
|
||||
randomPrivKey *ecies.PrivateKey // ecdhe-random
|
||||
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
|
||||
}
|
||||
|
||||
// secrets is called after the handshake is completed.
|
||||
// It extracts the connection secrets from the handshake values.
|
||||
func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
|
||||
ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
|
||||
if err != nil {
|
||||
return secrets{}, err
|
||||
}
|
||||
|
||||
// derive base secrets from ephemeral key agreement
|
||||
sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
|
||||
aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
|
||||
s := secrets{
|
||||
RemoteID: h.remoteID,
|
||||
AES: aesSecret,
|
||||
MAC: crypto.Sha3(ecdheSecret, aesSecret),
|
||||
Token: crypto.Sha3(sharedSecret),
|
||||
}
|
||||
|
||||
// setup sha3 instances for the MACs
|
||||
mac1 := sha3.NewKeccak256()
|
||||
mac1.Write(xor(s.MAC, h.respNonce))
|
||||
mac1.Write(auth)
|
||||
mac2 := sha3.NewKeccak256()
|
||||
mac2.Write(xor(s.MAC, h.initNonce))
|
||||
mac2.Write(authResp)
|
||||
if h.initiator {
|
||||
s.EgressMAC, s.IngressMAC = mac1, mac2
|
||||
} else {
|
||||
s.EgressMAC, s.IngressMAC = mac2, mac1
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
|
||||
return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
|
||||
}
|
||||
|
||||
// initiatorEncHandshake negotiates a session token on conn.
|
||||
// it should be called on the dialing side of the connection.
|
||||
//
|
||||
// prv is the local client's private key.
|
||||
// token is the token from a previous session with this node.
|
||||
func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
|
||||
h, err := newInitiatorHandshake(remoteID)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
auth, err := h.authMsg(prv, token)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
if _, err = conn.Write(auth); err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
response := make([]byte, encAuthRespLen)
|
||||
if _, err = io.ReadFull(conn, response); err != nil {
|
||||
return s, err
|
||||
}
|
||||
if err := h.decodeAuthResp(response, prv); err != nil {
|
||||
return s, err
|
||||
}
|
||||
return h.secrets(auth, response)
|
||||
}
|
||||
|
||||
func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
|
||||
// generate random initiator nonce
|
||||
n := make([]byte, shaLen)
|
||||
if _, err := rand.Read(n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// generate random keypair to use for signing
|
||||
randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rpub, err := remoteID.Pubkey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad remoteID: %v", err)
|
||||
}
|
||||
h := &encHandshake{
|
||||
initiator: true,
|
||||
remoteID: remoteID,
|
||||
remotePub: ecies.ImportECDSAPublic(rpub),
|
||||
initNonce: n,
|
||||
randomPrivKey: randpriv,
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// authMsg creates an encrypted initiator handshake message.
|
||||
func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
||||
var tokenFlag byte
|
||||
if token == nil {
|
||||
// no session token found means we need to generate shared secret.
|
||||
// ecies shared secret is used as initial session token for new peers
|
||||
// generate shared key from prv and remote pubkey
|
||||
var err error
|
||||
if token, err = h.ecdhShared(prv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// for known peers, we use stored token from the previous session
|
||||
tokenFlag = 0x01
|
||||
}
|
||||
|
||||
// sign known message:
|
||||
// ecdh-shared-secret^nonce for new peers
|
||||
// token^nonce for old peers
|
||||
signed := xor(token, h.initNonce)
|
||||
signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// encode auth message
|
||||
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
||||
msg := make([]byte, authMsgLen)
|
||||
n := copy(msg, signature)
|
||||
n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
|
||||
n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
|
||||
n += copy(msg[n:], h.initNonce)
|
||||
msg[n] = tokenFlag
|
||||
|
||||
// encrypt auth message using remote-pubk
|
||||
return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
|
||||
}
|
||||
|
||||
// decodeAuthResp decode an encrypted authentication response message.
|
||||
func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
|
||||
msg, err := crypto.Decrypt(prv, auth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not decrypt auth response (%v)", err)
|
||||
}
|
||||
h.respNonce = msg[pubLen : pubLen+shaLen]
|
||||
h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ignore token flag for now
|
||||
return nil
|
||||
}
|
||||
|
||||
// receiverEncHandshake negotiates a session token on conn.
|
||||
// it should be called on the listening side of the connection.
|
||||
//
|
||||
// prv is the local client's private key.
|
||||
// token is the token from a previous session with this node.
|
||||
func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
|
||||
// read remote auth sent by initiator.
|
||||
auth := make([]byte, encAuthMsgLen)
|
||||
if _, err := io.ReadFull(conn, auth); err != nil {
|
||||
return s, err
|
||||
}
|
||||
h, err := decodeAuthMsg(prv, token, auth)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
// send auth response
|
||||
resp, err := h.authResp(prv, token)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
if _, err = conn.Write(resp); err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return h.secrets(auth, resp)
|
||||
}
|
||||
|
||||
func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
|
||||
var err error
|
||||
h := new(encHandshake)
|
||||
// generate random keypair for session
|
||||
h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// generate random nonce
|
||||
h.respNonce = make([]byte, shaLen)
|
||||
if _, err = rand.Read(h.respNonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := crypto.Decrypt(prv, auth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
|
||||
}
|
||||
|
||||
// decode message parameters
|
||||
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
||||
h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||
copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
|
||||
rpub, err := h.remoteID.Pubkey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad remoteID: %#v", err)
|
||||
}
|
||||
h.remotePub = ecies.ImportECDSAPublic(rpub)
|
||||
|
||||
// recover remote random pubkey from signed message.
|
||||
if token == nil {
|
||||
// TODO: it is an error if the initiator has a token and we don't. check that.
|
||||
|
||||
// no session token means we need to generate shared secret.
|
||||
// ecies shared secret is used as initial session token for new peers.
|
||||
// generate shared key from prv and remote pubkey.
|
||||
if token, err = h.ecdhShared(prv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
signedMsg := xor(token, h.initNonce)
|
||||
remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// authResp generates the encrypted authentication response message.
|
||||
func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
||||
// responder auth message
|
||||
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
|
||||
resp := make([]byte, authRespLen)
|
||||
n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
|
||||
n += copy(resp[n:], h.respNonce)
|
||||
if token == nil {
|
||||
resp[n] = 0
|
||||
} else {
|
||||
resp[n] = 1
|
||||
}
|
||||
// encrypt using remote-pubk
|
||||
return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
|
||||
}
|
||||
|
||||
// importPublicKey unmarshals 512 bit public keys.
|
||||
func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
|
||||
var pubKey65 []byte
|
||||
switch len(pubKey) {
|
||||
case 64:
|
||||
// add 'uncompressed key' flag
|
||||
pubKey65 = append([]byte{0x04}, pubKey...)
|
||||
case 65:
|
||||
pubKey65 = pubKey
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
|
||||
}
|
||||
// TODO: fewer pointless conversions
|
||||
return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
|
||||
}
|
||||
|
||||
func exportPubkey(pub *ecies.PublicKey) []byte {
|
||||
if pub == nil {
|
||||
panic("nil pubkey")
|
||||
}
|
||||
return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
|
||||
}
|
||||
|
||||
func xor(one, other []byte) (xor []byte) {
|
||||
xor = make([]byte, len(one))
|
||||
for i := 0; i < len(one); i++ {
|
||||
xor[i] = one[i] ^ other[i]
|
||||
}
|
||||
return xor
|
||||
}
|
||||
|
||||
func readProtocolHandshake(rw MsgReadWriter, wantID discover.NodeID, our *protoHandshake) (*protoHandshake, error) {
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg.Code == discMsg {
|
||||
// disconnect before protocol handshake is valid according to the
|
||||
// spec and we send it ourself if Server.addPeer fails.
|
||||
var reason [1]DiscReason
|
||||
rlp.Decode(msg.Payload, &reason)
|
||||
return nil, reason[0]
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return nil, fmt.Errorf("message too big (%d > %d)", msg.Size, baseProtocolMaxMsgSize)
|
||||
}
|
||||
var hs protoHandshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// validate handshake info
|
||||
if hs.Version != our.Version {
|
||||
SendItems(rw, discMsg, DiscIncompatibleVersion)
|
||||
return nil, fmt.Errorf("required version %d, received %d\n", baseProtocolVersion, hs.Version)
|
||||
}
|
||||
if (hs.ID == discover.NodeID{}) {
|
||||
SendItems(rw, discMsg, DiscInvalidIdentity)
|
||||
return nil, errors.New("invalid public key in handshake")
|
||||
}
|
||||
if hs.ID != wantID {
|
||||
SendItems(rw, discMsg, DiscUnexpectedIdentity)
|
||||
return nil, errors.New("handshake node ID does not match encryption handshake")
|
||||
}
|
||||
return &hs, nil
|
||||
}
|
|
@ -1,172 +0,0 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
func TestSharedSecret(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
pub0 := &prv0.PublicKey
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
pub1 := &prv1.PublicKey
|
||||
|
||||
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
||||
if !bytes.Equal(ss0, ss1) {
|
||||
t.Errorf("dont match :(")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncHandshake(t *testing.T) {
|
||||
for i := 0; i < 20; i++ {
|
||||
start := time.Now()
|
||||
if err := testEncHandshake(nil); err != nil {
|
||||
t.Fatalf("i=%d %v", i, err)
|
||||
}
|
||||
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
||||
}
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
tok := make([]byte, shaLen)
|
||||
rand.Reader.Read(tok)
|
||||
start := time.Now()
|
||||
if err := testEncHandshake(tok); err != nil {
|
||||
t.Fatalf("i=%d %v", i, err)
|
||||
}
|
||||
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func testEncHandshake(token []byte) error {
|
||||
type result struct {
|
||||
side string
|
||||
s secrets
|
||||
err error
|
||||
}
|
||||
var (
|
||||
prv0, _ = crypto.GenerateKey()
|
||||
prv1, _ = crypto.GenerateKey()
|
||||
rw0, rw1 = net.Pipe()
|
||||
output = make(chan result)
|
||||
)
|
||||
|
||||
go func() {
|
||||
r := result{side: "initiator"}
|
||||
defer func() { output <- r }()
|
||||
|
||||
pub1s := discover.PubkeyID(&prv1.PublicKey)
|
||||
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
id1 := discover.PubkeyID(&prv1.PublicKey)
|
||||
if r.s.RemoteID != id1 {
|
||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
r := result{side: "receiver"}
|
||||
defer func() { output <- r }()
|
||||
|
||||
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
id0 := discover.PubkeyID(&prv0.PublicKey)
|
||||
if r.s.RemoteID != id0 {
|
||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for results from both sides
|
||||
r1, r2 := <-output, <-output
|
||||
|
||||
if r1.err != nil {
|
||||
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
||||
}
|
||||
if r2.err != nil {
|
||||
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
||||
}
|
||||
|
||||
// don't compare remote node IDs
|
||||
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
|
||||
// flip MACs on one of them so they compare equal
|
||||
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
|
||||
if !reflect.DeepEqual(r1.s, r2.s) {
|
||||
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetupConn(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey()
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
node0 := &discover.Node{
|
||||
ID: discover.PubkeyID(&prv0.PublicKey),
|
||||
IP: net.IP{1, 2, 3, 4},
|
||||
TCP: 33,
|
||||
}
|
||||
node1 := &discover.Node{
|
||||
ID: discover.PubkeyID(&prv1.PublicKey),
|
||||
IP: net.IP{5, 6, 7, 8},
|
||||
TCP: 44,
|
||||
}
|
||||
hs0 := &protoHandshake{
|
||||
Version: baseProtocolVersion,
|
||||
ID: node0.ID,
|
||||
Caps: []Cap{{"a", 0}, {"b", 2}},
|
||||
}
|
||||
hs1 := &protoHandshake{
|
||||
Version: baseProtocolVersion,
|
||||
ID: node1.ID,
|
||||
Caps: []Cap{{"c", 1}, {"d", 3}},
|
||||
}
|
||||
fd0, fd1 := net.Pipe()
|
||||
|
||||
done := make(chan struct{})
|
||||
keepalways := func(discover.NodeID) bool { return true }
|
||||
go func() {
|
||||
defer close(done)
|
||||
conn0, err := setupConn(fd0, prv0, hs0, node1, keepalways)
|
||||
if err != nil {
|
||||
t.Errorf("outbound side error: %v", err)
|
||||
return
|
||||
}
|
||||
if conn0.ID != node1.ID {
|
||||
t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
|
||||
t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
|
||||
}
|
||||
}()
|
||||
|
||||
conn1, err := setupConn(fd1, prv1, hs1, nil, keepalways)
|
||||
if err != nil {
|
||||
t.Fatalf("inbound side error: %v", err)
|
||||
}
|
||||
if conn1.ID != node0.ID {
|
||||
t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
|
||||
}
|
||||
if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
|
||||
t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
73
p2p/peer.go
73
p2p/peer.go
|
@ -18,7 +18,7 @@ import (
|
|||
const (
|
||||
baseProtocolVersion = 4
|
||||
baseProtocolLength = uint64(16)
|
||||
baseProtocolMaxMsgSize = 10 * 1024 * 1024
|
||||
baseProtocolMaxMsgSize = 2 * 1024
|
||||
|
||||
pingInterval = 15 * time.Second
|
||||
)
|
||||
|
@ -33,9 +33,17 @@ const (
|
|||
peersMsg = 0x05
|
||||
)
|
||||
|
||||
// protoHandshake is the RLP structure of the protocol handshake.
|
||||
type protoHandshake struct {
|
||||
Version uint64
|
||||
Name string
|
||||
Caps []Cap
|
||||
ListenPort uint64
|
||||
ID discover.NodeID
|
||||
}
|
||||
|
||||
// Peer represents a connected remote node.
|
||||
type Peer struct {
|
||||
conn net.Conn
|
||||
rw *conn
|
||||
running map[string]*protoRW
|
||||
|
||||
|
@ -48,37 +56,36 @@ type Peer struct {
|
|||
// NewPeer returns a peer for testing purposes.
|
||||
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
|
||||
pipe, _ := net.Pipe()
|
||||
msgpipe, _ := MsgPipe()
|
||||
conn := &conn{msgpipe, &protoHandshake{ID: id, Name: name, Caps: caps}}
|
||||
peer := newPeer(pipe, conn, nil)
|
||||
conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
|
||||
peer := newPeer(conn, nil)
|
||||
close(peer.closed) // ensures Disconnect doesn't block
|
||||
return peer
|
||||
}
|
||||
|
||||
// ID returns the node's public key.
|
||||
func (p *Peer) ID() discover.NodeID {
|
||||
return p.rw.ID
|
||||
return p.rw.id
|
||||
}
|
||||
|
||||
// Name returns the node name that the remote node advertised.
|
||||
func (p *Peer) Name() string {
|
||||
return p.rw.Name
|
||||
return p.rw.name
|
||||
}
|
||||
|
||||
// Caps returns the capabilities (supported subprotocols) of the remote peer.
|
||||
func (p *Peer) Caps() []Cap {
|
||||
// TODO: maybe return copy
|
||||
return p.rw.Caps
|
||||
return p.rw.caps
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the network connection.
|
||||
func (p *Peer) RemoteAddr() net.Addr {
|
||||
return p.conn.RemoteAddr()
|
||||
return p.rw.fd.RemoteAddr()
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the network connection.
|
||||
func (p *Peer) LocalAddr() net.Addr {
|
||||
return p.conn.LocalAddr()
|
||||
return p.rw.fd.LocalAddr()
|
||||
}
|
||||
|
||||
// Disconnect terminates the peer connection with the given reason.
|
||||
|
@ -92,13 +99,12 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
|||
|
||||
// String implements fmt.Stringer.
|
||||
func (p *Peer) String() string {
|
||||
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
|
||||
return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
|
||||
}
|
||||
|
||||
func newPeer(fd net.Conn, conn *conn, protocols []Protocol) *Peer {
|
||||
protomap := matchProtocols(protocols, conn.Caps, conn)
|
||||
func newPeer(conn *conn, protocols []Protocol) *Peer {
|
||||
protomap := matchProtocols(protocols, conn.caps, conn)
|
||||
p := &Peer{
|
||||
conn: fd,
|
||||
rw: conn,
|
||||
running: protomap,
|
||||
disc: make(chan DiscReason),
|
||||
|
@ -117,7 +123,10 @@ func (p *Peer) run() DiscReason {
|
|||
p.startProtocols()
|
||||
|
||||
// Wait for an error or disconnect.
|
||||
var reason DiscReason
|
||||
var (
|
||||
reason DiscReason
|
||||
requested bool
|
||||
)
|
||||
select {
|
||||
case err := <-readErr:
|
||||
if r, ok := err.(DiscReason); ok {
|
||||
|
@ -131,23 +140,19 @@ func (p *Peer) run() DiscReason {
|
|||
case err := <-p.protoErr:
|
||||
reason = discReasonForError(err)
|
||||
case reason = <-p.disc:
|
||||
p.politeDisconnect(reason)
|
||||
requested = true
|
||||
}
|
||||
close(p.closed)
|
||||
p.rw.close(reason)
|
||||
p.wg.Wait()
|
||||
|
||||
if requested {
|
||||
reason = DiscRequested
|
||||
}
|
||||
|
||||
close(p.closed)
|
||||
p.wg.Wait()
|
||||
glog.V(logger.Debug).Infof("%v: Disconnected: %v\n", p, reason)
|
||||
return reason
|
||||
}
|
||||
|
||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||
if reason != DiscNetworkError {
|
||||
SendItems(p.rw, discMsg, uint(reason))
|
||||
}
|
||||
p.conn.Close()
|
||||
}
|
||||
|
||||
func (p *Peer) pingLoop() {
|
||||
ping := time.NewTicker(pingInterval)
|
||||
defer p.wg.Done()
|
||||
|
@ -254,7 +259,7 @@ func (p *Peer) startProtocols() {
|
|||
glog.V(logger.Detail).Infof("%v: Protocol %s/%d returned\n", p, proto.Name, proto.Version)
|
||||
err = errors.New("protocol returned")
|
||||
} else if err != io.EOF {
|
||||
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: \n", p, proto.Name, proto.Version, err)
|
||||
glog.V(logger.Detail).Infof("%v: Protocol %s/%d error: %v\n", p, proto.Name, proto.Version, err)
|
||||
}
|
||||
p.protoErr <- err
|
||||
p.wg.Done()
|
||||
|
@ -273,20 +278,6 @@ func (p *Peer) getProto(code uint64) (*protoRW, error) {
|
|||
return nil, newPeerError(errInvalidMsgCode, "%d", code)
|
||||
}
|
||||
|
||||
// writeProtoMsg sends the given message on behalf of the given named protocol.
|
||||
// this exists because of Server.Broadcast.
|
||||
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
|
||||
proto, ok := p.running[protoName]
|
||||
if !ok {
|
||||
return fmt.Errorf("protocol %s not handled by peer", protoName)
|
||||
}
|
||||
if msg.Code >= proto.Length {
|
||||
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
|
||||
}
|
||||
msg.Code += proto.offset
|
||||
return p.rw.WriteMsg(msg)
|
||||
}
|
||||
|
||||
type protoRW struct {
|
||||
Protocol
|
||||
in chan Msg
|
||||
|
|
|
@ -5,39 +5,17 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
errMagicTokenMismatch = iota
|
||||
errRead
|
||||
errWrite
|
||||
errMisc
|
||||
errInvalidMsgCode
|
||||
errInvalidMsgCode = iota
|
||||
errInvalidMsg
|
||||
errP2PVersionMismatch
|
||||
errPubkeyInvalid
|
||||
errPubkeyForbidden
|
||||
errProtocolBreach
|
||||
errPingTimeout
|
||||
errInvalidNetworkId
|
||||
errInvalidProtocolVersion
|
||||
)
|
||||
|
||||
var errorToString = map[int]string{
|
||||
errMagicTokenMismatch: "magic token mismatch",
|
||||
errRead: "read error",
|
||||
errWrite: "write error",
|
||||
errMisc: "misc error",
|
||||
errInvalidMsgCode: "invalid message code",
|
||||
errInvalidMsg: "invalid message",
|
||||
errP2PVersionMismatch: "P2P Version Mismatch",
|
||||
errPubkeyInvalid: "public key invalid",
|
||||
errPubkeyForbidden: "public key forbidden",
|
||||
errProtocolBreach: "protocol Breach",
|
||||
errPingTimeout: "ping timeout",
|
||||
errInvalidNetworkId: "invalid network id",
|
||||
errInvalidProtocolVersion: "invalid protocol version",
|
||||
errInvalidMsgCode: "invalid message code",
|
||||
errInvalidMsg: "invalid message",
|
||||
}
|
||||
|
||||
type peerError struct {
|
||||
Code int
|
||||
code int
|
||||
message string
|
||||
}
|
||||
|
||||
|
@ -107,23 +85,13 @@ func discReasonForError(err error) DiscReason {
|
|||
return reason
|
||||
}
|
||||
peerError, ok := err.(*peerError)
|
||||
if !ok {
|
||||
return DiscSubprotocolError
|
||||
}
|
||||
switch peerError.Code {
|
||||
case errP2PVersionMismatch:
|
||||
return DiscIncompatibleVersion
|
||||
case errPubkeyInvalid:
|
||||
return DiscInvalidIdentity
|
||||
case errPubkeyForbidden:
|
||||
return DiscUselessPeer
|
||||
case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach:
|
||||
return DiscProtocolError
|
||||
case errPingTimeout:
|
||||
return DiscReadTimeout
|
||||
case errRead, errWrite:
|
||||
return DiscNetworkError
|
||||
default:
|
||||
return DiscSubprotocolError
|
||||
if ok {
|
||||
switch peerError.code {
|
||||
case errInvalidMsgCode, errInvalidMsg:
|
||||
return DiscProtocolError
|
||||
default:
|
||||
return DiscSubprotocolError
|
||||
}
|
||||
}
|
||||
return DiscSubprotocolError
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
@ -29,24 +28,20 @@ var discard = Protocol{
|
|||
}
|
||||
|
||||
func testPeer(protos []Protocol) (func(), *conn, *Peer, <-chan DiscReason) {
|
||||
fd1, _ := net.Pipe()
|
||||
hs1 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||
hs2 := &protoHandshake{ID: randomID(), Version: baseProtocolVersion}
|
||||
fd1, fd2 := net.Pipe()
|
||||
c1 := &conn{fd: fd1, transport: newTestTransport(randomID(), fd1)}
|
||||
c2 := &conn{fd: fd2, transport: newTestTransport(randomID(), fd2)}
|
||||
for _, p := range protos {
|
||||
hs1.Caps = append(hs1.Caps, p.cap())
|
||||
hs2.Caps = append(hs2.Caps, p.cap())
|
||||
c1.caps = append(c1.caps, p.cap())
|
||||
c2.caps = append(c2.caps, p.cap())
|
||||
}
|
||||
|
||||
p1, p2 := MsgPipe()
|
||||
peer := newPeer(fd1, &conn{p1, hs1}, protos)
|
||||
peer := newPeer(c1, protos)
|
||||
errc := make(chan DiscReason, 1)
|
||||
go func() { errc <- peer.run() }()
|
||||
|
||||
closer := func() {
|
||||
p1.Close()
|
||||
fd1.Close()
|
||||
}
|
||||
return closer, &conn{p2, hs2}, peer, errc
|
||||
closer := func() { c2.close(errors.New("close func called")) }
|
||||
return closer, c2, peer, errc
|
||||
}
|
||||
|
||||
func TestPeerProtoReadMsg(t *testing.T) {
|
||||
|
@ -107,44 +102,6 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPeerWriteForBroadcast(t *testing.T) {
|
||||
closer, rw, peer, peerErr := testPeer([]Protocol{discard})
|
||||
defer closer()
|
||||
|
||||
emptymsg := func(code uint64) Msg {
|
||||
return Msg{Code: code, Size: 0, Payload: bytes.NewReader(nil)}
|
||||
}
|
||||
|
||||
// test write errors
|
||||
if err := peer.writeProtoMsg("b", emptymsg(3)); err == nil {
|
||||
t.Errorf("expected error for unknown protocol, got nil")
|
||||
}
|
||||
if err := peer.writeProtoMsg("discard", emptymsg(8)); err == nil {
|
||||
t.Errorf("expected error for out-of-range msg code, got nil")
|
||||
} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
|
||||
t.Errorf("wrong error for out-of-range msg code, got %#v", err)
|
||||
}
|
||||
|
||||
// setup for reading the message on the other end
|
||||
read := make(chan struct{})
|
||||
go func() {
|
||||
if err := ExpectMsg(rw, 16, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
close(read)
|
||||
}()
|
||||
|
||||
// test successful write
|
||||
if err := peer.writeProtoMsg("discard", emptymsg(0)); err != nil {
|
||||
t.Errorf("expect no error for known protocol: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-read:
|
||||
case err := <-peerErr:
|
||||
t.Fatalf("peer stopped: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeerPing(t *testing.T) {
|
||||
closer, rw, _, _ := testPeer(nil)
|
||||
defer closer()
|
||||
|
|
444
p2p/rlpx.go
444
p2p/rlpx.go
|
@ -4,23 +4,459 @@ import (
|
|||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
const (
|
||||
maxUint24 = ^uint32(0) >> 8
|
||||
|
||||
sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
|
||||
sigLen = 65 // elliptic S256
|
||||
pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
|
||||
shaLen = 32 // hash length (for nonce etc)
|
||||
|
||||
authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
|
||||
authRespLen = pubLen + shaLen + 1
|
||||
|
||||
eciesBytes = 65 + 16 + 32
|
||||
encAuthMsgLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
|
||||
encAuthRespLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
|
||||
|
||||
// total timeout for encryption handshake and protocol
|
||||
// handshake in both directions.
|
||||
handshakeTimeout = 5 * time.Second
|
||||
|
||||
// This is the timeout for sending the disconnect reason.
|
||||
// This is shorter than the usual timeout because we don't want
|
||||
// to wait if the connection is known to be bad anyway.
|
||||
discWriteTimeout = 1 * time.Second
|
||||
)
|
||||
|
||||
// rlpx is the transport protocol used by actual (non-test) connections.
|
||||
// It wraps the frame encoder with locks and read/write deadlines.
|
||||
type rlpx struct {
|
||||
fd net.Conn
|
||||
|
||||
rmu, wmu sync.Mutex
|
||||
rw *rlpxFrameRW
|
||||
}
|
||||
|
||||
func newRLPX(fd net.Conn) transport {
|
||||
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||
return &rlpx{fd: fd}
|
||||
}
|
||||
|
||||
func (t *rlpx) ReadMsg() (Msg, error) {
|
||||
t.rmu.Lock()
|
||||
defer t.rmu.Unlock()
|
||||
t.fd.SetReadDeadline(time.Now().Add(frameReadTimeout))
|
||||
return t.rw.ReadMsg()
|
||||
}
|
||||
|
||||
func (t *rlpx) WriteMsg(msg Msg) error {
|
||||
t.wmu.Lock()
|
||||
defer t.wmu.Unlock()
|
||||
t.fd.SetWriteDeadline(time.Now().Add(frameWriteTimeout))
|
||||
return t.rw.WriteMsg(msg)
|
||||
}
|
||||
|
||||
func (t *rlpx) close(err error) {
|
||||
t.wmu.Lock()
|
||||
defer t.wmu.Unlock()
|
||||
// Tell the remote end why we're disconnecting if possible.
|
||||
if t.rw != nil {
|
||||
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
|
||||
t.fd.SetWriteDeadline(time.Now().Add(discWriteTimeout))
|
||||
SendItems(t.rw, discMsg, r)
|
||||
}
|
||||
}
|
||||
t.fd.Close()
|
||||
}
|
||||
|
||||
// doEncHandshake runs the protocol handshake using authenticated
|
||||
// messages. the protocol handshake is the first authenticated message
|
||||
// and also verifies whether the encryption handshake 'worked' and the
|
||||
// remote side actually provided the right public key.
|
||||
func (t *rlpx) doProtoHandshake(our *protoHandshake) (their *protoHandshake, err error) {
|
||||
// Writing our handshake happens concurrently, we prefer
|
||||
// returning the handshake read error. If the remote side
|
||||
// disconnects us early with a valid reason, we should return it
|
||||
// as the error so it can be tracked elsewhere.
|
||||
werr := make(chan error, 1)
|
||||
go func() { werr <- Send(t.rw, handshakeMsg, our) }()
|
||||
if their, err = readProtocolHandshake(t.rw, our); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := <-werr; err != nil {
|
||||
return nil, fmt.Errorf("write error: %v", err)
|
||||
}
|
||||
return their, nil
|
||||
}
|
||||
|
||||
func readProtocolHandshake(rw MsgReader, our *protoHandshake) (*protoHandshake, error) {
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return nil, fmt.Errorf("message too big")
|
||||
}
|
||||
if msg.Code == discMsg {
|
||||
// Disconnect before protocol handshake is valid according to the
|
||||
// spec and we send it ourself if the posthanshake checks fail.
|
||||
// We can't return the reason directly, though, because it is echoed
|
||||
// back otherwise. Wrap it in a string instead.
|
||||
var reason [1]DiscReason
|
||||
rlp.Decode(msg.Payload, &reason)
|
||||
return nil, reason[0]
|
||||
}
|
||||
if msg.Code != handshakeMsg {
|
||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||
}
|
||||
var hs protoHandshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// validate handshake info
|
||||
if hs.Version != our.Version {
|
||||
return nil, DiscIncompatibleVersion
|
||||
}
|
||||
if (hs.ID == discover.NodeID{}) {
|
||||
return nil, DiscInvalidIdentity
|
||||
}
|
||||
return &hs, nil
|
||||
}
|
||||
|
||||
func (t *rlpx) doEncHandshake(prv *ecdsa.PrivateKey, dial *discover.Node) (discover.NodeID, error) {
|
||||
var (
|
||||
sec secrets
|
||||
err error
|
||||
)
|
||||
if dial == nil {
|
||||
sec, err = receiverEncHandshake(t.fd, prv, nil)
|
||||
} else {
|
||||
sec, err = initiatorEncHandshake(t.fd, prv, dial.ID, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return discover.NodeID{}, err
|
||||
}
|
||||
t.wmu.Lock()
|
||||
t.rw = newRLPXFrameRW(t.fd, sec)
|
||||
t.wmu.Unlock()
|
||||
return sec.RemoteID, nil
|
||||
}
|
||||
|
||||
// encHandshake contains the state of the encryption handshake.
|
||||
type encHandshake struct {
|
||||
initiator bool
|
||||
remoteID discover.NodeID
|
||||
|
||||
remotePub *ecies.PublicKey // remote-pubk
|
||||
initNonce, respNonce []byte // nonce
|
||||
randomPrivKey *ecies.PrivateKey // ecdhe-random
|
||||
remoteRandomPub *ecies.PublicKey // ecdhe-random-pubk
|
||||
}
|
||||
|
||||
// secrets represents the connection secrets
|
||||
// which are negotiated during the encryption handshake.
|
||||
type secrets struct {
|
||||
RemoteID discover.NodeID
|
||||
AES, MAC []byte
|
||||
EgressMAC, IngressMAC hash.Hash
|
||||
Token []byte
|
||||
}
|
||||
|
||||
// secrets is called after the handshake is completed.
|
||||
// It extracts the connection secrets from the handshake values.
|
||||
func (h *encHandshake) secrets(auth, authResp []byte) (secrets, error) {
|
||||
ecdheSecret, err := h.randomPrivKey.GenerateShared(h.remoteRandomPub, sskLen, sskLen)
|
||||
if err != nil {
|
||||
return secrets{}, err
|
||||
}
|
||||
|
||||
// derive base secrets from ephemeral key agreement
|
||||
sharedSecret := crypto.Sha3(ecdheSecret, crypto.Sha3(h.respNonce, h.initNonce))
|
||||
aesSecret := crypto.Sha3(ecdheSecret, sharedSecret)
|
||||
s := secrets{
|
||||
RemoteID: h.remoteID,
|
||||
AES: aesSecret,
|
||||
MAC: crypto.Sha3(ecdheSecret, aesSecret),
|
||||
Token: crypto.Sha3(sharedSecret),
|
||||
}
|
||||
|
||||
// setup sha3 instances for the MACs
|
||||
mac1 := sha3.NewKeccak256()
|
||||
mac1.Write(xor(s.MAC, h.respNonce))
|
||||
mac1.Write(auth)
|
||||
mac2 := sha3.NewKeccak256()
|
||||
mac2.Write(xor(s.MAC, h.initNonce))
|
||||
mac2.Write(authResp)
|
||||
if h.initiator {
|
||||
s.EgressMAC, s.IngressMAC = mac1, mac2
|
||||
} else {
|
||||
s.EgressMAC, s.IngressMAC = mac2, mac1
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (h *encHandshake) ecdhShared(prv *ecdsa.PrivateKey) ([]byte, error) {
|
||||
return ecies.ImportECDSA(prv).GenerateShared(h.remotePub, sskLen, sskLen)
|
||||
}
|
||||
|
||||
// initiatorEncHandshake negotiates a session token on conn.
|
||||
// it should be called on the dialing side of the connection.
|
||||
//
|
||||
// prv is the local client's private key.
|
||||
// token is the token from a previous session with this node.
|
||||
func initiatorEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, remoteID discover.NodeID, token []byte) (s secrets, err error) {
|
||||
h, err := newInitiatorHandshake(remoteID)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
auth, err := h.authMsg(prv, token)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
if _, err = conn.Write(auth); err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
response := make([]byte, encAuthRespLen)
|
||||
if _, err = io.ReadFull(conn, response); err != nil {
|
||||
return s, err
|
||||
}
|
||||
if err := h.decodeAuthResp(response, prv); err != nil {
|
||||
return s, err
|
||||
}
|
||||
return h.secrets(auth, response)
|
||||
}
|
||||
|
||||
func newInitiatorHandshake(remoteID discover.NodeID) (*encHandshake, error) {
|
||||
// generate random initiator nonce
|
||||
n := make([]byte, shaLen)
|
||||
if _, err := rand.Read(n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// generate random keypair to use for signing
|
||||
randpriv, err := ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rpub, err := remoteID.Pubkey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad remoteID: %v", err)
|
||||
}
|
||||
h := &encHandshake{
|
||||
initiator: true,
|
||||
remoteID: remoteID,
|
||||
remotePub: ecies.ImportECDSAPublic(rpub),
|
||||
initNonce: n,
|
||||
randomPrivKey: randpriv,
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// authMsg creates an encrypted initiator handshake message.
|
||||
func (h *encHandshake) authMsg(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
||||
var tokenFlag byte
|
||||
if token == nil {
|
||||
// no session token found means we need to generate shared secret.
|
||||
// ecies shared secret is used as initial session token for new peers
|
||||
// generate shared key from prv and remote pubkey
|
||||
var err error
|
||||
if token, err = h.ecdhShared(prv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// for known peers, we use stored token from the previous session
|
||||
tokenFlag = 0x01
|
||||
}
|
||||
|
||||
// sign known message:
|
||||
// ecdh-shared-secret^nonce for new peers
|
||||
// token^nonce for old peers
|
||||
signed := xor(token, h.initNonce)
|
||||
signature, err := crypto.Sign(signed, h.randomPrivKey.ExportECDSA())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// encode auth message
|
||||
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
||||
msg := make([]byte, authMsgLen)
|
||||
n := copy(msg, signature)
|
||||
n += copy(msg[n:], crypto.Sha3(exportPubkey(&h.randomPrivKey.PublicKey)))
|
||||
n += copy(msg[n:], crypto.FromECDSAPub(&prv.PublicKey)[1:])
|
||||
n += copy(msg[n:], h.initNonce)
|
||||
msg[n] = tokenFlag
|
||||
|
||||
// encrypt auth message using remote-pubk
|
||||
return ecies.Encrypt(rand.Reader, h.remotePub, msg, nil, nil)
|
||||
}
|
||||
|
||||
// decodeAuthResp decode an encrypted authentication response message.
|
||||
func (h *encHandshake) decodeAuthResp(auth []byte, prv *ecdsa.PrivateKey) error {
|
||||
msg, err := crypto.Decrypt(prv, auth)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not decrypt auth response (%v)", err)
|
||||
}
|
||||
h.respNonce = msg[pubLen : pubLen+shaLen]
|
||||
h.remoteRandomPub, err = importPublicKey(msg[:pubLen])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ignore token flag for now
|
||||
return nil
|
||||
}
|
||||
|
||||
// receiverEncHandshake negotiates a session token on conn.
|
||||
// it should be called on the listening side of the connection.
|
||||
//
|
||||
// prv is the local client's private key.
|
||||
// token is the token from a previous session with this node.
|
||||
func receiverEncHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, token []byte) (s secrets, err error) {
|
||||
// read remote auth sent by initiator.
|
||||
auth := make([]byte, encAuthMsgLen)
|
||||
if _, err := io.ReadFull(conn, auth); err != nil {
|
||||
return s, err
|
||||
}
|
||||
h, err := decodeAuthMsg(prv, token, auth)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
// send auth response
|
||||
resp, err := h.authResp(prv, token)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
if _, err = conn.Write(resp); err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
return h.secrets(auth, resp)
|
||||
}
|
||||
|
||||
func decodeAuthMsg(prv *ecdsa.PrivateKey, token []byte, auth []byte) (*encHandshake, error) {
|
||||
var err error
|
||||
h := new(encHandshake)
|
||||
// generate random keypair for session
|
||||
h.randomPrivKey, err = ecies.GenerateKey(rand.Reader, crypto.S256(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// generate random nonce
|
||||
h.respNonce = make([]byte, shaLen)
|
||||
if _, err = rand.Read(h.respNonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := crypto.Decrypt(prv, auth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not decrypt auth message (%v)", err)
|
||||
}
|
||||
|
||||
// decode message parameters
|
||||
// signature || sha3(ecdhe-random-pubk) || pubk || nonce || token-flag
|
||||
h.initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
|
||||
copy(h.remoteID[:], msg[sigLen+shaLen:sigLen+shaLen+pubLen])
|
||||
rpub, err := h.remoteID.Pubkey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad remoteID: %#v", err)
|
||||
}
|
||||
h.remotePub = ecies.ImportECDSAPublic(rpub)
|
||||
|
||||
// recover remote random pubkey from signed message.
|
||||
if token == nil {
|
||||
// TODO: it is an error if the initiator has a token and we don't. check that.
|
||||
|
||||
// no session token means we need to generate shared secret.
|
||||
// ecies shared secret is used as initial session token for new peers.
|
||||
// generate shared key from prv and remote pubkey.
|
||||
if token, err = h.ecdhShared(prv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
signedMsg := xor(token, h.initNonce)
|
||||
remoteRandomPub, err := secp256k1.RecoverPubkey(signedMsg, msg[:sigLen])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.remoteRandomPub, _ = importPublicKey(remoteRandomPub)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// authResp generates the encrypted authentication response message.
|
||||
func (h *encHandshake) authResp(prv *ecdsa.PrivateKey, token []byte) ([]byte, error) {
|
||||
// responder auth message
|
||||
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
|
||||
resp := make([]byte, authRespLen)
|
||||
n := copy(resp, exportPubkey(&h.randomPrivKey.PublicKey))
|
||||
n += copy(resp[n:], h.respNonce)
|
||||
if token == nil {
|
||||
resp[n] = 0
|
||||
} else {
|
||||
resp[n] = 1
|
||||
}
|
||||
// encrypt using remote-pubk
|
||||
return ecies.Encrypt(rand.Reader, h.remotePub, resp, nil, nil)
|
||||
}
|
||||
|
||||
// importPublicKey unmarshals 512 bit public keys.
|
||||
func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) {
|
||||
var pubKey65 []byte
|
||||
switch len(pubKey) {
|
||||
case 64:
|
||||
// add 'uncompressed key' flag
|
||||
pubKey65 = append([]byte{0x04}, pubKey...)
|
||||
case 65:
|
||||
pubKey65 = pubKey
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
|
||||
}
|
||||
// TODO: fewer pointless conversions
|
||||
return ecies.ImportECDSAPublic(crypto.ToECDSAPub(pubKey65)), nil
|
||||
}
|
||||
|
||||
func exportPubkey(pub *ecies.PublicKey) []byte {
|
||||
if pub == nil {
|
||||
panic("nil pubkey")
|
||||
}
|
||||
return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:]
|
||||
}
|
||||
|
||||
func xor(one, other []byte) (xor []byte) {
|
||||
xor = make([]byte, len(one))
|
||||
for i := 0; i < len(one); i++ {
|
||||
xor[i] = one[i] ^ other[i]
|
||||
}
|
||||
return xor
|
||||
}
|
||||
|
||||
var (
|
||||
// this is used in place of actual frame header data.
|
||||
// TODO: replace this when Msg contains the protocol type code.
|
||||
zeroHeader = []byte{0xC2, 0x80, 0x80}
|
||||
|
||||
// sixteen zero bytes
|
||||
zero16 = make([]byte, 16)
|
||||
|
||||
maxUint24 = ^uint32(0) >> 8
|
||||
)
|
||||
|
||||
// rlpxFrameRW implements a simplified version of RLPx framing.
|
||||
|
@ -38,7 +474,7 @@ type rlpxFrameRW struct {
|
|||
ingressMAC hash.Hash
|
||||
}
|
||||
|
||||
func newRlpxFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
|
||||
func newRLPXFrameRW(conn io.ReadWriter, s secrets) *rlpxFrameRW {
|
||||
macc, err := aes.NewCipher(s.MAC)
|
||||
if err != nil {
|
||||
panic("invalid MAC secret: " + err.Error())
|
||||
|
|
244
p2p/rlpx_test.go
244
p2p/rlpx_test.go
|
@ -3,19 +3,253 @@ package p2p
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
func TestRlpxFrameFake(t *testing.T) {
|
||||
func TestSharedSecret(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
pub0 := &prv0.PublicKey
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
pub1 := &prv1.PublicKey
|
||||
|
||||
ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
|
||||
if !bytes.Equal(ss0, ss1) {
|
||||
t.Errorf("dont match :(")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncHandshake(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
start := time.Now()
|
||||
if err := testEncHandshake(nil); err != nil {
|
||||
t.Fatalf("i=%d %v", i, err)
|
||||
}
|
||||
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
tok := make([]byte, shaLen)
|
||||
rand.Reader.Read(tok)
|
||||
start := time.Now()
|
||||
if err := testEncHandshake(tok); err != nil {
|
||||
t.Fatalf("i=%d %v", i, err)
|
||||
}
|
||||
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func testEncHandshake(token []byte) error {
|
||||
type result struct {
|
||||
side string
|
||||
id discover.NodeID
|
||||
err error
|
||||
}
|
||||
var (
|
||||
prv0, _ = crypto.GenerateKey()
|
||||
prv1, _ = crypto.GenerateKey()
|
||||
fd0, fd1 = net.Pipe()
|
||||
c0, c1 = newRLPX(fd0).(*rlpx), newRLPX(fd1).(*rlpx)
|
||||
output = make(chan result)
|
||||
)
|
||||
|
||||
go func() {
|
||||
r := result{side: "initiator"}
|
||||
defer func() { output <- r }()
|
||||
|
||||
dest := &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey)}
|
||||
r.id, r.err = c0.doEncHandshake(prv0, dest)
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
id1 := discover.PubkeyID(&prv1.PublicKey)
|
||||
if r.id != id1 {
|
||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id1)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
r := result{side: "receiver"}
|
||||
defer func() { output <- r }()
|
||||
|
||||
r.id, r.err = c1.doEncHandshake(prv1, nil)
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
id0 := discover.PubkeyID(&prv0.PublicKey)
|
||||
if r.id != id0 {
|
||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.id, id0)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for results from both sides
|
||||
r1, r2 := <-output, <-output
|
||||
if r1.err != nil {
|
||||
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
||||
}
|
||||
if r2.err != nil {
|
||||
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
||||
}
|
||||
|
||||
// compare derived secrets
|
||||
if !reflect.DeepEqual(c0.rw.egressMAC, c1.rw.ingressMAC) {
|
||||
return fmt.Errorf("egress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.egressMAC, c1.rw.ingressMAC)
|
||||
}
|
||||
if !reflect.DeepEqual(c0.rw.ingressMAC, c1.rw.egressMAC) {
|
||||
return fmt.Errorf("ingress mac mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.ingressMAC, c1.rw.egressMAC)
|
||||
}
|
||||
if !reflect.DeepEqual(c0.rw.enc, c1.rw.enc) {
|
||||
return fmt.Errorf("enc cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.enc, c1.rw.enc)
|
||||
}
|
||||
if !reflect.DeepEqual(c0.rw.dec, c1.rw.dec) {
|
||||
return fmt.Errorf("dec cipher mismatch:\n c0.rw: %#v\n c1.rw: %#v", c0.rw.dec, c1.rw.dec)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestProtocolHandshake(t *testing.T) {
|
||||
var (
|
||||
prv0, _ = crypto.GenerateKey()
|
||||
node0 = &discover.Node{ID: discover.PubkeyID(&prv0.PublicKey), IP: net.IP{1, 2, 3, 4}, TCP: 33}
|
||||
hs0 = &protoHandshake{Version: 3, ID: node0.ID, Caps: []Cap{{"a", 0}, {"b", 2}}}
|
||||
|
||||
prv1, _ = crypto.GenerateKey()
|
||||
node1 = &discover.Node{ID: discover.PubkeyID(&prv1.PublicKey), IP: net.IP{5, 6, 7, 8}, TCP: 44}
|
||||
hs1 = &protoHandshake{Version: 3, ID: node1.ID, Caps: []Cap{{"c", 1}, {"d", 3}}}
|
||||
|
||||
fd0, fd1 = net.Pipe()
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rlpx := newRLPX(fd0)
|
||||
remid, err := rlpx.doEncHandshake(prv0, node1)
|
||||
if err != nil {
|
||||
t.Errorf("dial side enc handshake failed: %v", err)
|
||||
return
|
||||
}
|
||||
if remid != node1.ID {
|
||||
t.Errorf("dial side remote id mismatch: got %v, want %v", remid, node1.ID)
|
||||
return
|
||||
}
|
||||
|
||||
phs, err := rlpx.doProtoHandshake(hs0)
|
||||
if err != nil {
|
||||
t.Errorf("dial side proto handshake error: %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(phs, hs1) {
|
||||
t.Errorf("dial side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs1))
|
||||
return
|
||||
}
|
||||
rlpx.close(DiscQuitting)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rlpx := newRLPX(fd1)
|
||||
remid, err := rlpx.doEncHandshake(prv1, nil)
|
||||
if err != nil {
|
||||
t.Errorf("listen side enc handshake failed: %v", err)
|
||||
return
|
||||
}
|
||||
if remid != node0.ID {
|
||||
t.Errorf("listen side remote id mismatch: got %v, want %v", remid, node0.ID)
|
||||
return
|
||||
}
|
||||
|
||||
phs, err := rlpx.doProtoHandshake(hs1)
|
||||
if err != nil {
|
||||
t.Errorf("listen side proto handshake error: %v", err)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(phs, hs0) {
|
||||
t.Errorf("listen side proto handshake mismatch:\ngot: %s\nwant: %s\n", spew.Sdump(phs), spew.Sdump(hs0))
|
||||
return
|
||||
}
|
||||
|
||||
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
|
||||
t.Errorf("error receiving disconnect: %v", err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestProtocolHandshakeErrors(t *testing.T) {
|
||||
our := &protoHandshake{Version: 3, Caps: []Cap{{"foo", 2}, {"bar", 3}}, Name: "quux"}
|
||||
id := randomID()
|
||||
tests := []struct {
|
||||
code uint64
|
||||
msg interface{}
|
||||
err error
|
||||
}{
|
||||
{
|
||||
code: discMsg,
|
||||
msg: []DiscReason{DiscQuitting},
|
||||
err: DiscQuitting,
|
||||
},
|
||||
{
|
||||
code: 0x989898,
|
||||
msg: []byte{1},
|
||||
err: errors.New("expected handshake, got 989898"),
|
||||
},
|
||||
{
|
||||
code: handshakeMsg,
|
||||
msg: make([]byte, baseProtocolMaxMsgSize+2),
|
||||
err: errors.New("message too big"),
|
||||
},
|
||||
{
|
||||
code: handshakeMsg,
|
||||
msg: []byte{1, 2, 3},
|
||||
err: newPeerError(errInvalidMsg, "(code 0) (size 4) rlp: expected input list for p2p.protoHandshake"),
|
||||
},
|
||||
{
|
||||
code: handshakeMsg,
|
||||
msg: &protoHandshake{Version: 9944, ID: id},
|
||||
err: DiscIncompatibleVersion,
|
||||
},
|
||||
{
|
||||
code: handshakeMsg,
|
||||
msg: &protoHandshake{Version: 3},
|
||||
err: DiscInvalidIdentity,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
p1, p2 := MsgPipe()
|
||||
go Send(p1, test.code, test.msg)
|
||||
_, err := readProtocolHandshake(p2, our)
|
||||
if !reflect.DeepEqual(err, test.err) {
|
||||
t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRLPXFrameFake(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
hash := fakeHash([]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})
|
||||
rw := newRlpxFrameRW(buf, secrets{
|
||||
rw := newRLPXFrameRW(buf, secrets{
|
||||
AES: crypto.Sha3(),
|
||||
MAC: crypto.Sha3(),
|
||||
IngressMAC: hash,
|
||||
|
@ -66,7 +300,7 @@ func (fakeHash) BlockSize() int { return 0 }
|
|||
func (h fakeHash) Size() int { return len(h) }
|
||||
func (h fakeHash) Sum(b []byte) []byte { return append(b, h...) }
|
||||
|
||||
func TestRlpxFrameRW(t *testing.T) {
|
||||
func TestRLPXFrameRW(t *testing.T) {
|
||||
var (
|
||||
aesSecret = make([]byte, 16)
|
||||
macSecret = make([]byte, 16)
|
||||
|
@ -86,7 +320,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
|||
}
|
||||
s1.EgressMAC.Write(egressMACinit)
|
||||
s1.IngressMAC.Write(ingressMACinit)
|
||||
rw1 := newRlpxFrameRW(conn, s1)
|
||||
rw1 := newRLPXFrameRW(conn, s1)
|
||||
|
||||
s2 := secrets{
|
||||
AES: aesSecret,
|
||||
|
@ -96,7 +330,7 @@ func TestRlpxFrameRW(t *testing.T) {
|
|||
}
|
||||
s2.EgressMAC.Write(ingressMACinit)
|
||||
s2.IngressMAC.Write(egressMACinit)
|
||||
rw2 := newRlpxFrameRW(conn, s2)
|
||||
rw2 := newRLPXFrameRW(conn, s2)
|
||||
|
||||
// send some messages
|
||||
for i := 0; i < 10; i++ {
|
||||
|
|
718
p2p/server.go
718
p2p/server.go
|
@ -1,9 +1,7 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -14,7 +12,6 @@ import (
|
|||
"github.com/ethereum/go-ethereum/logger/glog"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/nat"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -26,18 +23,18 @@ const (
|
|||
maxAcceptConns = 50
|
||||
|
||||
// Maximum number of concurrently dialing outbound connections.
|
||||
maxDialingConns = 10
|
||||
maxActiveDialTasks = 16
|
||||
|
||||
// total timeout for encryption handshake and protocol
|
||||
// handshake in both directions.
|
||||
handshakeTimeout = 5 * time.Second
|
||||
// maximum time allowed for reading a complete message.
|
||||
// this is effectively the amount of time a connection can be idle.
|
||||
frameReadTimeout = 1 * time.Minute
|
||||
// maximum amount of time allowed for writing a complete message.
|
||||
// Maximum time allowed for reading a complete message.
|
||||
// This is effectively the amount of time a connection can be idle.
|
||||
frameReadTimeout = 30 * time.Second
|
||||
|
||||
// Maximum amount of time allowed for writing a complete message.
|
||||
frameWriteTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
var errServerStopped = errors.New("server stopped")
|
||||
|
||||
var srvjslog = logger.NewJsonLogger()
|
||||
|
||||
// Server manages all peer connections.
|
||||
|
@ -105,107 +102,173 @@ type Server struct {
|
|||
|
||||
// Hooks for testing. These are useful because we can inhibit
|
||||
// the whole protocol stack.
|
||||
setupFunc
|
||||
newPeerHook
|
||||
newTransport func(net.Conn) transport
|
||||
newPeerHook func(*Peer)
|
||||
|
||||
lock sync.Mutex // protects running
|
||||
running bool
|
||||
|
||||
ntab discoverTable
|
||||
listener net.Listener
|
||||
ourHandshake *protoHandshake
|
||||
|
||||
lock sync.RWMutex // protects running, peers and the trust fields
|
||||
running bool
|
||||
peers map[discover.NodeID]*Peer
|
||||
staticNodes map[discover.NodeID]*discover.Node // Map of currently maintained static remote nodes
|
||||
staticDial chan *discover.Node // Dial request channel reserved for the static nodes
|
||||
staticCycle time.Duration // Overrides staticPeerCheckInterval, used for testing
|
||||
trustedNodes map[discover.NodeID]bool // Set of currently trusted remote nodes
|
||||
// These are for Peers, PeerCount (and nothing else).
|
||||
peerOp chan peerOpFunc
|
||||
peerOpDone chan struct{}
|
||||
|
||||
ntab *discover.Table
|
||||
listener net.Listener
|
||||
|
||||
quit chan struct{}
|
||||
loopWG sync.WaitGroup // {dial,listen,nat}Loop
|
||||
peerWG sync.WaitGroup // active peer goroutines
|
||||
quit chan struct{}
|
||||
addstatic chan *discover.Node
|
||||
posthandshake chan *conn
|
||||
addpeer chan *conn
|
||||
delpeer chan *Peer
|
||||
loopWG sync.WaitGroup // loop, listenLoop
|
||||
}
|
||||
|
||||
type setupFunc func(net.Conn, *ecdsa.PrivateKey, *protoHandshake, *discover.Node, func(discover.NodeID) bool) (*conn, error)
|
||||
type newPeerHook func(*Peer)
|
||||
type peerOpFunc func(map[discover.NodeID]*Peer)
|
||||
|
||||
type connFlag int
|
||||
|
||||
const (
|
||||
dynDialedConn connFlag = 1 << iota
|
||||
staticDialedConn
|
||||
inboundConn
|
||||
trustedConn
|
||||
)
|
||||
|
||||
// conn wraps a network connection with information gathered
|
||||
// during the two handshakes.
|
||||
type conn struct {
|
||||
fd net.Conn
|
||||
transport
|
||||
flags connFlag
|
||||
cont chan error // The run loop uses cont to signal errors to setupConn.
|
||||
id discover.NodeID // valid after the encryption handshake
|
||||
caps []Cap // valid after the protocol handshake
|
||||
name string // valid after the protocol handshake
|
||||
}
|
||||
|
||||
type transport interface {
|
||||
// The two handshakes.
|
||||
doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error)
|
||||
doProtoHandshake(our *protoHandshake) (*protoHandshake, error)
|
||||
// The MsgReadWriter can only be used after the encryption
|
||||
// handshake has completed. The code uses conn.id to track this
|
||||
// by setting it to a non-nil value after the encryption handshake.
|
||||
MsgReadWriter
|
||||
// transports must provide Close because we use MsgPipe in some of
|
||||
// the tests. Closing the actual network connection doesn't do
|
||||
// anything in those tests because NsgPipe doesn't use it.
|
||||
close(err error)
|
||||
}
|
||||
|
||||
func (c *conn) String() string {
|
||||
s := c.flags.String() + " conn"
|
||||
if (c.id != discover.NodeID{}) {
|
||||
s += fmt.Sprintf(" %x", c.id[:8])
|
||||
}
|
||||
s += " " + c.fd.RemoteAddr().String()
|
||||
return s
|
||||
}
|
||||
|
||||
func (f connFlag) String() string {
|
||||
s := ""
|
||||
if f&trustedConn != 0 {
|
||||
s += " trusted"
|
||||
}
|
||||
if f&dynDialedConn != 0 {
|
||||
s += " dyn dial"
|
||||
}
|
||||
if f&staticDialedConn != 0 {
|
||||
s += " static dial"
|
||||
}
|
||||
if f&inboundConn != 0 {
|
||||
s += " inbound"
|
||||
}
|
||||
if s != "" {
|
||||
s = s[1:]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *conn) is(f connFlag) bool {
|
||||
return c.flags&f != 0
|
||||
}
|
||||
|
||||
// Peers returns all connected peers.
|
||||
func (srv *Server) Peers() (peers []*Peer) {
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
for _, peer := range srv.peers {
|
||||
if peer != nil {
|
||||
peers = append(peers, peer)
|
||||
func (srv *Server) Peers() []*Peer {
|
||||
var ps []*Peer
|
||||
select {
|
||||
// Note: We'd love to put this function into a variable but
|
||||
// that seems to cause a weird compiler error in some
|
||||
// environments.
|
||||
case srv.peerOp <- func(peers map[discover.NodeID]*Peer) {
|
||||
for _, p := range peers {
|
||||
ps = append(ps, p)
|
||||
}
|
||||
}:
|
||||
<-srv.peerOpDone
|
||||
case <-srv.quit:
|
||||
}
|
||||
return
|
||||
return ps
|
||||
}
|
||||
|
||||
// PeerCount returns the number of connected peers.
|
||||
func (srv *Server) PeerCount() int {
|
||||
srv.lock.RLock()
|
||||
n := len(srv.peers)
|
||||
srv.lock.RUnlock()
|
||||
return n
|
||||
var count int
|
||||
select {
|
||||
case srv.peerOp <- func(ps map[discover.NodeID]*Peer) { count = len(ps) }:
|
||||
<-srv.peerOpDone
|
||||
case <-srv.quit:
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// AddPeer connects to the given node and maintains the connection until the
|
||||
// server is shut down. If the connection fails for any reason, the server will
|
||||
// attempt to reconnect the peer.
|
||||
func (srv *Server) AddPeer(node *discover.Node) {
|
||||
select {
|
||||
case srv.addstatic <- node:
|
||||
case <-srv.quit:
|
||||
}
|
||||
}
|
||||
|
||||
// Self returns the local node's endpoint information.
|
||||
func (srv *Server) Self() *discover.Node {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
|
||||
srv.staticNodes[node.ID] = node
|
||||
if !srv.running {
|
||||
return &discover.Node{IP: net.ParseIP("0.0.0.0")}
|
||||
}
|
||||
return srv.ntab.Self()
|
||||
}
|
||||
|
||||
// Broadcast sends an RLP-encoded message to all connected peers.
|
||||
// This method is deprecated and will be removed later.
|
||||
func (srv *Server) Broadcast(protocol string, code uint64, data interface{}) error {
|
||||
return srv.BroadcastLimited(protocol, code, func(i float64) float64 { return i }, data)
|
||||
}
|
||||
|
||||
// BroadcastsRange an RLP-encoded message to a random set of peers using the limit function to limit the amount
|
||||
// of peers.
|
||||
func (srv *Server) BroadcastLimited(protocol string, code uint64, limit func(float64) float64, data interface{}) error {
|
||||
var payload []byte
|
||||
if data != nil {
|
||||
var err error
|
||||
payload, err = rlp.EncodeToBytes(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Stop terminates the server and all active peer connections.
|
||||
// It blocks until all active connections have been closed.
|
||||
func (srv *Server) Stop() {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
if !srv.running {
|
||||
return
|
||||
}
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
|
||||
i, max := 0, int(limit(float64(len(srv.peers))))
|
||||
for _, peer := range srv.peers {
|
||||
if i >= max {
|
||||
break
|
||||
}
|
||||
|
||||
if peer != nil {
|
||||
var msg = Msg{Code: code}
|
||||
if data != nil {
|
||||
msg.Payload = bytes.NewReader(payload)
|
||||
msg.Size = uint32(len(payload))
|
||||
}
|
||||
peer.writeProtoMsg(protocol, msg)
|
||||
i++
|
||||
}
|
||||
srv.running = false
|
||||
if srv.listener != nil {
|
||||
// this unblocks listener Accept
|
||||
srv.listener.Close()
|
||||
}
|
||||
return nil
|
||||
close(srv.quit)
|
||||
srv.loopWG.Wait()
|
||||
}
|
||||
|
||||
// Start starts running the server.
|
||||
// Servers can be re-used and started again after stopping.
|
||||
// Servers can not be re-used after stopping.
|
||||
func (srv *Server) Start() (err error) {
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
if srv.running {
|
||||
return errors.New("server already running")
|
||||
}
|
||||
srv.running = true
|
||||
glog.V(logger.Info).Infoln("Starting Server")
|
||||
|
||||
// static fields
|
||||
|
@ -215,23 +278,19 @@ func (srv *Server) Start() (err error) {
|
|||
if srv.MaxPeers <= 0 {
|
||||
return fmt.Errorf("Server.MaxPeers must be > 0")
|
||||
}
|
||||
if srv.newTransport == nil {
|
||||
srv.newTransport = newRLPX
|
||||
}
|
||||
if srv.Dialer == nil {
|
||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||
}
|
||||
srv.quit = make(chan struct{})
|
||||
srv.peers = make(map[discover.NodeID]*Peer)
|
||||
|
||||
// Create the current trust maps, and the associated dialing channel
|
||||
srv.trustedNodes = make(map[discover.NodeID]bool)
|
||||
for _, node := range srv.TrustedNodes {
|
||||
srv.trustedNodes[node.ID] = true
|
||||
}
|
||||
srv.staticNodes = make(map[discover.NodeID]*discover.Node)
|
||||
for _, node := range srv.StaticNodes {
|
||||
srv.staticNodes[node.ID] = node
|
||||
}
|
||||
srv.staticDial = make(chan *discover.Node)
|
||||
|
||||
if srv.setupFunc == nil {
|
||||
srv.setupFunc = setupConn
|
||||
}
|
||||
srv.addpeer = make(chan *conn)
|
||||
srv.delpeer = make(chan *Peer)
|
||||
srv.posthandshake = make(chan *conn)
|
||||
srv.addstatic = make(chan *discover.Node)
|
||||
srv.peerOp = make(chan peerOpFunc)
|
||||
srv.peerOpDone = make(chan struct{})
|
||||
|
||||
// node table
|
||||
ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase)
|
||||
|
@ -239,37 +298,31 @@ func (srv *Server) Start() (err error) {
|
|||
return err
|
||||
}
|
||||
srv.ntab = ntab
|
||||
dialer := newDialState(srv.StaticNodes, srv.ntab, srv.MaxPeers/2)
|
||||
|
||||
// handshake
|
||||
srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: ntab.Self().ID}
|
||||
for _, p := range srv.Protocols {
|
||||
srv.ourHandshake.Caps = append(srv.ourHandshake.Caps, p.cap())
|
||||
}
|
||||
|
||||
// listen/dial
|
||||
if srv.ListenAddr != "" {
|
||||
if err := srv.startListening(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if srv.Dialer == nil {
|
||||
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
|
||||
}
|
||||
if !srv.NoDial {
|
||||
srv.loopWG.Add(1)
|
||||
go srv.dialLoop()
|
||||
}
|
||||
if srv.NoDial && srv.ListenAddr == "" {
|
||||
glog.V(logger.Warn).Infoln("I will be kind-of useless, neither dialing nor listening.")
|
||||
}
|
||||
// maintain the static peers
|
||||
go srv.staticNodesLoop()
|
||||
|
||||
srv.loopWG.Add(1)
|
||||
go srv.run(dialer)
|
||||
srv.running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *Server) startListening() error {
|
||||
// Launch the TCP listener.
|
||||
listener, err := net.Listen("tcp", srv.ListenAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -279,6 +332,7 @@ func (srv *Server) startListening() error {
|
|||
srv.listener = listener
|
||||
srv.loopWG.Add(1)
|
||||
go srv.listenLoop()
|
||||
// Map the TCP listening port if NAT is configured.
|
||||
if !laddr.IP.IsLoopback() && srv.NAT != nil {
|
||||
srv.loopWG.Add(1)
|
||||
go func() {
|
||||
|
@ -289,50 +343,164 @@ func (srv *Server) startListening() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Stop terminates the server and all active peer connections.
|
||||
// It blocks until all active connections have been closed.
|
||||
func (srv *Server) Stop() {
|
||||
srv.lock.Lock()
|
||||
if !srv.running {
|
||||
srv.lock.Unlock()
|
||||
return
|
||||
}
|
||||
srv.running = false
|
||||
srv.lock.Unlock()
|
||||
type dialer interface {
|
||||
newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task
|
||||
taskDone(task, time.Time)
|
||||
addStatic(*discover.Node)
|
||||
}
|
||||
|
||||
glog.V(logger.Info).Infoln("Stopping Server")
|
||||
func (srv *Server) run(dialstate dialer) {
|
||||
defer srv.loopWG.Done()
|
||||
var (
|
||||
peers = make(map[discover.NodeID]*Peer)
|
||||
trusted = make(map[discover.NodeID]bool, len(srv.TrustedNodes))
|
||||
|
||||
tasks []task
|
||||
pendingTasks []task
|
||||
taskdone = make(chan task, maxActiveDialTasks)
|
||||
)
|
||||
// Put trusted nodes into a map to speed up checks.
|
||||
// Trusted peers are loaded on startup and cannot be
|
||||
// modified while the server is running.
|
||||
for _, n := range srv.TrustedNodes {
|
||||
trusted[n.ID] = true
|
||||
}
|
||||
|
||||
// Some task list helpers.
|
||||
delTask := func(t task) {
|
||||
for i := range tasks {
|
||||
if tasks[i] == t {
|
||||
tasks = append(tasks[:i], tasks[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
scheduleTasks := func(new []task) {
|
||||
pt := append(pendingTasks, new...)
|
||||
start := maxActiveDialTasks - len(tasks)
|
||||
if len(pt) < start {
|
||||
start = len(pt)
|
||||
}
|
||||
if start > 0 {
|
||||
tasks = append(tasks, pt[:start]...)
|
||||
for _, t := range pt[:start] {
|
||||
t := t
|
||||
glog.V(logger.Detail).Infoln("new task:", t)
|
||||
go func() { t.Do(srv); taskdone <- t }()
|
||||
}
|
||||
copy(pt, pt[start:])
|
||||
pendingTasks = pt[:len(pt)-start]
|
||||
}
|
||||
}
|
||||
|
||||
running:
|
||||
for {
|
||||
// Query the dialer for new tasks and launch them.
|
||||
now := time.Now()
|
||||
nt := dialstate.newTasks(len(pendingTasks)+len(tasks), peers, now)
|
||||
scheduleTasks(nt)
|
||||
|
||||
select {
|
||||
case <-srv.quit:
|
||||
// The server was stopped. Run the cleanup logic.
|
||||
glog.V(logger.Detail).Infoln("<-quit: spinning down")
|
||||
break running
|
||||
case n := <-srv.addstatic:
|
||||
// This channel is used by AddPeer to add to the
|
||||
// ephemeral static peer list. Add it to the dialer,
|
||||
// it will keep the node connected.
|
||||
glog.V(logger.Detail).Infoln("<-addstatic:", n)
|
||||
dialstate.addStatic(n)
|
||||
case op := <-srv.peerOp:
|
||||
// This channel is used by Peers and PeerCount.
|
||||
op(peers)
|
||||
srv.peerOpDone <- struct{}{}
|
||||
case t := <-taskdone:
|
||||
// A task got done. Tell dialstate about it so it
|
||||
// can update its state and remove it from the active
|
||||
// tasks list.
|
||||
glog.V(logger.Detail).Infoln("<-taskdone:", t)
|
||||
dialstate.taskDone(t, now)
|
||||
delTask(t)
|
||||
case c := <-srv.posthandshake:
|
||||
// A connection has passed the encryption handshake so
|
||||
// the remote identity is known (but hasn't been verified yet).
|
||||
if trusted[c.id] {
|
||||
// Ensure that the trusted flag is set before checking against MaxPeers.
|
||||
c.flags |= trustedConn
|
||||
}
|
||||
glog.V(logger.Detail).Infoln("<-posthandshake:", c)
|
||||
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
|
||||
c.cont <- srv.encHandshakeChecks(peers, c)
|
||||
case c := <-srv.addpeer:
|
||||
// At this point the connection is past the protocol handshake.
|
||||
// Its capabilities are known and the remote identity is verified.
|
||||
glog.V(logger.Detail).Infoln("<-addpeer:", c)
|
||||
err := srv.protoHandshakeChecks(peers, c)
|
||||
if err != nil {
|
||||
glog.V(logger.Detail).Infof("Not adding %v as peer: %v", c, err)
|
||||
} else {
|
||||
// The handshakes are done and it passed all checks.
|
||||
p := newPeer(c, srv.Protocols)
|
||||
peers[c.id] = p
|
||||
go srv.runPeer(p)
|
||||
}
|
||||
// The dialer logic relies on the assumption that
|
||||
// dial tasks complete after the peer has been added or
|
||||
// discarded. Unblock the task last.
|
||||
c.cont <- err
|
||||
case p := <-srv.delpeer:
|
||||
// A peer disconnected.
|
||||
glog.V(logger.Detail).Infoln("<-delpeer:", p)
|
||||
delete(peers, p.ID())
|
||||
}
|
||||
}
|
||||
|
||||
// Terminate discovery. If there is a running lookup it will terminate soon.
|
||||
srv.ntab.Close()
|
||||
if srv.listener != nil {
|
||||
// this unblocks listener Accept
|
||||
srv.listener.Close()
|
||||
// Disconnect all peers.
|
||||
for _, p := range peers {
|
||||
p.Disconnect(DiscQuitting)
|
||||
}
|
||||
close(srv.quit)
|
||||
srv.loopWG.Wait()
|
||||
|
||||
// No new peers can be added at this point because dialLoop and
|
||||
// listenLoop are down. It is safe to call peerWG.Wait because
|
||||
// peerWG.Add is not called outside of those loops.
|
||||
srv.lock.Lock()
|
||||
for _, peer := range srv.peers {
|
||||
peer.Disconnect(DiscQuitting)
|
||||
// Wait for peers to shut down. Pending connections and tasks are
|
||||
// not handled here and will terminate soon-ish because srv.quit
|
||||
// is closed.
|
||||
glog.V(logger.Detail).Infof("ignoring %d pending tasks at spindown", len(tasks))
|
||||
for len(peers) > 0 {
|
||||
p := <-srv.delpeer
|
||||
glog.V(logger.Detail).Infoln("<-delpeer (spindown):", p)
|
||||
delete(peers, p.ID())
|
||||
}
|
||||
srv.lock.Unlock()
|
||||
srv.peerWG.Wait()
|
||||
}
|
||||
|
||||
// Self returns the local node's endpoint information.
|
||||
func (srv *Server) Self() *discover.Node {
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
if !srv.running {
|
||||
return &discover.Node{IP: net.ParseIP("0.0.0.0")}
|
||||
func (srv *Server) protoHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
|
||||
// Drop connections with no matching protocols.
|
||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, c.caps) == 0 {
|
||||
return DiscUselessPeer
|
||||
}
|
||||
return srv.ntab.Self()
|
||||
// Repeat the encryption handshake checks because the
|
||||
// peer set might have changed between the handshakes.
|
||||
return srv.encHandshakeChecks(peers, c)
|
||||
}
|
||||
|
||||
// main loop for adding connections via listening
|
||||
func (srv *Server) encHandshakeChecks(peers map[discover.NodeID]*Peer, c *conn) error {
|
||||
switch {
|
||||
case !c.is(trustedConn|staticDialedConn) && len(peers) >= srv.MaxPeers:
|
||||
return DiscTooManyPeers
|
||||
case peers[c.id] != nil:
|
||||
return DiscAlreadyConnected
|
||||
case c.id == srv.ntab.Self().ID:
|
||||
return DiscSelf
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// listenLoop runs in its own goroutine and accepts
|
||||
// inbound connections.
|
||||
func (srv *Server) listenLoop() {
|
||||
defer srv.loopWG.Done()
|
||||
glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
|
||||
|
||||
// This channel acts as a semaphore limiting
|
||||
// active inbound connections that are lingering pre-handshake.
|
||||
|
@ -346,204 +514,92 @@ func (srv *Server) listenLoop() {
|
|||
slots <- struct{}{}
|
||||
}
|
||||
|
||||
glog.V(logger.Info).Infoln("Listening on", srv.listener.Addr())
|
||||
for {
|
||||
<-slots
|
||||
conn, err := srv.listener.Accept()
|
||||
fd, err := srv.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
glog.V(logger.Debug).Infof("Accepted conn %v\n", conn.RemoteAddr())
|
||||
srv.peerWG.Add(1)
|
||||
glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr())
|
||||
go func() {
|
||||
srv.startPeer(conn, nil)
|
||||
srv.setupConn(fd, inboundConn, nil)
|
||||
slots <- struct{}{}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// staticNodesLoop is responsible for periodically checking that static
|
||||
// connections are actually live, and requests dialing if not.
|
||||
func (srv *Server) staticNodesLoop() {
|
||||
// Create a default maintenance ticker, but override it requested
|
||||
cycle := staticPeerCheckInterval
|
||||
if srv.staticCycle != 0 {
|
||||
cycle = srv.staticCycle
|
||||
// setupConn runs the handshakes and attempts to add the connection
|
||||
// as a peer. It returns when the connection has been added as a peer
|
||||
// or the handshakes have failed.
|
||||
func (srv *Server) setupConn(fd net.Conn, flags connFlag, dialDest *discover.Node) {
|
||||
// Prevent leftover pending conns from entering the handshake.
|
||||
srv.lock.Lock()
|
||||
running := srv.running
|
||||
srv.lock.Unlock()
|
||||
c := &conn{fd: fd, transport: srv.newTransport(fd), flags: flags, cont: make(chan error)}
|
||||
if !running {
|
||||
c.close(errServerStopped)
|
||||
return
|
||||
}
|
||||
tick := time.NewTicker(cycle)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-srv.quit:
|
||||
return
|
||||
|
||||
case <-tick.C:
|
||||
// Collect all the non-connected static nodes
|
||||
needed := []*discover.Node{}
|
||||
srv.lock.RLock()
|
||||
for id, node := range srv.staticNodes {
|
||||
if _, ok := srv.peers[id]; !ok {
|
||||
needed = append(needed, node)
|
||||
}
|
||||
}
|
||||
srv.lock.RUnlock()
|
||||
|
||||
// Try to dial each of them (don't hang if server terminates)
|
||||
for _, node := range needed {
|
||||
glog.V(logger.Debug).Infof("Dialing static peer %v", node)
|
||||
select {
|
||||
case srv.staticDial <- node:
|
||||
case <-srv.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Run the encryption handshake.
|
||||
var err error
|
||||
if c.id, err = c.doEncHandshake(srv.PrivateKey, dialDest); err != nil {
|
||||
glog.V(logger.Debug).Infof("%v faild enc handshake: %v", c, err)
|
||||
c.close(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) dialLoop() {
|
||||
var (
|
||||
dialed = make(chan *discover.Node)
|
||||
dialing = make(map[discover.NodeID]bool)
|
||||
findresults = make(chan []*discover.Node)
|
||||
refresh = time.NewTimer(0)
|
||||
)
|
||||
defer srv.loopWG.Done()
|
||||
defer refresh.Stop()
|
||||
|
||||
// Limit the number of concurrent dials
|
||||
tokens := maxDialingConns
|
||||
if srv.MaxPendingPeers > 0 {
|
||||
tokens = srv.MaxPendingPeers
|
||||
// For dialed connections, check that the remote public key matches.
|
||||
if dialDest != nil && c.id != dialDest.ID {
|
||||
c.close(DiscUnexpectedIdentity)
|
||||
glog.V(logger.Debug).Infof("%v dialed identity mismatch, want %x", c, dialDest.ID[:8])
|
||||
return
|
||||
}
|
||||
slots := make(chan struct{}, tokens)
|
||||
for i := 0; i < tokens; i++ {
|
||||
slots <- struct{}{}
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||
glog.V(logger.Debug).Infof("%v failed checkpoint posthandshake: %v", c, err)
|
||||
c.close(err)
|
||||
return
|
||||
}
|
||||
dial := func(dest *discover.Node) {
|
||||
// Don't dial nodes that would fail the checks in addPeer.
|
||||
// This is important because the connection handshake is a lot
|
||||
// of work and we'd rather avoid doing that work for peers
|
||||
// that can't be added.
|
||||
srv.lock.RLock()
|
||||
ok, _ := srv.checkPeer(dest.ID)
|
||||
srv.lock.RUnlock()
|
||||
if !ok || dialing[dest.ID] {
|
||||
return
|
||||
}
|
||||
// Request a dial slot to prevent CPU exhaustion
|
||||
<-slots
|
||||
|
||||
dialing[dest.ID] = true
|
||||
srv.peerWG.Add(1)
|
||||
go func() {
|
||||
srv.dialNode(dest)
|
||||
slots <- struct{}{}
|
||||
dialed <- dest
|
||||
}()
|
||||
}
|
||||
|
||||
srv.ntab.Bootstrap(srv.BootstrapNodes)
|
||||
for {
|
||||
select {
|
||||
case <-refresh.C:
|
||||
// Grab some nodes to connect to if we're not at capacity.
|
||||
srv.lock.RLock()
|
||||
needpeers := len(srv.peers) < srv.MaxPeers/2
|
||||
srv.lock.RUnlock()
|
||||
if needpeers {
|
||||
go func() {
|
||||
var target discover.NodeID
|
||||
rand.Read(target[:])
|
||||
findresults <- srv.ntab.Lookup(target)
|
||||
}()
|
||||
} else {
|
||||
// Make sure we check again if the peer count falls
|
||||
// below MaxPeers.
|
||||
refresh.Reset(refreshPeersInterval)
|
||||
}
|
||||
case dest := <-srv.staticDial:
|
||||
dial(dest)
|
||||
case dests := <-findresults:
|
||||
for _, dest := range dests {
|
||||
dial(dest)
|
||||
}
|
||||
refresh.Reset(refreshPeersInterval)
|
||||
case dest := <-dialed:
|
||||
delete(dialing, dest.ID)
|
||||
if len(dialing) == 0 {
|
||||
// Check again immediately after dialing all current candidates.
|
||||
refresh.Reset(0)
|
||||
}
|
||||
case <-srv.quit:
|
||||
// TODO: maybe wait for active dials
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) dialNode(dest *discover.Node) {
|
||||
addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
|
||||
glog.V(logger.Debug).Infof("Dialing %v\n", dest)
|
||||
conn, err := srv.Dialer.Dial("tcp", addr.String())
|
||||
// Run the protocol handshake
|
||||
phs, err := c.doProtoHandshake(srv.ourHandshake)
|
||||
if err != nil {
|
||||
// dialLoop adds to the wait group counter when launching
|
||||
// dialNode, so we need to count it down again. startPeer also
|
||||
// does that when an error occurs.
|
||||
srv.peerWG.Done()
|
||||
glog.V(logger.Detail).Infof("dial error: %v", err)
|
||||
glog.V(logger.Debug).Infof("%v failed proto handshake: %v", c, err)
|
||||
c.close(err)
|
||||
return
|
||||
}
|
||||
srv.startPeer(conn, dest)
|
||||
}
|
||||
|
||||
func (srv *Server) startPeer(fd net.Conn, dest *discover.Node) {
|
||||
// TODO: handle/store session token
|
||||
|
||||
// Run setupFunc, which should create an authenticated connection
|
||||
// and run the capability exchange. Note that any early error
|
||||
// returns during that exchange need to call peerWG.Done because
|
||||
// the callers of startPeer added the peer to the wait group already.
|
||||
fd.SetDeadline(time.Now().Add(handshakeTimeout))
|
||||
|
||||
conn, err := srv.setupFunc(fd, srv.PrivateKey, srv.ourHandshake, dest, srv.keepconn)
|
||||
if err != nil {
|
||||
fd.Close()
|
||||
glog.V(logger.Debug).Infof("Handshake with %v failed: %v", fd.RemoteAddr(), err)
|
||||
srv.peerWG.Done()
|
||||
if phs.ID != c.id {
|
||||
glog.V(logger.Debug).Infof("%v wrong proto handshake identity: %x", c, phs.ID[:8])
|
||||
c.close(DiscUnexpectedIdentity)
|
||||
return
|
||||
}
|
||||
conn.MsgReadWriter = &netWrapper{
|
||||
wrapped: conn.MsgReadWriter,
|
||||
conn: fd, rtimeout: frameReadTimeout, wtimeout: frameWriteTimeout,
|
||||
}
|
||||
p := newPeer(fd, conn, srv.Protocols)
|
||||
if ok, reason := srv.addPeer(conn, p); !ok {
|
||||
glog.V(logger.Detail).Infof("Not adding %v (%v)\n", p, reason)
|
||||
p.politeDisconnect(reason)
|
||||
srv.peerWG.Done()
|
||||
c.caps, c.name = phs.Caps, phs.Name
|
||||
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
||||
glog.V(logger.Debug).Infof("%v failed checkpoint addpeer: %v", c, err)
|
||||
c.close(err)
|
||||
return
|
||||
}
|
||||
// The handshakes are done and it passed all checks.
|
||||
// Spawn the Peer loops.
|
||||
go srv.runPeer(p)
|
||||
// If the checks completed successfully, runPeer has now been
|
||||
// launched by run.
|
||||
}
|
||||
|
||||
// preflight checks whether a connection should be kept. it runs
|
||||
// after the encryption handshake, as soon as the remote identity is
|
||||
// known.
|
||||
func (srv *Server) keepconn(id discover.NodeID) bool {
|
||||
srv.lock.RLock()
|
||||
defer srv.lock.RUnlock()
|
||||
if _, ok := srv.staticNodes[id]; ok {
|
||||
return true // static nodes are always allowed
|
||||
// checkpoint sends the conn to run, which performs the
|
||||
// post-handshake checks for the stage (posthandshake, addpeer).
|
||||
func (srv *Server) checkpoint(c *conn, stage chan<- *conn) error {
|
||||
select {
|
||||
case stage <- c:
|
||||
case <-srv.quit:
|
||||
return errServerStopped
|
||||
}
|
||||
if _, ok := srv.trustedNodes[id]; ok {
|
||||
return true // trusted nodes are always allowed
|
||||
select {
|
||||
case err := <-c.cont:
|
||||
return err
|
||||
case <-srv.quit:
|
||||
return errServerStopped
|
||||
}
|
||||
return len(srv.peers) < srv.MaxPeers
|
||||
}
|
||||
|
||||
// runPeer runs in its own goroutine for each peer.
|
||||
// it waits until the Peer logic returns and removes
|
||||
// the peer.
|
||||
func (srv *Server) runPeer(p *Peer) {
|
||||
glog.V(logger.Debug).Infof("Added %v\n", p)
|
||||
srvjslog.LogJson(&logger.P2PConnected{
|
||||
|
@ -552,58 +608,18 @@ func (srv *Server) runPeer(p *Peer) {
|
|||
RemoteVersionString: p.Name(),
|
||||
NumConnections: srv.PeerCount(),
|
||||
})
|
||||
|
||||
if srv.newPeerHook != nil {
|
||||
srv.newPeerHook(p)
|
||||
}
|
||||
discreason := p.run()
|
||||
srv.removePeer(p)
|
||||
// Note: run waits for existing peers to be sent on srv.delpeer
|
||||
// before returning, so this send should not select on srv.quit.
|
||||
srv.delpeer <- p
|
||||
|
||||
glog.V(logger.Debug).Infof("Removed %v (%v)\n", p, discreason)
|
||||
srvjslog.LogJson(&logger.P2PDisconnected{
|
||||
RemoteId: p.ID().String(),
|
||||
NumConnections: srv.PeerCount(),
|
||||
})
|
||||
}
|
||||
|
||||
func (srv *Server) addPeer(conn *conn, p *Peer) (bool, DiscReason) {
|
||||
// drop connections with no matching protocols.
|
||||
if len(srv.Protocols) > 0 && countMatchingProtocols(srv.Protocols, conn.protoHandshake.Caps) == 0 {
|
||||
return false, DiscUselessPeer
|
||||
}
|
||||
// add the peer if it passes the other checks.
|
||||
srv.lock.Lock()
|
||||
defer srv.lock.Unlock()
|
||||
if ok, reason := srv.checkPeer(conn.ID); !ok {
|
||||
return false, reason
|
||||
}
|
||||
srv.peers[conn.ID] = p
|
||||
return true, 0
|
||||
}
|
||||
|
||||
// checkPeer verifies whether a peer looks promising and should be allowed/kept
|
||||
// in the pool, or if it's of no use.
|
||||
func (srv *Server) checkPeer(id discover.NodeID) (bool, DiscReason) {
|
||||
// First up, figure out if the peer is static or trusted
|
||||
_, static := srv.staticNodes[id]
|
||||
trusted := srv.trustedNodes[id]
|
||||
|
||||
// Make sure the peer passes all required checks
|
||||
switch {
|
||||
case !srv.running:
|
||||
return false, DiscQuitting
|
||||
case !static && !trusted && len(srv.peers) >= srv.MaxPeers:
|
||||
return false, DiscTooManyPeers
|
||||
case srv.peers[id] != nil:
|
||||
return false, DiscAlreadyConnected
|
||||
case id == srv.ntab.Self().ID:
|
||||
return false, DiscSelf
|
||||
default:
|
||||
return true, 0
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) removePeer(p *Peer) {
|
||||
srv.lock.Lock()
|
||||
delete(srv.peers, p.ID())
|
||||
srv.lock.Unlock()
|
||||
srv.peerWG.Done()
|
||||
}
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"io"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -15,29 +14,50 @@ import (
|
|||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
||||
func init() {
|
||||
// glog.SetV(6)
|
||||
// glog.SetToStderr(true)
|
||||
}
|
||||
|
||||
type testTransport struct {
|
||||
id discover.NodeID
|
||||
*rlpx
|
||||
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func newTestTransport(id discover.NodeID, fd net.Conn) transport {
|
||||
wrapped := newRLPX(fd).(*rlpx)
|
||||
wrapped.rw = newRLPXFrameRW(fd, secrets{
|
||||
MAC: zero16,
|
||||
AES: zero16,
|
||||
IngressMAC: sha3.NewKeccak256(),
|
||||
EgressMAC: sha3.NewKeccak256(),
|
||||
})
|
||||
return &testTransport{id: id, rlpx: wrapped}
|
||||
}
|
||||
|
||||
func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
|
||||
return c.id, nil
|
||||
}
|
||||
|
||||
func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
|
||||
return &protoHandshake{ID: c.id, Name: "test"}, nil
|
||||
}
|
||||
|
||||
func (c *testTransport) close(err error) {
|
||||
c.rlpx.fd.Close()
|
||||
c.closeErr = err
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
|
||||
server := &Server{
|
||||
Name: "test",
|
||||
MaxPeers: 10,
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
newPeerHook: pf,
|
||||
setupFunc: func(fd net.Conn, prv *ecdsa.PrivateKey, our *protoHandshake, dial *discover.Node, keepconn func(discover.NodeID) bool) (*conn, error) {
|
||||
id := randomID()
|
||||
if !keepconn(id) {
|
||||
return nil, DiscAlreadyConnected
|
||||
}
|
||||
rw := newRlpxFrameRW(fd, secrets{
|
||||
MAC: zero16,
|
||||
AES: zero16,
|
||||
IngressMAC: sha3.NewKeccak256(),
|
||||
EgressMAC: sha3.NewKeccak256(),
|
||||
})
|
||||
return &conn{
|
||||
MsgReadWriter: rw,
|
||||
protoHandshake: &protoHandshake{ID: id, Version: baseProtocolVersion},
|
||||
}, nil
|
||||
},
|
||||
Name: "test",
|
||||
MaxPeers: 10,
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
newPeerHook: pf,
|
||||
newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatalf("Could not start server: %v", err)
|
||||
|
@ -48,7 +68,11 @@ func startTestServer(t *testing.T, pf newPeerHook) *Server {
|
|||
func TestServerListen(t *testing.T) {
|
||||
// start the test server
|
||||
connected := make(chan *Peer)
|
||||
srv := startTestServer(t, func(p *Peer) {
|
||||
remid := randomID()
|
||||
srv := startTestServer(t, remid, func(p *Peer) {
|
||||
if p.ID() != remid {
|
||||
t.Error("peer func called with wrong node id")
|
||||
}
|
||||
if p == nil {
|
||||
t.Error("peer func called with nil conn")
|
||||
}
|
||||
|
@ -70,6 +94,10 @@ func TestServerListen(t *testing.T) {
|
|||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||
peer.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
peers := srv.Peers()
|
||||
if !reflect.DeepEqual(peers, []*Peer{peer}) {
|
||||
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("server did not accept within one second")
|
||||
}
|
||||
|
@ -95,23 +123,33 @@ func TestServerDial(t *testing.T) {
|
|||
|
||||
// start the server
|
||||
connected := make(chan *Peer)
|
||||
srv := startTestServer(t, func(p *Peer) { connected <- p })
|
||||
remid := randomID()
|
||||
srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
|
||||
defer close(connected)
|
||||
defer srv.Stop()
|
||||
|
||||
// tell the server to connect
|
||||
tcpAddr := listener.Addr().(*net.TCPAddr)
|
||||
srv.staticDial <- &discover.Node{IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}
|
||||
srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
|
||||
|
||||
select {
|
||||
case conn := <-accepted:
|
||||
select {
|
||||
case peer := <-connected:
|
||||
if peer.ID() != remid {
|
||||
t.Errorf("peer has wrong id")
|
||||
}
|
||||
if peer.Name() != "test" {
|
||||
t.Errorf("peer has wrong name")
|
||||
}
|
||||
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
|
||||
t.Errorf("peer started with wrong conn: got %v, want %v",
|
||||
peer.RemoteAddr(), conn.LocalAddr())
|
||||
}
|
||||
// TODO: validate more fields
|
||||
peers := srv.Peers()
|
||||
if !reflect.DeepEqual(peers, []*Peer{peer}) {
|
||||
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("server did not launch peer within one second")
|
||||
}
|
||||
|
@ -121,370 +159,250 @@ func TestServerDial(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestServerBroadcast(t *testing.T) {
|
||||
var connected sync.WaitGroup
|
||||
srv := startTestServer(t, func(p *Peer) {
|
||||
p.running = matchProtocols([]Protocol{discard}, []Cap{discard.cap()}, p.rw)
|
||||
connected.Done()
|
||||
})
|
||||
defer srv.Stop()
|
||||
|
||||
// create a few peers
|
||||
var conns = make([]net.Conn, 8)
|
||||
connected.Add(len(conns))
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
dialer := &net.Dialer{Deadline: deadline}
|
||||
for i := range conns {
|
||||
conn, err := dialer.Dial("tcp", srv.ListenAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
||||
// This test checks that tasks generated by dialstate are
|
||||
// actually executed and taskdone is called for them.
|
||||
func TestServerTaskScheduling(t *testing.T) {
|
||||
var (
|
||||
done = make(chan *testTask)
|
||||
quit, returned = make(chan struct{}), make(chan struct{})
|
||||
tc = 0
|
||||
tg = taskgen{
|
||||
newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
|
||||
tc++
|
||||
return []task{&testTask{index: tc - 1}}
|
||||
},
|
||||
doneFunc: func(t task) {
|
||||
select {
|
||||
case done <- t.(*testTask):
|
||||
case <-quit:
|
||||
}
|
||||
},
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(deadline)
|
||||
conns[i] = conn
|
||||
)
|
||||
|
||||
// The Server in this test isn't actually running
|
||||
// because we're only interested in what run does.
|
||||
srv := &Server{
|
||||
MaxPeers: 10,
|
||||
quit: make(chan struct{}),
|
||||
ntab: fakeTable{},
|
||||
running: true,
|
||||
}
|
||||
connected.Wait()
|
||||
srv.loopWG.Add(1)
|
||||
go func() {
|
||||
srv.run(tg)
|
||||
close(returned)
|
||||
}()
|
||||
|
||||
// broadcast one message
|
||||
srv.Broadcast("discard", 0, []string{"foo"})
|
||||
golden := unhex("66e94d166f0a2c3b884cfa59ca34")
|
||||
|
||||
// check that the message has been written everywhere
|
||||
for i, conn := range conns {
|
||||
buf := make([]byte, len(golden))
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
t.Errorf("conn %d: read error: %v", i, err)
|
||||
} else if !bytes.Equal(buf, golden) {
|
||||
t.Errorf("conn %d: msg mismatch\ngot: %x\nwant: %x", i, buf, golden)
|
||||
var gotdone []*testTask
|
||||
for i := 0; i < 100; i++ {
|
||||
gotdone = append(gotdone, <-done)
|
||||
}
|
||||
for i, task := range gotdone {
|
||||
if task.index != i {
|
||||
t.Errorf("task %d has wrong index, got %d", i, task.index)
|
||||
break
|
||||
}
|
||||
if !task.called {
|
||||
t.Errorf("task %d was not called", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
close(quit)
|
||||
srv.Stop()
|
||||
select {
|
||||
case <-returned:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Error("Server.run did not return within 500ms")
|
||||
}
|
||||
}
|
||||
|
||||
type taskgen struct {
|
||||
newFunc func(running int, peers map[discover.NodeID]*Peer) []task
|
||||
doneFunc func(task)
|
||||
}
|
||||
|
||||
func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
|
||||
return tg.newFunc(running, peers)
|
||||
}
|
||||
func (tg taskgen) taskDone(t task, now time.Time) {
|
||||
tg.doneFunc(t)
|
||||
}
|
||||
func (tg taskgen) addStatic(*discover.Node) {
|
||||
}
|
||||
|
||||
type testTask struct {
|
||||
index int
|
||||
called bool
|
||||
}
|
||||
|
||||
func (t *testTask) Do(srv *Server) {
|
||||
t.called = true
|
||||
}
|
||||
|
||||
// This test checks that connections are disconnected
|
||||
// just after the encryption handshake when the server is
|
||||
// at capacity.
|
||||
//
|
||||
// It also serves as a light-weight integration test.
|
||||
func TestServerDisconnectAtCap(t *testing.T) {
|
||||
started := make(chan *Peer)
|
||||
// at capacity. Trusted connections should still be accepted.
|
||||
func TestServerAtCap(t *testing.T) {
|
||||
trustedID := randomID()
|
||||
srv := &Server{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
// This hook signals that the peer was actually started. We
|
||||
// need to wait for the peer to be started before dialing the
|
||||
// next connection to get a deterministic peer count.
|
||||
newPeerHook: func(p *Peer) { started <- p },
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
TrustedNodes: []*discover.Node{{ID: trustedID}},
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("could not start: %v", err)
|
||||
}
|
||||
defer srv.Stop()
|
||||
|
||||
nconns := srv.MaxPeers + 1
|
||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
||||
for i := 0; i < nconns; i++ {
|
||||
conn, err := dialer.Dial("tcp", srv.ListenAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
||||
newconn := func(id discover.NodeID) *conn {
|
||||
fd, _ := net.Pipe()
|
||||
tx := newTestTransport(id, fd)
|
||||
return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
|
||||
}
|
||||
|
||||
// Inject a few connections to fill up the peer set.
|
||||
for i := 0; i < 10; i++ {
|
||||
c := newconn(randomID())
|
||||
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
||||
t.Fatalf("could not add conn %d: %v", i, err)
|
||||
}
|
||||
// Close the connection when the test ends, before
|
||||
// shutting down the server.
|
||||
defer conn.Close()
|
||||
// Run the handshakes just like a real peer would.
|
||||
key := newkey()
|
||||
hs := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
||||
_, err = setupConn(conn, key, hs, srv.Self(), keepalways)
|
||||
if i == nconns-1 {
|
||||
// When handling the last connection, the server should
|
||||
// disconnect immediately instead of running the protocol
|
||||
// handshake.
|
||||
if err != DiscTooManyPeers {
|
||||
t.Errorf("conn %d: got error %q, expected %q", i, err, DiscTooManyPeers)
|
||||
}
|
||||
// Try inserting a non-trusted connection.
|
||||
c := newconn(randomID())
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
||||
t.Error("wrong error for insert:", err)
|
||||
}
|
||||
// Try inserting a trusted connection.
|
||||
c = newconn(trustedID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||
}
|
||||
if !c.is(trustedConn) {
|
||||
t.Error("Server did not set trusted flag")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestServerSetupConn(t *testing.T) {
|
||||
id := randomID()
|
||||
srvkey := newkey()
|
||||
srvid := discover.PubkeyID(&srvkey.PublicKey)
|
||||
tests := []struct {
|
||||
dontstart bool
|
||||
tt *setupTransport
|
||||
flags connFlag
|
||||
dialDest *discover.Node
|
||||
|
||||
wantCloseErr error
|
||||
wantCalls string
|
||||
}{
|
||||
{
|
||||
dontstart: true,
|
||||
tt: &setupTransport{id: id},
|
||||
wantCalls: "close,",
|
||||
wantCloseErr: errServerStopped,
|
||||
},
|
||||
{
|
||||
tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
|
||||
flags: inboundConn,
|
||||
wantCalls: "doEncHandshake,close,",
|
||||
wantCloseErr: errors.New("read error"),
|
||||
},
|
||||
{
|
||||
tt: &setupTransport{id: id},
|
||||
dialDest: &discover.Node{ID: randomID()},
|
||||
flags: dynDialedConn,
|
||||
wantCalls: "doEncHandshake,close,",
|
||||
wantCloseErr: DiscUnexpectedIdentity,
|
||||
},
|
||||
{
|
||||
tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
|
||||
dialDest: &discover.Node{ID: id},
|
||||
flags: dynDialedConn,
|
||||
wantCalls: "doEncHandshake,doProtoHandshake,close,",
|
||||
wantCloseErr: DiscUnexpectedIdentity,
|
||||
},
|
||||
{
|
||||
tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
|
||||
dialDest: &discover.Node{ID: id},
|
||||
flags: dynDialedConn,
|
||||
wantCalls: "doEncHandshake,doProtoHandshake,close,",
|
||||
wantCloseErr: errors.New("foo"),
|
||||
},
|
||||
{
|
||||
tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
|
||||
flags: inboundConn,
|
||||
wantCalls: "doEncHandshake,close,",
|
||||
wantCloseErr: DiscSelf,
|
||||
},
|
||||
{
|
||||
tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
|
||||
flags: inboundConn,
|
||||
wantCalls: "doEncHandshake,doProtoHandshake,close,",
|
||||
wantCloseErr: DiscUselessPeer,
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
srv := &Server{
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
Protocols: []Protocol{discard},
|
||||
newTransport: func(fd net.Conn) transport { return test.tt },
|
||||
}
|
||||
if !test.dontstart {
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatalf("couldn't start server: %v", err)
|
||||
}
|
||||
} else {
|
||||
// For all earlier connections, the handshake should go through.
|
||||
if err != nil {
|
||||
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
||||
}
|
||||
// Wait for runPeer to be started.
|
||||
<-started
|
||||
}
|
||||
p1, _ := net.Pipe()
|
||||
srv.setupConn(p1, test.flags, test.dialDest)
|
||||
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
|
||||
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
|
||||
}
|
||||
if test.tt.calls != test.wantCalls {
|
||||
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that static peers are (re)connected, and done so even above max peers.
|
||||
func TestServerStaticPeers(t *testing.T) {
|
||||
// Create a test server with limited connection slots
|
||||
started := make(chan *Peer)
|
||||
server := &Server{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 3,
|
||||
newPeerHook: func(p *Peer) { started <- p },
|
||||
staticCycle: time.Second,
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Stop()
|
||||
type setupTransport struct {
|
||||
id discover.NodeID
|
||||
encHandshakeErr error
|
||||
|
||||
// Fill up all the slots on the server
|
||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
||||
for i := 0; i < server.MaxPeers; i++ {
|
||||
// Establish a new connection
|
||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
phs *protoHandshake
|
||||
protoHandshakeErr error
|
||||
|
||||
// Run the handshakes just like a real peer would, and wait for completion
|
||||
key := newkey()
|
||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
||||
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
|
||||
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
||||
}
|
||||
<-started
|
||||
}
|
||||
// Open a TCP listener to accept static connections
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup listener: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
connected := make(chan net.Conn)
|
||||
go func() {
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
connected <- conn
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Inject a static node and wait for a remote dial, then redial, then nothing
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
static := &discover.Node{
|
||||
ID: discover.PubkeyID(&newkey().PublicKey),
|
||||
IP: addr.IP,
|
||||
TCP: uint16(addr.Port),
|
||||
}
|
||||
server.AddPeer(static)
|
||||
|
||||
select {
|
||||
case conn := <-connected:
|
||||
// Close the first connection, expect redial
|
||||
conn.Close()
|
||||
|
||||
case <-time.After(2 * server.staticCycle):
|
||||
t.Fatalf("remote dial timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case conn := <-connected:
|
||||
// Keep the second connection, don't expect redial
|
||||
defer conn.Close()
|
||||
|
||||
case <-time.After(2 * server.staticCycle):
|
||||
t.Fatalf("remote re-dial timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(2 * server.staticCycle):
|
||||
// Timeout as no dial occurred
|
||||
|
||||
case <-connected:
|
||||
t.Fatalf("connected node dialed")
|
||||
}
|
||||
calls string
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// Tests that trusted peers and can connect above max peer caps.
|
||||
func TestServerTrustedPeers(t *testing.T) {
|
||||
|
||||
// Create a trusted peer to accept connections from
|
||||
key := newkey()
|
||||
trusted := &discover.Node{
|
||||
ID: discover.PubkeyID(&key.PublicKey),
|
||||
}
|
||||
// Create a test server with limited connection slots
|
||||
started := make(chan *Peer)
|
||||
server := &Server{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 3,
|
||||
NoDial: true,
|
||||
TrustedNodes: []*discover.Node{trusted},
|
||||
newPeerHook: func(p *Peer) { started <- p },
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// Fill up all the slots on the server
|
||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
||||
for i := 0; i < server.MaxPeers; i++ {
|
||||
// Establish a new connection
|
||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("conn %d: dial error: %v", i, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Run the handshakes just like a real peer would, and wait for completion
|
||||
key := newkey()
|
||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
||||
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
|
||||
t.Fatalf("conn %d: unexpected error: %v", i, err)
|
||||
}
|
||||
<-started
|
||||
}
|
||||
// Dial from the trusted peer, ensure connection is accepted
|
||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("trusted node: dial error: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: trusted.ID}
|
||||
if _, err = setupConn(conn, key, shake, server.Self(), keepalways); err != nil {
|
||||
t.Fatalf("trusted node: unexpected error: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-started:
|
||||
// Ok, trusted peer accepted
|
||||
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("trusted node timeout")
|
||||
func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
|
||||
c.calls += "doEncHandshake,"
|
||||
return c.id, c.encHandshakeErr
|
||||
}
|
||||
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
|
||||
c.calls += "doProtoHandshake,"
|
||||
if c.protoHandshakeErr != nil {
|
||||
return nil, c.protoHandshakeErr
|
||||
}
|
||||
return c.phs, nil
|
||||
}
|
||||
func (c *setupTransport) close(err error) {
|
||||
c.calls += "close,"
|
||||
c.closeErr = err
|
||||
}
|
||||
|
||||
// Tests that a failed dial will temporarily throttle a peer.
|
||||
func TestServerMaxPendingDials(t *testing.T) {
|
||||
// Start a simple test server
|
||||
server := &Server{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 10,
|
||||
MaxPendingPeers: 1,
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal("failed to start test server: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// Simulate two separate remote peers
|
||||
peers := make(chan *discover.Node, 2)
|
||||
conns := make(chan net.Conn, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listener %d: failed to setup: %v", i, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
peers <- &discover.Node{
|
||||
ID: discover.PubkeyID(&newkey().PublicKey),
|
||||
IP: addr.IP,
|
||||
TCP: uint16(addr.Port),
|
||||
}
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
conns <- conn
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Request a dial for both peers
|
||||
go func() {
|
||||
for i := 0; i < 2; i++ {
|
||||
server.staticDial <- <-peers // hack piggybacking the static implementation
|
||||
}
|
||||
}()
|
||||
|
||||
// Make sure only one outbound connection goes through
|
||||
var conn net.Conn
|
||||
|
||||
select {
|
||||
case conn = <-conns:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("first dial timeout")
|
||||
}
|
||||
select {
|
||||
case conn = <-conns:
|
||||
t.Fatalf("second dial completed prematurely")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
// Finish the first dial, check the second
|
||||
conn.Close()
|
||||
select {
|
||||
case conn = <-conns:
|
||||
conn.Close()
|
||||
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("second dial timeout")
|
||||
}
|
||||
// setupConn shouldn't write to/read from the connection.
|
||||
func (c *setupTransport) WriteMsg(Msg) error {
|
||||
panic("WriteMsg called on setupTransport")
|
||||
}
|
||||
|
||||
func TestServerMaxPendingAccepts(t *testing.T) {
|
||||
// Start a test server and a peer sink for synchronization
|
||||
started := make(chan *Peer)
|
||||
server := &Server{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 10,
|
||||
MaxPendingPeers: 1,
|
||||
NoDial: true,
|
||||
newPeerHook: func(p *Peer) { started <- p },
|
||||
}
|
||||
if err := server.Start(); err != nil {
|
||||
t.Fatal("failed to start test server: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// Try and connect to the server on multiple threads concurrently
|
||||
conns := make([]net.Conn, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
dialer := &net.Dialer{Deadline: time.Now().Add(3 * time.Second)}
|
||||
|
||||
conn, err := dialer.Dial("tcp", server.ListenAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial server: %v", err)
|
||||
}
|
||||
conns[i] = conn
|
||||
}
|
||||
// Check that a handshake on the second doesn't pass
|
||||
go func() {
|
||||
key := newkey()
|
||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
||||
if _, err := setupConn(conns[1], key, shake, server.Self(), keepalways); err != nil {
|
||||
t.Fatalf("failed to run handshake: %v", err)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-started:
|
||||
t.Fatalf("handshake on second connection accepted")
|
||||
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
// Shake on first, check that both go through
|
||||
go func() {
|
||||
key := newkey()
|
||||
shake := &protoHandshake{Version: baseProtocolVersion, ID: discover.PubkeyID(&key.PublicKey)}
|
||||
if _, err := setupConn(conns[0], key, shake, server.Self(), keepalways); err != nil {
|
||||
t.Fatalf("failed to run handshake: %v", err)
|
||||
}
|
||||
}()
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("peer %d: handshake timeout", i)
|
||||
}
|
||||
}
|
||||
func (c *setupTransport) ReadMsg() (Msg, error) {
|
||||
panic("ReadMsg called on setupTransport")
|
||||
}
|
||||
|
||||
func newkey() *ecdsa.PrivateKey {
|
||||
|
@ -501,7 +419,3 @@ func randomID() (id discover.NodeID) {
|
|||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func keepalways(id discover.NodeID) bool {
|
||||
return true
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue