diff --git a/rpc/client.go b/rpc/client.go index e43760c22c..d55af75545 100644 --- a/rpc/client.go +++ b/rpc/client.go @@ -58,12 +58,6 @@ const ( maxClientSubscriptionBuffer = 20000 ) -const ( - httpScheme = "http" - wsScheme = "ws" - ipcScheme = "ipc" -) - // BatchElem is an element in a batch request. type BatchElem struct { Method string @@ -80,7 +74,7 @@ type BatchElem struct { // Client represents a connection to an RPC server. type Client struct { idgen func() ID // for subscriptions - scheme string // connection type: http, ws or ipc + isHTTP bool // connection type: http, ws or ipc services *serviceRegistry idCounter uint32 @@ -115,11 +109,9 @@ type clientConn struct { } func (c *Client) newClientConn(conn ServerCodec) *clientConn { - ctx := context.WithValue(context.Background(), clientContextKey{}, c) - // Http connections have already set the scheme - if !c.isHTTP() && c.scheme != "" { - ctx = context.WithValue(ctx, "scheme", c.scheme) - } + ctx := context.Background() + ctx = context.WithValue(ctx, clientContextKey{}, c) + ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo()) handler := newHandler(ctx, conn, c.idgen, c.services) return &clientConn{conn, handler} } @@ -145,7 +137,7 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro select { case <-ctx.Done(): // Send the timeout to dispatch so it can remove the request IDs. - if !c.isHTTP() { + if !c.isHTTP { select { case c.reqTimeout <- op: case <-c.closing: @@ -212,18 +204,10 @@ func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) } func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { - scheme := "" - switch conn.(type) { - case *httpConn: - scheme = httpScheme - case *websocketCodec: - scheme = wsScheme - case *jsonCodec: - scheme = ipcScheme - } + _, isHTTP := conn.(*httpConn) c := &Client{ + isHTTP: isHTTP, idgen: idgen, - scheme: scheme, services: services, writeConn: conn, close: make(chan struct{}), @@ -236,7 +220,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C reqSent: make(chan error, 1), reqTimeout: make(chan *requestOp), } - if !c.isHTTP() { + if !isHTTP { go c.dispatch(conn) } return c @@ -267,7 +251,7 @@ func (c *Client) SupportedModules() (map[string]string, error) { // Close closes the client, aborting any in-flight requests. func (c *Client) Close() { - if c.isHTTP() { + if c.isHTTP { return } select { @@ -281,7 +265,7 @@ func (c *Client) Close() { // This method only works for clients using HTTP, it doesn't have // any effect for clients using another transport. func (c *Client) SetHeader(key, value string) { - if !c.isHTTP() { + if !c.isHTTP { return } conn := c.writeConn.(*httpConn) @@ -315,7 +299,7 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str } op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} - if c.isHTTP() { + if c.isHTTP { err = c.sendHTTP(ctx, op, msg) } else { err = c.send(ctx, op, msg) @@ -378,7 +362,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error { } var err error - if c.isHTTP() { + if c.isHTTP { err = c.sendBatchHTTP(ctx, op, msgs) } else { err = c.send(ctx, op, msgs) @@ -417,7 +401,7 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) } msg.ID = nil - if c.isHTTP() { + if c.isHTTP { return c.sendHTTP(ctx, op, msg) } return c.send(ctx, op, msg) @@ -450,12 +434,12 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf // Check type of channel first. chanVal := reflect.ValueOf(channel) if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { - panic("first argument to Subscribe must be a writable channel") + panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel)) } if chanVal.IsNil() { panic("channel given to Subscribe must not be nil") } - if c.isHTTP() { + if c.isHTTP { return nil, ErrNotificationsUnsupported } @@ -509,8 +493,8 @@ func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error } func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { - // The previous write failed. Try to establish a new connection. if c.writeConn == nil { + // The previous write failed. Try to establish a new connection. if err := c.reconnect(ctx); err != nil { return err } @@ -657,7 +641,3 @@ func (c *Client) read(codec ServerCodec) { c.readOp <- readOp{msgs, batch} } } - -func (c *Client) isHTTP() bool { - return c.scheme == httpScheme -} diff --git a/rpc/http.go b/rpc/http.go index 32f4e7d90a..9c5a5cc0f2 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -48,11 +48,18 @@ type httpConn struct { headers http.Header } -// httpConn is treated specially by Client. +// httpConn implements ServerCodec, but it is treated specially by Client +// and some methods don't work. The panic() stubs here exist to ensure +// this special treatment is correct. + func (hc *httpConn) writeJSON(context.Context, interface{}) error { panic("writeJSON called on httpConn") } +func (hc *httpConn) peerInfo() PeerInfo { + panic("peerInfo called on httpConn") +} + func (hc *httpConn) remoteAddr() string { return hc.url } @@ -236,20 +243,19 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), code) return } + + // Create request-scoped context. + connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr} + connInfo.HTTP.Version = r.Proto + connInfo.HTTP.Host = r.Host + connInfo.HTTP.Origin = r.Header.Get("Origin") + connInfo.HTTP.UserAgent = r.Header.Get("User-Agent") + ctx := r.Context() + ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo) + // All checks passed, create a codec that reads directly from the request body // until EOF, writes the response to w, and orders the server to process a // single request. - ctx := r.Context() - ctx = context.WithValue(ctx, "remote", r.RemoteAddr) - ctx = context.WithValue(ctx, "scheme", r.Proto) - ctx = context.WithValue(ctx, "local", r.Host) - if ua := r.Header.Get("User-Agent"); ua != "" { - ctx = context.WithValue(ctx, "User-Agent", ua) - } - if origin := r.Header.Get("Origin"); origin != "" { - ctx = context.WithValue(ctx, "Origin", origin) - } - w.Header().Set("content-type", contentType) codec := newHTTPServerConn(r, w) defer codec.close() diff --git a/rpc/http_test.go b/rpc/http_test.go index 97f8d44c39..c84d7705f2 100644 --- a/rpc/http_test.go +++ b/rpc/http_test.go @@ -162,3 +162,39 @@ func TestHTTPErrorResponse(t *testing.T) { t.Error("unexpected error message", errMsg) } } + +func TestHTTPPeerInfo(t *testing.T) { + s := newTestServer() + defer s.Stop() + ts := httptest.NewServer(s) + defer ts.Close() + + c, err := Dial(ts.URL) + if err != nil { + t.Fatal(err) + } + c.SetHeader("user-agent", "ua-testing") + c.SetHeader("origin", "origin.example.com") + + // Request peer information. + var info PeerInfo + if err := c.Call(&info, "test_peerInfo"); err != nil { + t.Fatal(err) + } + + if info.RemoteAddr == "" { + t.Error("RemoteAddr not set") + } + if info.Transport != "http" { + t.Errorf("wrong Transport %q", info.Transport) + } + if info.HTTP.Version != "HTTP/1.1" { + t.Errorf("wrong HTTP.Version %q", info.HTTP.Version) + } + if info.HTTP.UserAgent != "ua-testing" { + t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent) + } + if info.HTTP.Origin != "origin.example.com" { + t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent) + } +} diff --git a/rpc/json.go b/rpc/json.go index 1daee3db82..6024f1e7dc 100644 --- a/rpc/json.go +++ b/rpc/json.go @@ -198,6 +198,11 @@ func NewCodec(conn Conn) ServerCodec { return NewFuncCodec(conn, enc.Encode, dec.Decode) } +func (c *jsonCodec) peerInfo() PeerInfo { + // This returns "ipc" because all other built-in transports have a separate codec type. + return PeerInfo{Transport: "ipc", RemoteAddr: c.remote} +} + func (c *jsonCodec) remoteAddr() string { return c.remote } diff --git a/rpc/server.go b/rpc/server.go index 64e078a7fd..e2d5c03835 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -145,3 +145,38 @@ func (s *RPCService) Modules() map[string]string { } return modules } + +// PeerInfo contains information about the remote end of the network connection. +// +// This is available within RPC method handlers through the context. Call +// PeerInfoFromContext to get information about the client connection related to +// the current method call. +type PeerInfo struct { + // Transport is name of the protocol used by the client. + // This can be "http", "ws" or "ipc". + Transport string + + // Address of client. This will usually contain the IP address and port. + RemoteAddr string + + // Addditional information for HTTP and WebSocket connections. + HTTP struct { + // Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket. + Version string + // Header values sent by the client. + UserAgent string + Origin string + Host string + } +} + +type peerInfoContextKey struct{} + +// PeerInfoFromContext returns information about the client's network connection. +// Use this with the context passed to RPC method handler functions. +// +// The zero value is returned if no connection info is present in ctx. +func PeerInfoFromContext(ctx context.Context) PeerInfo { + info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo) + return info +} diff --git a/rpc/server_test.go b/rpc/server_test.go index 6a2b09e449..c692a071cf 100644 --- a/rpc/server_test.go +++ b/rpc/server_test.go @@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) { t.Fatalf("Expected service calc to be registered") } - wantCallbacks := 9 + wantCallbacks := 10 if len(svc.callbacks) != wantCallbacks { t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks)) } diff --git a/rpc/testservice_test.go b/rpc/testservice_test.go index 62afc1df44..253e263289 100644 --- a/rpc/testservice_test.go +++ b/rpc/testservice_test.go @@ -80,6 +80,10 @@ func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args * return echoResult{str, i, args} } +func (s *testService) PeerInfo(ctx context.Context) PeerInfo { + return PeerInfoFromContext(ctx) +} + func (s *testService) Sleep(ctx context.Context, duration time.Duration) { time.Sleep(duration) } diff --git a/rpc/types.go b/rpc/types.go index ca52d474d9..959e383723 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -40,8 +40,10 @@ type API struct { // a RPC session. Implementations must be go-routine safe since the codec can be called in // multiple go-routines concurrently. type ServerCodec interface { + peerInfo() PeerInfo readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error) close() + jsonWriter } diff --git a/rpc/websocket.go b/rpc/websocket.go index 5571324af8..28380d8aa4 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { log.Debug("WebSocket upgrade failed", "err", err) return } - codec := newWebsocketCodec(conn) + codec := newWebsocketCodec(conn, r.Host, r.Header) s.ServeCodec(codec, 0) }) } @@ -197,7 +197,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale } return nil, hErr } - return newWebsocketCodec(conn), nil + return newWebsocketCodec(conn, endpoint, header), nil }) } @@ -235,12 +235,13 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { type websocketCodec struct { *jsonCodec conn *websocket.Conn + info PeerInfo wg sync.WaitGroup pingReset chan struct{} } -func newWebsocketCodec(conn *websocket.Conn) ServerCodec { +func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec { conn.SetReadLimit(wsMessageSizeLimit) conn.SetPongHandler(func(appData string) error { conn.SetReadDeadline(time.Time{}) @@ -250,7 +251,16 @@ func newWebsocketCodec(conn *websocket.Conn) ServerCodec { jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec), conn: conn, pingReset: make(chan struct{}, 1), + info: PeerInfo{ + Transport: "ws", + RemoteAddr: conn.RemoteAddr().String(), + }, } + // Fill in connection details. + wc.info.HTTP.Host = host + wc.info.HTTP.Origin = req.Get("Origin") + wc.info.HTTP.UserAgent = req.Get("User-Agent") + // Start pinger. wc.wg.Add(1) go wc.pingLoop() return wc @@ -261,6 +271,10 @@ func (wc *websocketCodec) close() { wc.wg.Wait() } +func (wc *websocketCodec) peerInfo() PeerInfo { + return wc.info +} + func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error { err := wc.jsonCodec.writeJSON(ctx, v) if err == nil { diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index cf83b621f1..8659f798e4 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -117,6 +117,41 @@ func TestWebsocketLargeCall(t *testing.T) { } } +func TestWebsocketPeerInfo(t *testing.T) { + var ( + s = newTestServer() + ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"})) + tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:") + ) + defer s.Stop() + defer ts.Close() + + ctx := context.Background() + c, err := DialWebsocket(ctx, tsurl, "origin.example.com") + if err != nil { + t.Fatal(err) + } + + // Request peer information. + var connInfo PeerInfo + if err := c.Call(&connInfo, "test_peerInfo"); err != nil { + t.Fatal(err) + } + + if connInfo.RemoteAddr == "" { + t.Error("RemoteAddr not set") + } + if connInfo.Transport != "ws" { + t.Errorf("wrong Transport %q", connInfo.Transport) + } + if connInfo.HTTP.UserAgent != "Go-http-client/1.1" { + t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent) + } + if connInfo.HTTP.Origin != "origin.example.com" { + t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent) + } +} + // This test checks that client handles WebSocket ping frames correctly. func TestClientWebsocketPing(t *testing.T) { t.Parallel() diff --git a/signer/core/api.go b/signer/core/api.go index 48b54b8f43..f06fbeb76d 100644 --- a/signer/core/api.go +++ b/signer/core/api.go @@ -33,6 +33,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/signer/core/apitypes" "github.com/ethereum/go-ethereum/signer/storage" ) @@ -188,23 +189,24 @@ func StartClefAccountManager(ksLocation string, nousb, lightKDF bool, scpath str // MetadataFromContext extracts Metadata from a given context.Context func MetadataFromContext(ctx context.Context) Metadata { + info := rpc.PeerInfoFromContext(ctx) + m := Metadata{"NA", "NA", "NA", "", ""} // batman - if v := ctx.Value("remote"); v != nil { - m.Remote = v.(string) + if info.Transport != "" { + if info.Transport == "http" { + m.Scheme = info.HTTP.Version + } + m.Scheme = info.Transport } - if v := ctx.Value("scheme"); v != nil { - m.Scheme = v.(string) + if info.RemoteAddr != "" { + m.Remote = info.RemoteAddr } - if v := ctx.Value("local"); v != nil { - m.Local = v.(string) - } - if v := ctx.Value("Origin"); v != nil { - m.Origin = v.(string) - } - if v := ctx.Value("User-Agent"); v != nil { - m.UserAgent = v.(string) + if info.HTTP.Host != "" { + m.Local = info.HTTP.Host } + m.Origin = info.HTTP.Origin + m.UserAgent = info.HTTP.UserAgent return m }