go-ethereum/rpc/websocket.go

232 lines
6.9 KiB
Go

// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package rpc
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
mapset "github.com/deckarep/golang-set"
"github.com/ethereum/go-ethereum/log"
"golang.org/x/net/websocket"
)
// websocketJSONCodec is a custom JSON codec with payload size enforcement and
// special number parsing.
var websocketJSONCodec = websocket.Codec{
// Marshal is the stock JSON marshaller used by the websocket library too.
Marshal: func(v interface{}) ([]byte, byte, error) {
msg, err := json.Marshal(v)
return msg, websocket.TextFrame, err
},
// Unmarshal is a specialized unmarshaller to properly convert numbers.
Unmarshal: func(msg []byte, payloadType byte, v interface{}) error {
dec := json.NewDecoder(bytes.NewReader(msg))
dec.UseNumber()
return dec.Decode(v)
},
}
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*".
func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return websocket.Server{
Handshake: wsHandshakeValidator(allowedOrigins),
Handler: func(conn *websocket.Conn) {
codec := newWebsocketCodec(conn)
s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
},
}
}
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
// Create a custom encode/decode pair to enforce payload size and number encoding
conn.MaxPayloadBytes = maxRequestContentLength
encoder := func(v interface{}) error {
return websocketJSONCodec.Send(conn, v)
}
decoder := func(v interface{}) error {
return websocketJSONCodec.Receive(conn, v)
}
rpcconn := Conn(conn)
if conn.IsServerConn() {
// Override remote address with the actual socket address because
// package websocket crashes if there is no request origin.
addr := conn.Request().RemoteAddr
if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil {
// Add origin if present.
addr += "(" + wsaddr.URL.String() + ")"
}
rpcconn = connWithRemoteAddr{conn, addr}
}
return NewCodec(rpcconn, encoder, decoder)
}
// NewWSServer creates a new websocket RPC server around an API provider.
//
// Deprecated: use Server.WebsocketHandler
func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)}
}
// wsHandshakeValidator returns a handler that verifies the origin during the
// websocket upgrade process. When a '*' is specified as an allowed origins all
// connections are accepted.
func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error {
origins := mapset.NewSet()
allowAllOrigins := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAllOrigins = true
}
if origin != "" {
origins.Add(strings.ToLower(origin))
}
}
// allow localhost if no allowedOrigins are specified.
if len(origins.ToSlice()) == 0 {
origins.Add("http://localhost")
if hostname, err := os.Hostname(); err == nil {
origins.Add("http://" + strings.ToLower(hostname))
}
}
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
f := func(cfg *websocket.Config, req *http.Request) error {
// Verify origin against whitelist.
origin := strings.ToLower(req.Header.Get("Origin"))
if allowAllOrigins || origins.Contains(origin) {
return nil
}
log.Warn("Rejected WebSocket connection", "origin", origin)
return errors.New("origin not allowed")
}
return f
}
func wsGetConfig(endpoint, origin string) (*websocket.Config, error) {
if origin == "" {
var err error
if origin, err = os.Hostname(); err != nil {
return nil, err
}
if strings.HasPrefix(endpoint, "wss") {
origin = "https://" + strings.ToLower(origin)
} else {
origin = "http://" + strings.ToLower(origin)
}
}
config, err := websocket.NewConfig(endpoint, origin)
if err != nil {
return nil, err
}
if config.Location.User != nil {
b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String()))
config.Header.Add("Authorization", "Basic "+b64auth)
config.Location.User = nil
}
return config, nil
}
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
// that is listening on the given endpoint.
//
// The context is used for the initial connection establishment. It does not
// affect subsequent interactions with the client.
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
config, err := wsGetConfig(endpoint, origin)
if err != nil {
return nil, err
}
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
conn, err := wsDialContext(ctx, config)
if err != nil {
return nil, err
}
return newWebsocketCodec(conn), nil
})
}
func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) {
var conn net.Conn
var err error
switch config.Location.Scheme {
case "ws":
conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location))
case "wss":
dialer := contextDialer(ctx)
conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig)
default:
err = websocket.ErrBadScheme
}
if err != nil {
return nil, err
}
ws, err := websocket.NewClient(config, conn)
if err != nil {
conn.Close()
return nil, err
}
return ws, err
}
var wsPortMap = map[string]string{"ws": "80", "wss": "443"}
func wsDialAddress(location *url.URL) string {
if _, ok := wsPortMap[location.Scheme]; ok {
if _, _, err := net.SplitHostPort(location.Host); err != nil {
return net.JoinHostPort(location.Host, wsPortMap[location.Scheme])
}
}
return location.Host
}
func dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
d := &net.Dialer{KeepAlive: tcpKeepAliveInterval}
return d.DialContext(ctx, network, addr)
}
func contextDialer(ctx context.Context) *net.Dialer {
dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval}
if deadline, ok := ctx.Deadline(); ok {
dialer.Deadline = deadline
} else {
dialer.Deadline = time.Now().Add(defaultDialTimeout)
}
return dialer
}