diff --git a/cmd/clef/main.go b/cmd/clef/main.go index 3aaf898db2..f7c3adebc4 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -661,7 +661,7 @@ func signer(c *cli.Context) error { if err != nil { utils.Fatalf("Could not register API: %w", err) } - handler := node.NewHTTPHandlerStack(srv, cors, vhosts) + handler := node.NewHTTPHandlerStack(srv, cors, vhosts, nil) // set port port := c.Int(rpcPortFlag.Name) diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 94a0b16a8d..288ff28efc 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -164,6 +164,8 @@ var ( utils.HTTPListenAddrFlag, utils.HTTPPortFlag, utils.HTTPCORSDomainFlag, + utils.AuthPortFlag, + utils.JWTSecretFlag, utils.HTTPVirtualHostsFlag, utils.GraphQLEnabledFlag, utils.GraphQLCORSDomainFlag, diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 417fba6892..6378107695 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -135,6 +135,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{ Flags: []cli.Flag{ utils.IPCDisabledFlag, utils.IPCPathFlag, + utils.JWTSecretFlag, utils.HTTPEnabledFlag, utils.HTTPListenAddrFlag, utils.HTTPPortFlag, diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 7d11b0631a..c1d226df09 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -518,6 +518,16 @@ var ( Usage: "Sets a cap on transaction fee (in ether) that can be sent via the RPC APIs (0 = no cap)", Value: ethconfig.Defaults.RPCTxFeeCap, } + // Authenticated port settings + AuthPortFlag = cli.IntFlag{ + Name: "authrpc.port", + Usage: "Listening port for authenticated APIs", + Value: node.DefaultAuthPort, + } + JWTSecretFlag = cli.StringFlag{ + Name: "authrpc.jwtsecret", + Usage: "JWT secret (or path to a jwt secret) to use for authenticated RPC endpoints", + } // Logging and debug settings EthStatsURLFlag = cli.StringFlag{ Name: "ethstats", @@ -951,6 +961,10 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { cfg.HTTPPort = ctx.GlobalInt(HTTPPortFlag.Name) } + if ctx.GlobalIsSet(AuthPortFlag.Name) { + cfg.AuthPort = ctx.GlobalInt(AuthPortFlag.Name) + } + if ctx.GlobalIsSet(HTTPCORSDomainFlag.Name) { cfg.HTTPCors = SplitAndTrim(ctx.GlobalString(HTTPCORSDomainFlag.Name)) } @@ -1218,6 +1232,10 @@ func SetNodeConfig(ctx *cli.Context, cfg *node.Config) { setDataDir(ctx, cfg) setSmartCard(ctx, cfg) + if ctx.GlobalIsSet(JWTSecretFlag.Name) { + cfg.JWTSecret = ctx.GlobalString(JWTSecretFlag.Name) + } + if ctx.GlobalIsSet(ExternalSignerFlag.Name) { cfg.ExternalSigner = ctx.GlobalString(ExternalSignerFlag.Name) } diff --git a/eth/catalyst/api.go b/eth/catalyst/api.go index a8b20d7589..aa1193d2fc 100644 --- a/eth/catalyst/api.go +++ b/eth/catalyst/api.go @@ -36,10 +36,18 @@ func Register(stack *node.Node, backend *eth.Ethereum) error { log.Warn("Catalyst mode enabled", "protocol", "eth") stack.RegisterAPIs([]rpc.API{ { - Namespace: "engine", - Version: "1.0", - Service: NewConsensusAPI(backend), - Public: true, + Namespace: "engine", + Version: "1.0", + Service: NewConsensusAPI(backend), + Public: true, + Authenticated: true, + }, + { + Namespace: "engine", + Version: "1.0", + Service: NewConsensusAPI(backend), + Public: true, + Authenticated: false, }, }) return nil diff --git a/go.mod b/go.mod index 79802085a6..b02d8ca832 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-stack/stack v1.8.0 + github.com/golang-jwt/jwt/v4 v4.3.0 // indirect github.com/golang/protobuf v1.4.3 github.com/golang/snappy v0.0.4 github.com/google/gofuzz v1.1.1-0.20200604201612-c04b05f3adfa diff --git a/go.sum b/go.sum index e27d662108..ad936f8282 100644 --- a/go.sum +++ b/go.sum @@ -154,6 +154,10 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang-jwt/jwt/v4 v4.2.0 h1:besgBTC8w8HjP6NzQdxwKH9Z5oQMZ24ThTrHp3cZ8eU= +github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog= +github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/graphql/service.go b/graphql/service.go index bcb0a4990d..29d98ad746 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -74,7 +74,7 @@ func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) return err } h := handler{Schema: s} - handler := node.NewHTTPHandlerStack(h, cors, vhosts) + handler := node.NewHTTPHandlerStack(h, cors, vhosts, nil) stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{}) stack.RegisterHandler("GraphQL", "/graphql", handler) diff --git a/les/catalyst/api.go b/les/catalyst/api.go index 5f5193c3bb..ea5f9af28d 100644 --- a/les/catalyst/api.go +++ b/les/catalyst/api.go @@ -34,10 +34,11 @@ func Register(stack *node.Node, backend *les.LightEthereum) error { log.Warn("Catalyst mode enabled", "protocol", "les") stack.RegisterAPIs([]rpc.API{ { - Namespace: "engine", - Version: "1.0", - Service: NewConsensusAPI(backend), - Public: true, + Namespace: "engine", + Version: "1.0", + Service: NewConsensusAPI(backend), + Public: true, + Authenticated: true, }, }) return nil diff --git a/node/api.go b/node/api.go index a685ecd6b3..1b32399f63 100644 --- a/node/api.go +++ b/node/api.go @@ -274,11 +274,12 @@ func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str } // Enable WebSocket on the server. - server := api.node.wsServerForPort(*port) + server := api.node.wsServerForPort(*port, false) if err := server.setListenAddr(*host, *port); err != nil { return false, err } - if err := server.enableWS(api.node.rpcAPIs, config); err != nil { + openApis, _ := api.node.GetAPIs() + if err := server.enableWS(openApis, config); err != nil { return false, err } if err := server.start(); err != nil { diff --git a/node/config.go b/node/config.go index 26f00cd678..97853530a6 100644 --- a/node/config.go +++ b/node/config.go @@ -36,6 +36,7 @@ import ( const ( datadirPrivateKey = "nodekey" // Path within the datadir to the node's private key + datadirJWTKey = "jwtsecret" // Path within the datadir to the node's jwt secret datadirDefaultKeyStore = "keystore" // Path within the datadir to the keystore datadirStaticNodes = "static-nodes.json" // Path within the datadir to the static node list datadirTrustedNodes = "trusted-nodes.json" // Path within the datadir to the trusted node list @@ -112,6 +113,9 @@ type Config struct { // for ephemeral nodes). HTTPPort int `toml:",omitempty"` + // Authport is the port number on which the authenticated API is provided. + AuthPort int `toml:",omitempty"` + // HTTPCors is the Cross-Origin Resource Sharing header to send to requesting // clients. Please be aware that CORS is a browser enforced security, it's fully // useless for custom HTTP clients. @@ -190,6 +194,9 @@ type Config struct { // AllowUnprotectedTxs allows non EIP-155 protected transactions to be send over RPC. AllowUnprotectedTxs bool `toml:",omitempty"` + + // JWTSecret is the hex-encoded jwt secret. + JWTSecret string `toml:",omitempty"` } // IPCEndpoint resolves an IPC endpoint based on a configured value, taking into @@ -248,7 +255,7 @@ func (c *Config) HTTPEndpoint() string { // DefaultHTTPEndpoint returns the HTTP endpoint used by default. func DefaultHTTPEndpoint() string { - config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort} + config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort, AuthPort: DefaultAuthPort} return config.HTTPEndpoint() } diff --git a/node/defaults.go b/node/defaults.go index c685dde5d1..318d907fcc 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -34,12 +34,23 @@ const ( DefaultWSPort = 8546 // Default TCP port for the websocket RPC server DefaultGraphQLHost = "localhost" // Default host interface for the GraphQL server DefaultGraphQLPort = 8547 // Default TCP port for the GraphQL server + DefaultAuthHost = "localhost" // Default host interface for the authenticated apis + DefaultAuthPort = 8551 // Default port for the authenticated apis +) + +var ( + DefaultAuthCors = []string{"localhost"} // Default cors domain for the authenticated apis + DefaultAuthVhosts = []string{"localhost"} // Default virtual hosts for the authenticated apis + DefaultAuthOrigins = []string{"localhost"} // Default origins for the authenticated apis + DefaultAuthPrefix = "" // Default prefix for the authenticated apis + DefaultAuthModules = []string{"eth", "engine"} ) // DefaultConfig contains reasonable default settings. var DefaultConfig = Config{ DataDir: DefaultDataDir(), HTTPPort: DefaultHTTPPort, + AuthPort: DefaultAuthPort, HTTPModules: []string{"net", "web3"}, HTTPVirtualHosts: []string{"localhost"}, HTTPTimeouts: rpc.DefaultHTTPTimeouts, diff --git a/node/endpoints.go b/node/endpoints.go index 1f85a52131..166e39adb4 100644 --- a/node/endpoints.go +++ b/node/endpoints.go @@ -60,8 +60,10 @@ func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available [ } } for _, name := range modules { - if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi { - bad = append(bad, name) + if _, ok := availableSet[name]; !ok { + if name != rpc.MetadataApi && name != rpc.EngineApi { + bad = append(bad, name) + } } } return bad, available diff --git a/node/jwt_handler.go b/node/jwt_handler.go new file mode 100644 index 0000000000..28d5b87c60 --- /dev/null +++ b/node/jwt_handler.go @@ -0,0 +1,78 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +type jwtHandler struct { + keyFunc func(token *jwt.Token) (interface{}, error) + next http.Handler +} + +// newJWTHandler creates a http.Handler with jwt authentication support. +func newJWTHandler(secret []byte, next http.Handler) http.Handler { + return &jwtHandler{ + keyFunc: func(token *jwt.Token) (interface{}, error) { + return secret, nil + }, + next: next, + } +} + +// ServeHTTP implements http.Handler +func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) { + var ( + strToken string + claims jwt.RegisteredClaims + ) + if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { + strToken = strings.TrimPrefix(auth, "Bearer ") + } + if len(strToken) == 0 { + http.Error(out, "missing token", http.StatusForbidden) + return + } + // We explicitly set only HS256 allowed, and also disables the + // claim-check: the RegisteredClaims internally requires 'iat' to + // be no later than 'now', but we allow for a bit of drift. + token, err := jwt.ParseWithClaims(strToken, &claims, handler.keyFunc, + jwt.WithValidMethods([]string{"HS256"}), + jwt.WithoutClaimsValidation()) + + switch { + case err != nil: + http.Error(out, err.Error(), http.StatusForbidden) + case !token.Valid: + http.Error(out, "invalid token", http.StatusForbidden) + case !claims.VerifyExpiresAt(time.Now(), false): // optional + http.Error(out, "token is expired", http.StatusForbidden) + case claims.IssuedAt == nil: + http.Error(out, "missing issued-at", http.StatusForbidden) + case time.Since(claims.IssuedAt.Time) > 5*time.Second: + http.Error(out, "stale token", http.StatusForbidden) + case time.Until(claims.IssuedAt.Time) > 5*time.Second: + http.Error(out, "future token", http.StatusForbidden) + default: + handler.next.ServeHTTP(out, r) + } +} diff --git a/node/node.go b/node/node.go index ceab1c9090..135fae7942 100644 --- a/node/node.go +++ b/node/node.go @@ -17,6 +17,7 @@ package node import ( + crand "crypto/rand" "errors" "fmt" "net/http" @@ -27,6 +28,8 @@ import ( "sync" "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/event" @@ -55,6 +58,8 @@ type Node struct { rpcAPIs []rpc.API // List of APIs currently provided by the node http *httpServer // ws *httpServer // + httpAuth *httpServer // + wsAuth *httpServer // ipc *ipcServer // Stores information about the ipc http server inprocHandler *rpc.Server // In-process RPC request handler to process the API requests @@ -147,7 +152,9 @@ func New(conf *Config) (*Node, error) { // Configure RPC servers. node.http = newHTTPServer(node.log, conf.HTTPTimeouts) + node.httpAuth = newHTTPServer(node.log, conf.HTTPTimeouts) node.ws = newHTTPServer(node.log, rpc.DefaultHTTPTimeouts) + node.wsAuth = newHTTPServer(node.log, rpc.DefaultHTTPTimeouts) node.ipc = newIPCServer(node.log, conf.IPCEndpoint()) return node, nil @@ -335,7 +342,50 @@ func (n *Node) closeDataDir() { } } -// configureRPC is a helper method to configure all the various RPC endpoints during node +// obtainJWTSecret loads the jwt-secret, either from the provided config, +// or from the default location. If neither of those are present, it generates +// a new secret and stores to the default location. +func (n *Node) obtainJWTSecret(cliParam string) ([]byte, error) { + var fileName string + if len(cliParam) > 0 { + // If a plaintext secret was provided via cli flags, use that + jwtSecret := common.FromHex(cliParam) + if len(jwtSecret) == 32 && strings.HasPrefix(cliParam, "0x") { + log.Warn("Plaintext JWT secret provided, please consider passing via file") + return jwtSecret, nil + } + // path provided + fileName = cliParam + } else { + // no path provided, use default + fileName = n.ResolvePath(datadirJWTKey) + } + // try reading from file + log.Debug("Reading JWT secret", "path", fileName) + if data, err := os.ReadFile(fileName); err == nil { + jwtSecret := common.FromHex(strings.TrimSpace(string(data))) + if len(jwtSecret) == 32 { + return jwtSecret, nil + } + log.Error("Invalid JWT secret", "path", fileName, "length", len(jwtSecret)) + return nil, errors.New("invalid JWT secret") + } + // Need to generate one + jwtSecret := make([]byte, 32) + crand.Read(jwtSecret) + // if we're in --dev mode, don't bother saving, just show it + if fileName == "" { + log.Info("Generated ephemeral JWT secret", "secret", hexutil.Encode(jwtSecret)) + return jwtSecret, nil + } + if err := os.WriteFile(fileName, []byte(hexutil.Encode(jwtSecret)), 0600); err != nil { + return nil, err + } + log.Info("Generated JWT secret", "path", fileName) + return jwtSecret, nil +} + +// startRPC is a helper method to configure all the various RPC endpoints during node // startup. It's not meant to be called at any time afterwards as it makes certain // assumptions about the state of the node. func (n *Node) startRPC() error { @@ -349,55 +399,123 @@ func (n *Node) startRPC() error { return err } } + var ( + servers []*httpServer + open, all = n.GetAPIs() + ) - // Configure HTTP. - if n.config.HTTPHost != "" { - config := httpConfig{ + initHttp := func(server *httpServer, apis []rpc.API, port int) error { + if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { + return err + } + if err := server.enableRPC(apis, httpConfig{ CorsAllowedOrigins: n.config.HTTPCors, Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, - } - if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil { - return err - } - if err := n.http.enableRPC(n.rpcAPIs, config); err != nil { + }); err != nil { return err } + servers = append(servers, server) + return nil } - - // Configure WebSocket. - if n.config.WSHost != "" { - server := n.wsServerForPort(n.config.WSPort) - config := wsConfig{ + initWS := func(apis []rpc.API, port int) error { + server := n.wsServerForPort(port, false) + if err := server.setListenAddr(n.config.WSHost, port); err != nil { + return err + } + if err := server.enableWS(n.rpcAPIs, wsConfig{ Modules: n.config.WSModules, Origins: n.config.WSOrigins, prefix: n.config.WSPathPrefix, - } - if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil { - return err - } - if err := server.enableWS(n.rpcAPIs, config); err != nil { + }); err != nil { return err } + servers = append(servers, server) + return nil } - if err := n.http.start(); err != nil { - return err + initAuth := func(apis []rpc.API, port int, secret []byte) error { + // Enable auth via HTTP + server := n.httpAuth + if err := server.setListenAddr(DefaultAuthHost, port); err != nil { + return err + } + if err := server.enableRPC(apis, httpConfig{ + CorsAllowedOrigins: DefaultAuthCors, + Vhosts: DefaultAuthVhosts, + Modules: DefaultAuthModules, + prefix: DefaultAuthPrefix, + jwtSecret: secret, + }); err != nil { + return err + } + servers = append(servers, server) + // Enable auth via WS + server = n.wsServerForPort(port, true) + if err := server.setListenAddr(DefaultAuthHost, port); err != nil { + return err + } + if err := server.enableWS(apis, wsConfig{ + Modules: DefaultAuthModules, + Origins: DefaultAuthOrigins, + prefix: DefaultAuthPrefix, + jwtSecret: secret, + }); err != nil { + return err + } + servers = append(servers, server) + return nil } - return n.ws.start() + // Set up HTTP. + if n.config.HTTPHost != "" { + // Configure legacy unauthenticated HTTP. + if err := initHttp(n.http, open, n.config.HTTPPort); err != nil { + return err + } + } + // Configure WebSocket. + if n.config.WSHost != "" { + // legacy unauthenticated + if err := initWS(open, n.config.WSPort); err != nil { + return err + } + } + // Configure authenticated API + if len(open) != len(all) { + jwtSecret, err := n.obtainJWTSecret(n.config.JWTSecret) + if err != nil { + return err + } + if err := initAuth(all, n.config.AuthPort, jwtSecret); err != nil { + return err + } + } + // Start the servers + for _, server := range servers { + if err := server.start(); err != nil { + return err + } + } + return nil } -func (n *Node) wsServerForPort(port int) *httpServer { - if n.config.HTTPHost == "" || n.http.port == port { - return n.http +func (n *Node) wsServerForPort(port int, authenticated bool) *httpServer { + httpServer, wsServer := n.http, n.ws + if authenticated { + httpServer, wsServer = n.httpAuth, n.wsAuth } - return n.ws + if n.config.HTTPHost == "" || httpServer.port == port { + return httpServer + } + return wsServer } func (n *Node) stopRPC() { n.http.stop() n.ws.stop() + n.httpAuth.stop() + n.wsAuth.stop() n.ipc.stop() n.stopInProc() } @@ -458,6 +576,17 @@ func (n *Node) RegisterAPIs(apis []rpc.API) { n.rpcAPIs = append(n.rpcAPIs, apis...) } +// GetAPIs return two sets of APIs, both the ones that do not require +// authentication, and the complete set +func (n *Node) GetAPIs() (unauthenticated, all []rpc.API) { + for _, api := range n.rpcAPIs { + if !api.Authenticated { + unauthenticated = append(unauthenticated, api) + } + } + return unauthenticated, n.rpcAPIs +} + // RegisterHandler mounts a handler on the given path on the canonical HTTP server. // // The name of the handler is shown in a log message when the HTTP server starts diff --git a/node/node_test.go b/node/node_test.go index 25cfa9d38d..84f61f0c44 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -577,13 +577,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { } } for _, path := range test.wantWS { - err := wsRequest(t, wsBase+path, "") + err := wsRequest(t, wsBase+path) if err != nil { t.Errorf("Error: %s: WebSocket connection failed: %v", path, err) } } for _, path := range test.wantNoWS { - err := wsRequest(t, wsBase+path, "") + err := wsRequest(t, wsBase+path) if err == nil { t.Errorf("Error: %s: WebSocket connection succeeded for path in wantNoWS", path) } diff --git a/node/rpcstack.go b/node/rpcstack.go index 2c55a070b2..d9c41cca57 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -40,13 +40,15 @@ type httpConfig struct { CorsAllowedOrigins []string Vhosts []string prefix string // path prefix on which to mount http handler + jwtSecret []byte // optional JWT secret } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { - Origins []string - Modules []string - prefix string // path prefix on which to mount ws handler + Origins []string + Modules []string + prefix string // path prefix on which to mount ws handler + jwtSecret []byte // optional JWT secret } type rpcHandler struct { @@ -157,7 +159,7 @@ func (h *httpServer) start() error { } // Log http endpoint. h.log.Info("HTTP server started", - "endpoint", listener.Addr(), + "endpoint", listener.Addr(), "auth", (h.httpConfig.jwtSecret != nil), "prefix", h.httpConfig.prefix, "cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","), "vhosts", strings.Join(h.httpConfig.Vhosts, ","), @@ -285,7 +287,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { } h.httpConfig = config h.httpHandler.Store(&rpcHandler{ - Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts), + Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts, config.jwtSecret), server: srv, }) return nil @@ -309,7 +311,6 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { if h.wsAllowed() { return fmt.Errorf("JSON-RPC over WebSocket is already enabled") } - // Create RPC server and handler. srv := rpc.NewServer() if err := RegisterApis(apis, config.Modules, srv, false); err != nil { @@ -317,7 +318,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } h.wsConfig = config h.wsHandler.Store(&rpcHandler{ - Handler: srv.WebsocketHandler(config.Origins), + Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret), server: srv, }) return nil @@ -362,13 +363,24 @@ func isWebsocket(r *http.Request) bool { } // NewHTTPHandlerStack returns wrapped http-related handlers -func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler { +func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSecret []byte) http.Handler { // Wrap the CORS-handler within a host-handler handler := newCorsHandler(srv, cors) handler = newVHostHandler(vhosts, handler) + if len(jwtSecret) != 0 { + handler = newJWTHandler(jwtSecret, handler) + } return newGzipHandler(handler) } +// NewWSHandlerStack returns a wrapped ws-related handler. +func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler { + if len(jwtSecret) != 0 { + return newJWTHandler(jwtSecret, srv) + } + return srv +} + func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { // disable CORS support if user has not specified a custom CORS configuration if len(allowedOrigins) == 0 { diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index f92f0ba396..60fcab5a90 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -24,10 +24,12 @@ import ( "strconv" "strings" "testing" + "time" "github.com/ethereum/go-ethereum/internal/testlog" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/rpc" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -146,12 +148,12 @@ func TestWebsocketOrigins(t *testing.T) { srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) url := fmt.Sprintf("ws://%v", srv.listenAddr()) for _, origin := range tc.expOk { - if err := wsRequest(t, url, origin); err != nil { + if err := wsRequest(t, url, "Origin", origin); err != nil { t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) } } for _, origin := range tc.expFail { - if err := wsRequest(t, url, origin); err == nil { + if err := wsRequest(t, url, "Origin", origin); err == nil { t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) } } @@ -243,13 +245,18 @@ func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsCon } // wsRequest attempts to open a WebSocket connection to the given URL. -func wsRequest(t *testing.T, url, browserOrigin string) error { +func wsRequest(t *testing.T, url string, extraHeaders ...string) error { t.Helper() - t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) + //t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) headers := make(http.Header) - if browserOrigin != "" { - headers.Set("Origin", browserOrigin) + // Apply extra headers. + if len(extraHeaders)%2 != 0 { + panic("odd extraHeaders length") + } + for i := 0; i < len(extraHeaders); i += 2 { + key, value := extraHeaders[i], extraHeaders[i+1] + headers.Set(key, value) } conn, _, err := websocket.DefaultDialer.Dial(url, headers) if conn != nil { @@ -291,3 +298,79 @@ func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response } return resp } + +type testClaim map[string]interface{} + +func (testClaim) Valid() error { + return nil +} + +func TestJWT(t *testing.T) { + var secret = []byte("secret") + issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string { + if method == nil { + method = jwt.SigningMethodHS256 + } + ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) + return ss + } + expOk := []string{ + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "exp": time.Now().Unix() + 2, + })), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "bar": "baz", + })), + } + expFail := []string{ + // future + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})), + // stale + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})), + // wrong algo + fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})), + // expired + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})), + // missing mandatory iat + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})), + // wrong secret + fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})), + // Various malformed syntax + fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer \t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + } + srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, + true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) + htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) + + for i, token := range expOk { + if err := wsRequest(t, wsUrl, "Authorization", token); err != nil { + t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) + } + } + for i, token := range expFail { + if err := wsRequest(t, wsUrl, "Authorization", token); err == nil { + t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 { + t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) + } + } + srv.stop() +} diff --git a/rpc/server.go b/rpc/server.go index e2d5c03835..babc5688e2 100644 --- a/rpc/server.go +++ b/rpc/server.go @@ -26,6 +26,7 @@ import ( ) const MetadataApi = "rpc" +const EngineApi = "engine" // CodecOption specifies which type of messages a codec supports. // diff --git a/rpc/types.go b/rpc/types.go index 959e383723..46b08caf68 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -30,10 +30,11 @@ import ( // API describes the set of methods offered over the RPC interface type API struct { - Namespace string // namespace under which the rpc methods of Service are exposed - Version string // api version for DApp's - Service interface{} // receiver instance which holds the methods - Public bool // indication if the methods must be considered safe for public use + Namespace string // namespace under which the rpc methods of Service are exposed + Version string // api version for DApp's + Service interface{} // receiver instance which holds the methods + Public bool // indication if the methods must be considered safe for public use + Authenticated bool // whether the api should only be available behind authentication. } // ServerCodec implements reading, parsing and writing RPC messages for the server side of diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 8659f798e4..f74b7fd08b 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -76,7 +76,7 @@ func TestWebsocketOriginCheck(t *testing.T) { // Connections without origin header should work. client, err = DialWebsocket(context.Background(), wsURL, "") if err != nil { - t.Fatal("error for empty origin") + t.Fatalf("error for empty origin: %v", err) } client.Close() }