301 lines
8.3 KiB
Go
301 lines
8.3 KiB
Go
// Copyright 2015 The go-ethereum Authors
|
|
// This file is part of the go-ethereum library.
|
|
//
|
|
// The go-ethereum library is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU Lesser General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// The go-ethereum library is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU Lesser General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU Lesser General Public License
|
|
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
package rpc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
mapset "github.com/deckarep/golang-set"
|
|
"github.com/ethereum/go-ethereum/log"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
wsReadBuffer = 1024
|
|
wsWriteBuffer = 1024
|
|
wsPingInterval = 60 * time.Second
|
|
wsPingWriteTimeout = 5 * time.Second
|
|
wsPongTimeout = 30 * time.Second
|
|
wsMessageSizeLimit = 15 * 1024 * 1024
|
|
)
|
|
|
|
var wsBufferPool = new(sync.Pool)
|
|
|
|
// WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
|
|
//
|
|
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
|
// To allow connections with any origin, pass "*".
|
|
func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: wsReadBuffer,
|
|
WriteBufferSize: wsWriteBuffer,
|
|
WriteBufferPool: wsBufferPool,
|
|
CheckOrigin: wsHandshakeValidator(allowedOrigins),
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Debug("WebSocket upgrade failed", "err", err)
|
|
return
|
|
}
|
|
codec := newWebsocketCodec(conn)
|
|
s.ServeCodec(codec, 0)
|
|
})
|
|
}
|
|
|
|
// wsHandshakeValidator returns a handler that verifies the origin during the
|
|
// websocket upgrade process. When a '*' is specified as an allowed origins all
|
|
// connections are accepted.
|
|
func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
|
|
origins := mapset.NewSet()
|
|
allowAllOrigins := false
|
|
|
|
for _, origin := range allowedOrigins {
|
|
if origin == "*" {
|
|
allowAllOrigins = true
|
|
}
|
|
if origin != "" {
|
|
origins.Add(origin)
|
|
}
|
|
}
|
|
// allow localhost if no allowedOrigins are specified.
|
|
if len(origins.ToSlice()) == 0 {
|
|
origins.Add("http://localhost")
|
|
if hostname, err := os.Hostname(); err == nil {
|
|
origins.Add("http://" + hostname)
|
|
}
|
|
}
|
|
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
|
|
|
|
f := func(req *http.Request) bool {
|
|
// Skip origin verification if no Origin header is present. The origin check
|
|
// is supposed to protect against browser based attacks. Browsers always set
|
|
// Origin. Non-browser software can put anything in origin and checking it doesn't
|
|
// provide additional security.
|
|
if _, ok := req.Header["Origin"]; !ok {
|
|
return true
|
|
}
|
|
// Verify origin against allow list.
|
|
origin := strings.ToLower(req.Header.Get("Origin"))
|
|
if allowAllOrigins || originIsAllowed(origins, origin) {
|
|
return true
|
|
}
|
|
log.Warn("Rejected WebSocket connection", "origin", origin)
|
|
return false
|
|
}
|
|
|
|
return f
|
|
}
|
|
|
|
type wsHandshakeError struct {
|
|
err error
|
|
status string
|
|
}
|
|
|
|
func (e wsHandshakeError) Error() string {
|
|
s := e.err.Error()
|
|
if e.status != "" {
|
|
s += " (HTTP status " + e.status + ")"
|
|
}
|
|
return s
|
|
}
|
|
|
|
func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
|
|
it := allowedOrigins.Iterator()
|
|
for origin := range it.C {
|
|
if ruleAllowsOrigin(origin.(string), browserOrigin) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
|
|
var (
|
|
allowedScheme, allowedHostname, allowedPort string
|
|
browserScheme, browserHostname, browserPort string
|
|
err error
|
|
)
|
|
allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
|
|
if err != nil {
|
|
log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
|
|
return false
|
|
}
|
|
browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
|
|
if err != nil {
|
|
log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
|
|
return false
|
|
}
|
|
if allowedScheme != "" && allowedScheme != browserScheme {
|
|
return false
|
|
}
|
|
if allowedHostname != "" && allowedHostname != browserHostname {
|
|
return false
|
|
}
|
|
if allowedPort != "" && allowedPort != browserPort {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func parseOriginURL(origin string) (string, string, string, error) {
|
|
parsedURL, err := url.Parse(strings.ToLower(origin))
|
|
if err != nil {
|
|
return "", "", "", err
|
|
}
|
|
var scheme, hostname, port string
|
|
if strings.Contains(origin, "://") {
|
|
scheme = parsedURL.Scheme
|
|
hostname = parsedURL.Hostname()
|
|
port = parsedURL.Port()
|
|
} else {
|
|
scheme = ""
|
|
hostname = parsedURL.Scheme
|
|
port = parsedURL.Opaque
|
|
if hostname == "" {
|
|
hostname = origin
|
|
}
|
|
}
|
|
return scheme, hostname, port, nil
|
|
}
|
|
|
|
// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
|
|
// that is listening on the given endpoint using the provided dialer.
|
|
func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
|
|
endpoint, header, err := wsClientHeaders(endpoint, origin)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
|
|
conn, resp, err := dialer.DialContext(ctx, endpoint, header)
|
|
if err != nil {
|
|
hErr := wsHandshakeError{err: err}
|
|
if resp != nil {
|
|
hErr.status = resp.Status
|
|
}
|
|
return nil, hErr
|
|
}
|
|
return newWebsocketCodec(conn), nil
|
|
})
|
|
}
|
|
|
|
// DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
|
|
// that is listening on the given endpoint.
|
|
//
|
|
// The context is used for the initial connection establishment. It does not
|
|
// affect subsequent interactions with the client.
|
|
func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
|
|
dialer := websocket.Dialer{
|
|
ReadBufferSize: wsReadBuffer,
|
|
WriteBufferSize: wsWriteBuffer,
|
|
WriteBufferPool: wsBufferPool,
|
|
}
|
|
return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
|
|
}
|
|
|
|
func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
|
|
endpointURL, err := url.Parse(endpoint)
|
|
if err != nil {
|
|
return endpoint, nil, err
|
|
}
|
|
header := make(http.Header)
|
|
if origin != "" {
|
|
header.Add("origin", origin)
|
|
}
|
|
if endpointURL.User != nil {
|
|
b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
|
|
header.Add("authorization", "Basic "+b64auth)
|
|
endpointURL.User = nil
|
|
}
|
|
return endpointURL.String(), header, nil
|
|
}
|
|
|
|
type websocketCodec struct {
|
|
*jsonCodec
|
|
conn *websocket.Conn
|
|
|
|
wg sync.WaitGroup
|
|
pingReset chan struct{}
|
|
}
|
|
|
|
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
|
|
conn.SetReadLimit(wsMessageSizeLimit)
|
|
conn.SetPongHandler(func(appData string) error {
|
|
conn.SetReadDeadline(time.Time{})
|
|
return nil
|
|
})
|
|
wc := &websocketCodec{
|
|
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
|
|
conn: conn,
|
|
pingReset: make(chan struct{}, 1),
|
|
}
|
|
wc.wg.Add(1)
|
|
go wc.pingLoop()
|
|
return wc
|
|
}
|
|
|
|
func (wc *websocketCodec) close() {
|
|
wc.jsonCodec.close()
|
|
wc.wg.Wait()
|
|
}
|
|
|
|
func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
|
|
err := wc.jsonCodec.writeJSON(ctx, v)
|
|
if err == nil {
|
|
// Notify pingLoop to delay the next idle ping.
|
|
select {
|
|
case wc.pingReset <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
// pingLoop sends periodic ping frames when the connection is idle.
|
|
func (wc *websocketCodec) pingLoop() {
|
|
var timer = time.NewTimer(wsPingInterval)
|
|
defer wc.wg.Done()
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-wc.closed():
|
|
return
|
|
case <-wc.pingReset:
|
|
if !timer.Stop() {
|
|
<-timer.C
|
|
}
|
|
timer.Reset(wsPingInterval)
|
|
case <-timer.C:
|
|
wc.jsonCodec.encMu.Lock()
|
|
wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
|
|
wc.conn.WriteMessage(websocket.PingMessage, nil)
|
|
wc.conn.SetReadDeadline(time.Now().Add(wsPongTimeout))
|
|
wc.jsonCodec.encMu.Unlock()
|
|
timer.Reset(wsPingInterval)
|
|
}
|
|
}
|
|
}
|