From c5761317f4918545d8c5598c5d6204ded534f51e Mon Sep 17 00:00:00 2001 From: Brian Ray <62354532+emeryray2002@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:14:23 -0400 Subject: [PATCH] MCP OAuth Part 1 - OAuth Infrastructure (#4316) --- packages/core/src/index.ts | 14 + packages/core/src/mcp/oauth-provider.test.ts | 720 ++++++++++++++++++ packages/core/src/mcp/oauth-provider.ts | 698 +++++++++++++++++ .../core/src/mcp/oauth-token-storage.test.ts | 325 ++++++++ packages/core/src/mcp/oauth-token-storage.ts | 205 +++++ packages/core/src/mcp/oauth-utils.test.ts | 206 +++++ packages/core/src/mcp/oauth-utils.ts | 285 +++++++ 7 files changed, 2453 insertions(+) create mode 100644 packages/core/src/mcp/oauth-provider.test.ts create mode 100644 packages/core/src/mcp/oauth-provider.ts create mode 100644 packages/core/src/mcp/oauth-token-storage.test.ts create mode 100644 packages/core/src/mcp/oauth-token-storage.ts create mode 100644 packages/core/src/mcp/oauth-utils.test.ts create mode 100644 packages/core/src/mcp/oauth-utils.ts diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index ffc06866..0aab6106 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -60,6 +60,20 @@ export * from './tools/read-many-files.js'; export * from './tools/mcp-client.js'; export * from './tools/mcp-tool.js'; +// MCP OAuth +export { MCPOAuthProvider } from './mcp/oauth-provider.js'; +export { + MCPOAuthToken, + MCPOAuthCredentials, + MCPOAuthTokenStorage, +} from './mcp/oauth-token-storage.js'; +export type { MCPOAuthConfig } from './mcp/oauth-provider.js'; +export type { + OAuthAuthorizationServerMetadata, + OAuthProtectedResourceMetadata, +} from './mcp/oauth-utils.js'; +export { OAuthUtils } from './mcp/oauth-utils.js'; + // Export telemetry functions export * from './telemetry/index.js'; export { sessionId } from './utils/session.js'; diff --git a/packages/core/src/mcp/oauth-provider.test.ts b/packages/core/src/mcp/oauth-provider.test.ts new file mode 100644 index 00000000..41938969 --- /dev/null +++ b/packages/core/src/mcp/oauth-provider.test.ts @@ -0,0 +1,720 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import * as http from 'node:http'; +import * as crypto from 'node:crypto'; +import open from 'open'; +import { + MCPOAuthProvider, + MCPOAuthConfig, + OAuthTokenResponse, + OAuthClientRegistrationResponse, +} from './oauth-provider.js'; +import { MCPOAuthTokenStorage, MCPOAuthToken } from './oauth-token-storage.js'; + +// Mock dependencies +vi.mock('open'); +vi.mock('node:crypto'); +vi.mock('./oauth-token-storage.js'); + +// Mock fetch globally +const mockFetch = vi.fn(); +global.fetch = mockFetch; + +// Define a reusable mock server with .listen, .close, and .on methods +const mockHttpServer = { + listen: vi.fn(), + close: vi.fn(), + on: vi.fn(), +}; +vi.mock('node:http', () => ({ + createServer: vi.fn(() => mockHttpServer), +})); + +describe('MCPOAuthProvider', () => { + const mockConfig: MCPOAuthConfig = { + enabled: true, + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + authorizationUrl: 'https://auth.example.com/authorize', + tokenUrl: 'https://auth.example.com/token', + scopes: ['read', 'write'], + redirectUri: 'http://localhost:7777/oauth/callback', + }; + + const mockToken: MCPOAuthToken = { + accessToken: 'access_token_123', + refreshToken: 'refresh_token_456', + tokenType: 'Bearer', + scope: 'read write', + expiresAt: Date.now() + 3600000, + }; + + const mockTokenResponse: OAuthTokenResponse = { + access_token: 'access_token_123', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'refresh_token_456', + scope: 'read write', + }; + + beforeEach(() => { + vi.clearAllMocks(); + vi.spyOn(console, 'log').mockImplementation(() => {}); + vi.spyOn(console, 'warn').mockImplementation(() => {}); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Mock crypto functions + vi.mocked(crypto.randomBytes).mockImplementation((size: number) => { + if (size === 32) return Buffer.from('code_verifier_mock_32_bytes_long'); + if (size === 16) return Buffer.from('state_mock_16_by'); + return Buffer.alloc(size); + }); + + vi.mocked(crypto.createHash).mockReturnValue({ + update: vi.fn().mockReturnThis(), + digest: vi.fn().mockReturnValue('code_challenge_mock'), + } as unknown as crypto.Hash); + + // Mock randomBytes to return predictable values for state + vi.mocked(crypto.randomBytes).mockImplementation((size) => { + if (size === 32) { + return Buffer.from('mock_code_verifier_32_bytes_long_string'); + } else if (size === 16) { + return Buffer.from('mock_state_16_bytes'); + } + return Buffer.alloc(size); + }); + + // Mock token storage + vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined); + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('authenticate', () => { + it('should perform complete OAuth flow with PKCE', async () => { + // Mock HTTP server callback + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + // Simulate OAuth callback + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + // Mock token exchange + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + const result = await MCPOAuthProvider.authenticate( + 'test-server', + mockConfig, + ); + + expect(result).toEqual({ + accessToken: 'access_token_123', + refreshToken: 'refresh_token_456', + tokenType: 'Bearer', + scope: 'read write', + expiresAt: expect.any(Number), + }); + + expect(open).toHaveBeenCalledWith(expect.stringContaining('authorize')); + expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( + 'test-server', + expect.objectContaining({ accessToken: 'access_token_123' }), + 'test-client-id', + 'https://auth.example.com/token', + ); + }); + + it('should handle OAuth discovery when no authorization URL provided', async () => { + // Use a mutable config object + const configWithoutAuth: MCPOAuthConfig = { ...mockConfig }; + delete configWithoutAuth.authorizationUrl; + delete configWithoutAuth.tokenUrl; + + const mockResourceMetadata = { + authorization_servers: ['https://discovered.auth.com'], + }; + + const mockAuthServerMetadata = { + authorization_endpoint: 'https://discovered.auth.com/authorize', + token_endpoint: 'https://discovered.auth.com/token', + scopes_supported: ['read', 'write'], + }; + + mockFetch + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockResourceMetadata), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockAuthServerMetadata), + }); + + // Patch config after discovery + configWithoutAuth.authorizationUrl = + mockAuthServerMetadata.authorization_endpoint; + configWithoutAuth.tokenUrl = mockAuthServerMetadata.token_endpoint; + configWithoutAuth.scopes = mockAuthServerMetadata.scopes_supported; + + // Setup callback handler + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + // Mock token exchange with discovered endpoint + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + const result = await MCPOAuthProvider.authenticate( + 'test-server', + configWithoutAuth, + 'https://api.example.com', + ); + + expect(result).toBeDefined(); + expect(mockFetch).toHaveBeenCalledWith( + 'https://discovered.auth.com/token', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + }), + ); + }); + + it('should perform dynamic client registration when no client ID provided', async () => { + const configWithoutClient = { ...mockConfig }; + delete configWithoutClient.clientId; + + const mockRegistrationResponse: OAuthClientRegistrationResponse = { + client_id: 'dynamic_client_id', + client_secret: 'dynamic_client_secret', + redirect_uris: ['http://localhost:7777/oauth/callback'], + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + + const mockAuthServerMetadata = { + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + }; + + mockFetch + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockAuthServerMetadata), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockRegistrationResponse), + }); + + // Setup callback handler + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + // Mock token exchange + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + const result = await MCPOAuthProvider.authenticate( + 'test-server', + configWithoutClient, + ); + + expect(result).toBeDefined(); + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.example.com/register', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + }), + ); + }); + + it('should handle OAuth callback errors', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?error=access_denied&error_description=User%20denied%20access', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('OAuth error: access_denied'); + }); + + it('should handle state mismatch in callback', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=wrong_state', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('State mismatch - possible CSRF attack'); + }); + + it('should handle token exchange failure', async () => { + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('Invalid grant'), + }); + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('Token exchange failed: 400 - Invalid grant'); + }); + + it('should handle callback timeout', async () => { + vi.mocked(http.createServer).mockImplementation( + () => mockHttpServer as unknown as http.Server, + ); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + // Don't trigger callback - simulate timeout + }); + + // Mock setTimeout to trigger timeout immediately + const originalSetTimeout = global.setTimeout; + global.setTimeout = vi.fn((callback, delay) => { + if (delay === 5 * 60 * 1000) { + // 5 minute timeout + callback(); + } + return originalSetTimeout(callback, 0); + }) as unknown as typeof setTimeout; + + await expect( + MCPOAuthProvider.authenticate('test-server', mockConfig), + ).rejects.toThrow('OAuth callback timeout'); + + global.setTimeout = originalSetTimeout; + }); + }); + + describe('refreshAccessToken', () => { + it('should refresh token successfully', async () => { + const refreshResponse = { + access_token: 'new_access_token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new_refresh_token', + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(refreshResponse), + }); + + const result = await MCPOAuthProvider.refreshAccessToken( + mockConfig, + 'old_refresh_token', + 'https://auth.example.com/token', + ); + + expect(result).toEqual(refreshResponse); + expect(mockFetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: expect.stringContaining('grant_type=refresh_token'), + }), + ); + }); + + it('should include client secret in refresh request when available', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + await MCPOAuthProvider.refreshAccessToken( + mockConfig, + 'refresh_token', + 'https://auth.example.com/token', + ); + + const fetchCall = mockFetch.mock.calls[0]; + expect(fetchCall[1].body).toContain('client_secret=test-client-secret'); + }); + + it('should handle refresh token failure', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('Invalid refresh token'), + }); + + await expect( + MCPOAuthProvider.refreshAccessToken( + mockConfig, + 'invalid_refresh_token', + 'https://auth.example.com/token', + ), + ).rejects.toThrow('Token refresh failed: 400 - Invalid refresh token'); + }); + }); + + describe('getValidToken', () => { + it('should return valid token when not expired', async () => { + const validCredentials = { + serverName: 'test-server', + token: mockToken, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + validCredentials, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBe('access_token_123'); + }); + + it('should refresh expired token and return new token', async () => { + const expiredCredentials = { + serverName: 'test-server', + token: { ...mockToken, expiresAt: Date.now() - 3600000 }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + expiredCredentials, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + + const refreshResponse = { + access_token: 'new_access_token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new_refresh_token', + }; + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(refreshResponse), + }); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBe('new_access_token'); + expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith( + 'test-server', + expect.objectContaining({ accessToken: 'new_access_token' }), + 'test-client-id', + 'https://auth.example.com/token', + ); + }); + + it('should return null when no credentials exist', async () => { + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + }); + + it('should handle refresh failure and remove invalid token', async () => { + const expiredCredentials = { + serverName: 'test-server', + token: { ...mockToken, expiresAt: Date.now() - 3600000 }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + expiredCredentials, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined); + + mockFetch.mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('Invalid refresh token'), + }); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith( + 'test-server', + ); + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to refresh token'), + ); + }); + + it('should return null for token without refresh capability', async () => { + const tokenWithoutRefresh = { + serverName: 'test-server', + token: { + ...mockToken, + refreshToken: undefined, + expiresAt: Date.now() - 3600000, + }, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue( + tokenWithoutRefresh, + ); + vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true); + + const result = await MCPOAuthProvider.getValidToken( + 'test-server', + mockConfig, + ); + + expect(result).toBeNull(); + }); + }); + + describe('PKCE parameter generation', () => { + it('should generate valid PKCE parameters', async () => { + // Test is implicit in the authenticate flow tests, but we can verify + // the crypto mocks are called correctly + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + await MCPOAuthProvider.authenticate('test-server', mockConfig); + + expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier + expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state + expect(crypto.createHash).toHaveBeenCalledWith('sha256'); + }); + }); + + describe('Authorization URL building', () => { + it('should build correct authorization URL with all parameters', async () => { + // Mock to capture the URL that would be opened + let capturedUrl: string; + vi.mocked(open).mockImplementation((url) => { + capturedUrl = url; + // Return a minimal mock ChildProcess + return Promise.resolve({ + pid: 1234, + } as unknown as import('child_process').ChildProcess); + }); + + let callbackHandler: unknown; + vi.mocked(http.createServer).mockImplementation((handler) => { + callbackHandler = handler; + return mockHttpServer as unknown as http.Server; + }); + + mockHttpServer.listen.mockImplementation((port, callback) => { + callback?.(); + setTimeout(() => { + const mockReq = { + url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw', + }; + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + }; + (callbackHandler as (req: unknown, res: unknown) => void)( + mockReq, + mockRes, + ); + }, 10); + }); + + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockTokenResponse), + }); + + await MCPOAuthProvider.authenticate('test-server', mockConfig); + + expect(capturedUrl!).toContain('response_type=code'); + expect(capturedUrl!).toContain('client_id=test-client-id'); + expect(capturedUrl!).toContain('code_challenge=code_challenge_mock'); + expect(capturedUrl!).toContain('code_challenge_method=S256'); + expect(capturedUrl!).toContain('scope=read+write'); + expect(capturedUrl!).toContain('resource=https%3A%2F%2Fauth.example.com'); + }); + }); +}); diff --git a/packages/core/src/mcp/oauth-provider.ts b/packages/core/src/mcp/oauth-provider.ts new file mode 100644 index 00000000..51f5b2d6 --- /dev/null +++ b/packages/core/src/mcp/oauth-provider.ts @@ -0,0 +1,698 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as http from 'node:http'; +import * as crypto from 'node:crypto'; +import { URL } from 'node:url'; +import open from 'open'; +import { MCPOAuthToken, MCPOAuthTokenStorage } from './oauth-token-storage.js'; +import { getErrorMessage } from '../utils/errors.js'; +import { OAuthUtils } from './oauth-utils.js'; + +/** + * OAuth configuration for an MCP server. + */ +export interface MCPOAuthConfig { + enabled?: boolean; // Whether OAuth is enabled for this server + clientId?: string; + clientSecret?: string; + authorizationUrl?: string; + tokenUrl?: string; + scopes?: string[]; + redirectUri?: string; + tokenParamName?: string; // For SSE connections, specifies the query parameter name for the token +} + +/** + * OAuth authorization response. + */ +export interface OAuthAuthorizationResponse { + code: string; + state: string; +} + +/** + * OAuth token response from the authorization server. + */ +export interface OAuthTokenResponse { + access_token: string; + token_type: string; + expires_in?: number; + refresh_token?: string; + scope?: string; +} + +/** + * Dynamic client registration request. + */ +export interface OAuthClientRegistrationRequest { + client_name: string; + redirect_uris: string[]; + grant_types: string[]; + response_types: string[]; + token_endpoint_auth_method: string; + code_challenge_method?: string[]; + scope?: string; +} + +/** + * Dynamic client registration response. + */ +export interface OAuthClientRegistrationResponse { + client_id: string; + client_secret?: string; + client_id_issued_at?: number; + client_secret_expires_at?: number; + redirect_uris: string[]; + grant_types: string[]; + response_types: string[]; + token_endpoint_auth_method: string; + code_challenge_method?: string[]; + scope?: string; +} + +/** + * PKCE (Proof Key for Code Exchange) parameters. + */ +interface PKCEParams { + codeVerifier: string; + codeChallenge: string; + state: string; +} + +/** + * Provider for handling OAuth authentication for MCP servers. + */ +export class MCPOAuthProvider { + private static readonly REDIRECT_PORT = 7777; + private static readonly REDIRECT_PATH = '/oauth/callback'; + private static readonly HTTP_OK = 200; + private static readonly HTTP_REDIRECT = 302; + + /** + * Register a client dynamically with the OAuth server. + * + * @param registrationUrl The client registration endpoint URL + * @param config OAuth configuration + * @returns The registered client information + */ + private static async registerClient( + registrationUrl: string, + config: MCPOAuthConfig, + ): Promise { + const redirectUri = + config.redirectUri || + `http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`; + + const registrationRequest: OAuthClientRegistrationRequest = { + client_name: 'Gemini CLI MCP Client', + redirect_uris: [redirectUri], + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', // Public client + code_challenge_method: ['S256'], + scope: config.scopes?.join(' ') || '', + }; + + const response = await fetch(registrationUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(registrationRequest), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `Client registration failed: ${response.status} ${response.statusText} - ${errorText}`, + ); + } + + return (await response.json()) as OAuthClientRegistrationResponse; + } + + /** + * Discover OAuth configuration from an MCP server URL. + * + * @param mcpServerUrl The MCP server URL + * @returns OAuth configuration if discovered, null otherwise + */ + private static async discoverOAuthFromMCPServer( + mcpServerUrl: string, + ): Promise { + const baseUrl = OAuthUtils.extractBaseUrl(mcpServerUrl); + return OAuthUtils.discoverOAuthConfig(baseUrl); + } + + /** + * Generate PKCE parameters for OAuth flow. + * + * @returns PKCE parameters including code verifier, challenge, and state + */ + private static generatePKCEParams(): PKCEParams { + // Generate code verifier (43-128 characters) + const codeVerifier = crypto.randomBytes(32).toString('base64url'); + + // Generate code challenge using SHA256 + const codeChallenge = crypto + .createHash('sha256') + .update(codeVerifier) + .digest('base64url'); + + // Generate state for CSRF protection + const state = crypto.randomBytes(16).toString('base64url'); + + return { codeVerifier, codeChallenge, state }; + } + + /** + * Start a local HTTP server to handle OAuth callback. + * + * @param expectedState The state parameter to validate + * @returns Promise that resolves with the authorization code + */ + private static async startCallbackServer( + expectedState: string, + ): Promise { + return new Promise((resolve, reject) => { + const server = http.createServer( + async (req: http.IncomingMessage, res: http.ServerResponse) => { + try { + const url = new URL( + req.url!, + `http://localhost:${this.REDIRECT_PORT}`, + ); + + if (url.pathname !== this.REDIRECT_PATH) { + res.writeHead(404); + res.end('Not found'); + return; + } + + const code = url.searchParams.get('code'); + const state = url.searchParams.get('state'); + const error = url.searchParams.get('error'); + + if (error) { + res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' }); + res.end(` + + +

Authentication Failed

+

Error: ${(error as string).replace(//g, '>')}

+

${((url.searchParams.get('error_description') || '') as string).replace(//g, '>')}

+

You can close this window.

+ + + `); + server.close(); + reject(new Error(`OAuth error: ${error}`)); + return; + } + + if (!code || !state) { + res.writeHead(400); + res.end('Missing code or state parameter'); + return; + } + + if (state !== expectedState) { + res.writeHead(400); + res.end('Invalid state parameter'); + server.close(); + reject(new Error('State mismatch - possible CSRF attack')); + return; + } + + // Send success response to browser + res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' }); + res.end(` + + +

Authentication Successful!

+

You can close this window and return to Gemini CLI.

+ + + + `); + + server.close(); + resolve({ code, state }); + } catch (error) { + server.close(); + reject(error); + } + }, + ); + + server.on('error', reject); + server.listen(this.REDIRECT_PORT, () => { + console.log( + `OAuth callback server listening on port ${this.REDIRECT_PORT}`, + ); + }); + + // Timeout after 5 minutes + setTimeout( + () => { + server.close(); + reject(new Error('OAuth callback timeout')); + }, + 5 * 60 * 1000, + ); + }); + } + + /** + * Build the authorization URL with PKCE parameters. + * + * @param config OAuth configuration + * @param pkceParams PKCE parameters + * @returns The authorization URL + */ + private static buildAuthorizationUrl( + config: MCPOAuthConfig, + pkceParams: PKCEParams, + ): string { + const redirectUri = + config.redirectUri || + `http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`; + + const params = new URLSearchParams({ + client_id: config.clientId!, + response_type: 'code', + redirect_uri: redirectUri, + state: pkceParams.state, + code_challenge: pkceParams.codeChallenge, + code_challenge_method: 'S256', + }); + + if (config.scopes && config.scopes.length > 0) { + params.append('scope', config.scopes.join(' ')); + } + + // Add resource parameter for MCP OAuth spec compliance + params.append( + 'resource', + OAuthUtils.buildResourceParameter(config.authorizationUrl!), + ); + + return `${config.authorizationUrl}?${params.toString()}`; + } + + /** + * Exchange authorization code for tokens. + * + * @param config OAuth configuration + * @param code Authorization code + * @param codeVerifier PKCE code verifier + * @returns The token response + */ + private static async exchangeCodeForToken( + config: MCPOAuthConfig, + code: string, + codeVerifier: string, + ): Promise { + const redirectUri = + config.redirectUri || + `http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`; + + const params = new URLSearchParams({ + grant_type: 'authorization_code', + code, + redirect_uri: redirectUri, + code_verifier: codeVerifier, + client_id: config.clientId!, + }); + + if (config.clientSecret) { + params.append('client_secret', config.clientSecret); + } + + // Add resource parameter for MCP OAuth spec compliance + params.append( + 'resource', + OAuthUtils.buildResourceParameter(config.tokenUrl!), + ); + + const response = await fetch(config.tokenUrl!, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: params.toString(), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `Token exchange failed: ${response.status} - ${errorText}`, + ); + } + + return (await response.json()) as OAuthTokenResponse; + } + + /** + * Refresh an access token using a refresh token. + * + * @param config OAuth configuration + * @param refreshToken The refresh token + * @returns The new token response + */ + static async refreshAccessToken( + config: MCPOAuthConfig, + refreshToken: string, + tokenUrl: string, + ): Promise { + const params = new URLSearchParams({ + grant_type: 'refresh_token', + refresh_token: refreshToken, + client_id: config.clientId!, + }); + + if (config.clientSecret) { + params.append('client_secret', config.clientSecret); + } + + if (config.scopes && config.scopes.length > 0) { + params.append('scope', config.scopes.join(' ')); + } + + // Add resource parameter for MCP OAuth spec compliance + params.append('resource', OAuthUtils.buildResourceParameter(tokenUrl)); + + const response = await fetch(tokenUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: params.toString(), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error( + `Token refresh failed: ${response.status} - ${errorText}`, + ); + } + + return (await response.json()) as OAuthTokenResponse; + } + + /** + * Perform the full OAuth authorization code flow with PKCE. + * + * @param serverName The name of the MCP server + * @param config OAuth configuration + * @param mcpServerUrl Optional MCP server URL for OAuth discovery + * @returns The obtained OAuth token + */ + static async authenticate( + serverName: string, + config: MCPOAuthConfig, + mcpServerUrl?: string, + ): Promise { + // If no authorization URL is provided, try to discover OAuth configuration + if (!config.authorizationUrl && mcpServerUrl) { + console.log( + 'No authorization URL provided, attempting OAuth discovery...', + ); + + // For SSE URLs, first check if authentication is required + if (OAuthUtils.isSSEEndpoint(mcpServerUrl)) { + try { + const response = await fetch(mcpServerUrl, { + method: 'HEAD', + headers: { + Accept: 'text/event-stream', + }, + }); + + if (response.status === 401 || response.status === 307) { + const wwwAuthenticate = response.headers.get('www-authenticate'); + if (wwwAuthenticate) { + const discoveredConfig = + await OAuthUtils.discoverOAuthFromWWWAuthenticate( + wwwAuthenticate, + ); + if (discoveredConfig) { + config = { + ...config, + ...discoveredConfig, + scopes: discoveredConfig.scopes || config.scopes || [], + }; + } + } + } + } catch (error) { + console.debug( + `Failed to check SSE endpoint for authentication requirements: ${getErrorMessage(error)}`, + ); + } + } + + // If we still don't have OAuth config, try the standard discovery + if (!config.authorizationUrl) { + const discoveredConfig = + await this.discoverOAuthFromMCPServer(mcpServerUrl); + if (discoveredConfig) { + config = { ...config, ...discoveredConfig }; + console.log('OAuth configuration discovered successfully'); + } else { + throw new Error( + 'Failed to discover OAuth configuration from MCP server', + ); + } + } + } + + // If no client ID is provided, try dynamic client registration + if (!config.clientId) { + // Extract server URL from authorization URL + if (!config.authorizationUrl) { + throw new Error( + 'Cannot perform dynamic registration without authorization URL', + ); + } + + const authUrl = new URL(config.authorizationUrl); + const serverUrl = `${authUrl.protocol}//${authUrl.host}`; + + console.log( + 'No client ID provided, attempting dynamic client registration...', + ); + + // Get the authorization server metadata for registration + const authServerMetadataUrl = new URL( + '/.well-known/oauth-authorization-server', + serverUrl, + ).toString(); + + const authServerMetadata = + await OAuthUtils.fetchAuthorizationServerMetadata( + authServerMetadataUrl, + ); + if (!authServerMetadata) { + throw new Error( + 'Failed to fetch authorization server metadata for client registration', + ); + } + + // Register client if registration endpoint is available + if (authServerMetadata.registration_endpoint) { + const clientRegistration = await this.registerClient( + authServerMetadata.registration_endpoint, + config, + ); + + config.clientId = clientRegistration.client_id; + if (clientRegistration.client_secret) { + config.clientSecret = clientRegistration.client_secret; + } + + console.log('Dynamic client registration successful'); + } else { + throw new Error( + 'No client ID provided and dynamic registration not supported', + ); + } + } + + // Validate configuration + if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) { + throw new Error( + 'Missing required OAuth configuration after discovery and registration', + ); + } + + // Generate PKCE parameters + const pkceParams = this.generatePKCEParams(); + + // Build authorization URL + const authUrl = this.buildAuthorizationUrl(config, pkceParams); + + console.log('\nOpening browser for OAuth authentication...'); + console.log('If the browser does not open, please visit:'); + console.log(''); + + // Get terminal width or default to 80 + const terminalWidth = process.stdout.columns || 80; + const separatorLength = Math.min(terminalWidth - 2, 80); + const separator = '━'.repeat(separatorLength); + + console.log(separator); + console.log( + 'COPY THE ENTIRE URL BELOW (select all text between the lines):', + ); + console.log(separator); + console.log(authUrl); + console.log(separator); + console.log(''); + console.log( + '💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.', + ); + console.log( + '⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.', + ); + console.log(''); + + // Start callback server + const callbackPromise = this.startCallbackServer(pkceParams.state); + + // Open browser + try { + await open(authUrl); + } catch (error) { + console.warn( + 'Failed to open browser automatically:', + getErrorMessage(error), + ); + } + + // Wait for callback + const { code } = await callbackPromise; + + console.log('\nAuthorization code received, exchanging for tokens...'); + + // Exchange code for tokens + const tokenResponse = await this.exchangeCodeForToken( + config, + code, + pkceParams.codeVerifier, + ); + + // Convert to our token format + const token: MCPOAuthToken = { + accessToken: tokenResponse.access_token, + tokenType: tokenResponse.token_type, + refreshToken: tokenResponse.refresh_token, + scope: tokenResponse.scope, + }; + + if (tokenResponse.expires_in) { + token.expiresAt = Date.now() + tokenResponse.expires_in * 1000; + } + + // Save token + try { + await MCPOAuthTokenStorage.saveToken( + serverName, + token, + config.clientId, + config.tokenUrl, + ); + console.log('Authentication successful! Token saved.'); + + // Verify token was saved + const savedToken = await MCPOAuthTokenStorage.getToken(serverName); + if (savedToken) { + console.log( + `Token verification successful: ${savedToken.token.accessToken.substring(0, 20)}...`, + ); + } else { + console.error('Token verification failed: token not found after save'); + } + } catch (saveError) { + console.error(`Failed to save token: ${getErrorMessage(saveError)}`); + throw saveError; + } + + return token; + } + + /** + * Get a valid access token for an MCP server, refreshing if necessary. + * + * @param serverName The name of the MCP server + * @param config OAuth configuration + * @returns A valid access token or null if not authenticated + */ + static async getValidToken( + serverName: string, + config: MCPOAuthConfig, + ): Promise { + console.debug(`Getting valid token for server: ${serverName}`); + const credentials = await MCPOAuthTokenStorage.getToken(serverName); + + if (!credentials) { + console.debug(`No credentials found for server: ${serverName}`); + return null; + } + + const { token } = credentials; + console.debug( + `Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`, + ); + + // Check if token is expired + if (!MCPOAuthTokenStorage.isTokenExpired(token)) { + console.debug(`Returning valid token for server: ${serverName}`); + return token.accessToken; + } + + // Try to refresh if we have a refresh token + if (token.refreshToken && config.clientId && credentials.tokenUrl) { + try { + console.log(`Refreshing expired token for MCP server: ${serverName}`); + + const newTokenResponse = await this.refreshAccessToken( + config, + token.refreshToken, + credentials.tokenUrl, + ); + + // Update stored token + const newToken: MCPOAuthToken = { + accessToken: newTokenResponse.access_token, + tokenType: newTokenResponse.token_type, + refreshToken: newTokenResponse.refresh_token || token.refreshToken, + scope: newTokenResponse.scope || token.scope, + }; + + if (newTokenResponse.expires_in) { + newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000; + } + + await MCPOAuthTokenStorage.saveToken( + serverName, + newToken, + config.clientId, + credentials.tokenUrl, + ); + + return newToken.accessToken; + } catch (error) { + console.error(`Failed to refresh token: ${getErrorMessage(error)}`); + // Remove invalid token + await MCPOAuthTokenStorage.removeToken(serverName); + } + } + + return null; + } +} diff --git a/packages/core/src/mcp/oauth-token-storage.test.ts b/packages/core/src/mcp/oauth-token-storage.test.ts new file mode 100644 index 00000000..5fe2f3f5 --- /dev/null +++ b/packages/core/src/mcp/oauth-token-storage.test.ts @@ -0,0 +1,325 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { promises as fs } from 'node:fs'; +import * as path from 'node:path'; +import { + MCPOAuthTokenStorage, + MCPOAuthToken, + MCPOAuthCredentials, +} from './oauth-token-storage.js'; + +// Mock file system operations +vi.mock('node:fs', () => ({ + promises: { + readFile: vi.fn(), + writeFile: vi.fn(), + mkdir: vi.fn(), + unlink: vi.fn(), + }, +})); + +vi.mock('node:os', () => ({ + homedir: vi.fn(() => '/mock/home'), +})); + +describe('MCPOAuthTokenStorage', () => { + const mockToken: MCPOAuthToken = { + accessToken: 'access_token_123', + refreshToken: 'refresh_token_456', + tokenType: 'Bearer', + scope: 'read write', + expiresAt: Date.now() + 3600000, // 1 hour from now + }; + + const mockCredentials: MCPOAuthCredentials = { + serverName: 'test-server', + token: mockToken, + clientId: 'test-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + vi.spyOn(console, 'error').mockImplementation(() => {}); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('loadTokens', () => { + it('should return empty map when token file does not exist', async () => { + vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); + + const tokens = await MCPOAuthTokenStorage.loadTokens(); + + expect(tokens.size).toBe(0); + expect(console.error).not.toHaveBeenCalled(); + }); + + it('should load tokens from file successfully', async () => { + const tokensArray = [mockCredentials]; + vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray)); + + const tokens = await MCPOAuthTokenStorage.loadTokens(); + + expect(tokens.size).toBe(1); + expect(tokens.get('test-server')).toEqual(mockCredentials); + expect(fs.readFile).toHaveBeenCalledWith( + path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), + 'utf-8', + ); + }); + + it('should handle corrupted token file gracefully', async () => { + vi.mocked(fs.readFile).mockResolvedValue('invalid json'); + + const tokens = await MCPOAuthTokenStorage.loadTokens(); + + expect(tokens.size).toBe(0); + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to load MCP OAuth tokens'), + ); + }); + + it('should handle file read errors other than ENOENT', async () => { + const error = new Error('Permission denied'); + vi.mocked(fs.readFile).mockRejectedValue(error); + + const tokens = await MCPOAuthTokenStorage.loadTokens(); + + expect(tokens.size).toBe(0); + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to load MCP OAuth tokens'), + ); + }); + }); + + describe('saveToken', () => { + it('should save token with restricted permissions', async () => { + vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + vi.mocked(fs.writeFile).mockResolvedValue(undefined); + + await MCPOAuthTokenStorage.saveToken( + 'test-server', + mockToken, + 'client-id', + 'https://token.url', + ); + + expect(fs.mkdir).toHaveBeenCalledWith( + path.join('/mock/home', '.gemini'), + { recursive: true }, + ); + expect(fs.writeFile).toHaveBeenCalledWith( + path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), + expect.stringContaining('test-server'), + { mode: 0o600 }, + ); + }); + + it('should update existing token for same server', async () => { + const existingCredentials = { + ...mockCredentials, + serverName: 'existing-server', + }; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([existingCredentials]), + ); + vi.mocked(fs.writeFile).mockResolvedValue(undefined); + + const newToken = { ...mockToken, accessToken: 'new_access_token' }; + await MCPOAuthTokenStorage.saveToken('existing-server', newToken); + + const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; + const savedData = JSON.parse(writeCall[1] as string); + + expect(savedData).toHaveLength(1); + expect(savedData[0].token.accessToken).toBe('new_access_token'); + expect(savedData[0].serverName).toBe('existing-server'); + }); + + it('should handle write errors gracefully', async () => { + vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); + vi.mocked(fs.mkdir).mockResolvedValue(undefined); + const writeError = new Error('Disk full'); + vi.mocked(fs.writeFile).mockRejectedValue(writeError); + + await expect( + MCPOAuthTokenStorage.saveToken('test-server', mockToken), + ).rejects.toThrow('Disk full'); + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to save MCP OAuth token'), + ); + }); + }); + + describe('getToken', () => { + it('should return token for existing server', async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([mockCredentials]), + ); + + const result = await MCPOAuthTokenStorage.getToken('test-server'); + + expect(result).toEqual(mockCredentials); + }); + + it('should return null for non-existent server', async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([mockCredentials]), + ); + + const result = await MCPOAuthTokenStorage.getToken('non-existent'); + + expect(result).toBeNull(); + }); + + it('should return null when no tokens file exists', async () => { + vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); + + const result = await MCPOAuthTokenStorage.getToken('test-server'); + + expect(result).toBeNull(); + }); + }); + + describe('removeToken', () => { + it('should remove token for specific server', async () => { + const credentials1 = { ...mockCredentials, serverName: 'server1' }; + const credentials2 = { ...mockCredentials, serverName: 'server2' }; + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([credentials1, credentials2]), + ); + vi.mocked(fs.writeFile).mockResolvedValue(undefined); + + await MCPOAuthTokenStorage.removeToken('server1'); + + const writeCall = vi.mocked(fs.writeFile).mock.calls[0]; + const savedData = JSON.parse(writeCall[1] as string); + + expect(savedData).toHaveLength(1); + expect(savedData[0].serverName).toBe('server2'); + }); + + it('should remove token file when no tokens remain', async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([mockCredentials]), + ); + vi.mocked(fs.unlink).mockResolvedValue(undefined); + + await MCPOAuthTokenStorage.removeToken('test-server'); + + expect(fs.unlink).toHaveBeenCalledWith( + path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), + ); + expect(fs.writeFile).not.toHaveBeenCalled(); + }); + + it('should handle removal of non-existent token gracefully', async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([mockCredentials]), + ); + + await MCPOAuthTokenStorage.removeToken('non-existent'); + + expect(fs.writeFile).not.toHaveBeenCalled(); + expect(fs.unlink).not.toHaveBeenCalled(); + }); + + it('should handle file operation errors gracefully', async () => { + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify([mockCredentials]), + ); + vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); + + await MCPOAuthTokenStorage.removeToken('test-server'); + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to remove MCP OAuth token'), + ); + }); + }); + + describe('isTokenExpired', () => { + it('should return false for token without expiry', () => { + const tokenWithoutExpiry = { ...mockToken }; + delete tokenWithoutExpiry.expiresAt; + + const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry); + + expect(result).toBe(false); + }); + + it('should return false for valid token', () => { + const futureToken = { + ...mockToken, + expiresAt: Date.now() + 3600000, // 1 hour from now + }; + + const result = MCPOAuthTokenStorage.isTokenExpired(futureToken); + + expect(result).toBe(false); + }); + + it('should return true for expired token', () => { + const expiredToken = { + ...mockToken, + expiresAt: Date.now() - 3600000, // 1 hour ago + }; + + const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken); + + expect(result).toBe(true); + }); + + it('should return true for token expiring within buffer time', () => { + const soonToExpireToken = { + ...mockToken, + expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer) + }; + + const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken); + + expect(result).toBe(true); + }); + }); + + describe('clearAllTokens', () => { + it('should remove token file successfully', async () => { + vi.mocked(fs.unlink).mockResolvedValue(undefined); + + await MCPOAuthTokenStorage.clearAllTokens(); + + expect(fs.unlink).toHaveBeenCalledWith( + path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'), + ); + }); + + it('should handle non-existent file gracefully', async () => { + vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' }); + + await MCPOAuthTokenStorage.clearAllTokens(); + + expect(console.error).not.toHaveBeenCalled(); + }); + + it('should handle other file errors gracefully', async () => { + vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied')); + + await MCPOAuthTokenStorage.clearAllTokens(); + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to clear MCP OAuth tokens'), + ); + }); + }); +}); diff --git a/packages/core/src/mcp/oauth-token-storage.ts b/packages/core/src/mcp/oauth-token-storage.ts new file mode 100644 index 00000000..fc9da8af --- /dev/null +++ b/packages/core/src/mcp/oauth-token-storage.ts @@ -0,0 +1,205 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { promises as fs } from 'node:fs'; +import * as path from 'node:path'; +import * as os from 'node:os'; +import { getErrorMessage } from '../utils/errors.js'; + +/** + * Interface for MCP OAuth tokens. + */ +export interface MCPOAuthToken { + accessToken: string; + refreshToken?: string; + expiresAt?: number; + tokenType: string; + scope?: string; +} + +/** + * Interface for stored MCP OAuth credentials. + */ +export interface MCPOAuthCredentials { + serverName: string; + token: MCPOAuthToken; + clientId?: string; + tokenUrl?: string; + updatedAt: number; +} + +/** + * Class for managing MCP OAuth token storage and retrieval. + */ +export class MCPOAuthTokenStorage { + private static readonly TOKEN_FILE = 'mcp-oauth-tokens.json'; + private static readonly CONFIG_DIR = '.gemini'; + + /** + * Get the path to the token storage file. + * + * @returns The full path to the token storage file + */ + private static getTokenFilePath(): string { + const homeDir = os.homedir(); + return path.join(homeDir, this.CONFIG_DIR, this.TOKEN_FILE); + } + + /** + * Ensure the config directory exists. + */ + private static async ensureConfigDir(): Promise { + const configDir = path.dirname(this.getTokenFilePath()); + await fs.mkdir(configDir, { recursive: true }); + } + + /** + * Load all stored MCP OAuth tokens. + * + * @returns A map of server names to credentials + */ + static async loadTokens(): Promise> { + const tokenMap = new Map(); + + try { + const tokenFile = this.getTokenFilePath(); + const data = await fs.readFile(tokenFile, 'utf-8'); + const tokens = JSON.parse(data) as MCPOAuthCredentials[]; + + for (const credential of tokens) { + tokenMap.set(credential.serverName, credential); + } + } catch (error) { + // File doesn't exist or is invalid, return empty map + if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { + console.error( + `Failed to load MCP OAuth tokens: ${getErrorMessage(error)}`, + ); + } + } + + return tokenMap; + } + + /** + * Save a token for a specific MCP server. + * + * @param serverName The name of the MCP server + * @param token The OAuth token to save + * @param clientId Optional client ID used for this token + * @param tokenUrl Optional token URL used for this token + */ + static async saveToken( + serverName: string, + token: MCPOAuthToken, + clientId?: string, + tokenUrl?: string, + ): Promise { + await this.ensureConfigDir(); + + const tokens = await this.loadTokens(); + + const credential: MCPOAuthCredentials = { + serverName, + token, + clientId, + tokenUrl, + updatedAt: Date.now(), + }; + + tokens.set(serverName, credential); + + const tokenArray = Array.from(tokens.values()); + const tokenFile = this.getTokenFilePath(); + + try { + await fs.writeFile( + tokenFile, + JSON.stringify(tokenArray, null, 2), + { mode: 0o600 }, // Restrict file permissions + ); + } catch (error) { + console.error( + `Failed to save MCP OAuth token: ${getErrorMessage(error)}`, + ); + throw error; + } + } + + /** + * Get a token for a specific MCP server. + * + * @param serverName The name of the MCP server + * @returns The stored credentials or null if not found + */ + static async getToken( + serverName: string, + ): Promise { + const tokens = await this.loadTokens(); + return tokens.get(serverName) || null; + } + + /** + * Remove a token for a specific MCP server. + * + * @param serverName The name of the MCP server + */ + static async removeToken(serverName: string): Promise { + const tokens = await this.loadTokens(); + + if (tokens.delete(serverName)) { + const tokenArray = Array.from(tokens.values()); + const tokenFile = this.getTokenFilePath(); + + try { + if (tokenArray.length === 0) { + // Remove file if no tokens left + await fs.unlink(tokenFile); + } else { + await fs.writeFile(tokenFile, JSON.stringify(tokenArray, null, 2), { + mode: 0o600, + }); + } + } catch (error) { + console.error( + `Failed to remove MCP OAuth token: ${getErrorMessage(error)}`, + ); + } + } + } + + /** + * Check if a token is expired. + * + * @param token The token to check + * @returns True if the token is expired + */ + static isTokenExpired(token: MCPOAuthToken): boolean { + if (!token.expiresAt) { + return false; // No expiry, assume valid + } + + // Add a 5-minute buffer to account for clock skew + const bufferMs = 5 * 60 * 1000; + return Date.now() + bufferMs >= token.expiresAt; + } + + /** + * Clear all stored MCP OAuth tokens. + */ + static async clearAllTokens(): Promise { + try { + const tokenFile = this.getTokenFilePath(); + await fs.unlink(tokenFile); + } catch (error) { + if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { + console.error( + `Failed to clear MCP OAuth tokens: ${getErrorMessage(error)}`, + ); + } + } + } +} diff --git a/packages/core/src/mcp/oauth-utils.test.ts b/packages/core/src/mcp/oauth-utils.test.ts new file mode 100644 index 00000000..b27d97b3 --- /dev/null +++ b/packages/core/src/mcp/oauth-utils.test.ts @@ -0,0 +1,206 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + OAuthUtils, + OAuthAuthorizationServerMetadata, + OAuthProtectedResourceMetadata, +} from './oauth-utils.js'; + +// Mock fetch globally +const mockFetch = vi.fn(); +global.fetch = mockFetch; + +describe('OAuthUtils', () => { + beforeEach(() => { + vi.clearAllMocks(); + vi.spyOn(console, 'debug').mockImplementation(() => {}); + vi.spyOn(console, 'error').mockImplementation(() => {}); + vi.spyOn(console, 'log').mockImplementation(() => {}); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('buildWellKnownUrls', () => { + it('should build correct well-known URLs', () => { + const urls = OAuthUtils.buildWellKnownUrls('https://example.com/path'); + expect(urls.protectedResource).toBe( + 'https://example.com/.well-known/oauth-protected-resource', + ); + expect(urls.authorizationServer).toBe( + 'https://example.com/.well-known/oauth-authorization-server', + ); + }); + }); + + describe('fetchProtectedResourceMetadata', () => { + const mockResourceMetadata: OAuthProtectedResourceMetadata = { + resource: 'https://api.example.com', + authorization_servers: ['https://auth.example.com'], + bearer_methods_supported: ['header'], + }; + + it('should fetch protected resource metadata successfully', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockResourceMetadata), + }); + + const result = await OAuthUtils.fetchProtectedResourceMetadata( + 'https://example.com/.well-known/oauth-protected-resource', + ); + + expect(result).toEqual(mockResourceMetadata); + }); + + it('should return null when fetch fails', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + }); + + const result = await OAuthUtils.fetchProtectedResourceMetadata( + 'https://example.com/.well-known/oauth-protected-resource', + ); + + expect(result).toBeNull(); + }); + }); + + describe('fetchAuthorizationServerMetadata', () => { + const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + scopes_supported: ['read', 'write'], + }; + + it('should fetch authorization server metadata successfully', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(mockAuthServerMetadata), + }); + + const result = await OAuthUtils.fetchAuthorizationServerMetadata( + 'https://auth.example.com/.well-known/oauth-authorization-server', + ); + + expect(result).toEqual(mockAuthServerMetadata); + }); + + it('should return null when fetch fails', async () => { + mockFetch.mockResolvedValueOnce({ + ok: false, + }); + + const result = await OAuthUtils.fetchAuthorizationServerMetadata( + 'https://auth.example.com/.well-known/oauth-authorization-server', + ); + + expect(result).toBeNull(); + }); + }); + + describe('metadataToOAuthConfig', () => { + it('should convert metadata to OAuth config', () => { + const metadata: OAuthAuthorizationServerMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + scopes_supported: ['read', 'write'], + }; + + const config = OAuthUtils.metadataToOAuthConfig(metadata); + + expect(config).toEqual({ + authorizationUrl: 'https://auth.example.com/authorize', + tokenUrl: 'https://auth.example.com/token', + scopes: ['read', 'write'], + }); + }); + + it('should handle empty scopes', () => { + const metadata: OAuthAuthorizationServerMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + }; + + const config = OAuthUtils.metadataToOAuthConfig(metadata); + + expect(config.scopes).toEqual([]); + }); + }); + + describe('parseWWWAuthenticateHeader', () => { + it('should parse resource metadata URI from WWW-Authenticate header', () => { + const header = + 'Bearer realm="example", resource_metadata_uri="https://example.com/.well-known/oauth-protected-resource"'; + const result = OAuthUtils.parseWWWAuthenticateHeader(header); + expect(result).toBe( + 'https://example.com/.well-known/oauth-protected-resource', + ); + }); + + it('should return null when no resource metadata URI is found', () => { + const header = 'Bearer realm="example"'; + const result = OAuthUtils.parseWWWAuthenticateHeader(header); + expect(result).toBeNull(); + }); + }); + + describe('extractBaseUrl', () => { + it('should extract base URL from MCP server URL', () => { + const result = OAuthUtils.extractBaseUrl('https://example.com/mcp/v1'); + expect(result).toBe('https://example.com'); + }); + + it('should handle URLs with ports', () => { + const result = OAuthUtils.extractBaseUrl( + 'https://example.com:8080/mcp/v1', + ); + expect(result).toBe('https://example.com:8080'); + }); + }); + + describe('isSSEEndpoint', () => { + it('should return true for SSE endpoints', () => { + expect(OAuthUtils.isSSEEndpoint('https://example.com/sse')).toBe(true); + expect(OAuthUtils.isSSEEndpoint('https://example.com/api/v1/sse')).toBe( + true, + ); + }); + + it('should return true for non-MCP endpoints', () => { + expect(OAuthUtils.isSSEEndpoint('https://example.com/api')).toBe(true); + }); + + it('should return false for MCP endpoints', () => { + expect(OAuthUtils.isSSEEndpoint('https://example.com/mcp')).toBe(false); + expect(OAuthUtils.isSSEEndpoint('https://example.com/api/mcp/v1')).toBe( + false, + ); + }); + }); + + describe('buildResourceParameter', () => { + it('should build resource parameter from endpoint URL', () => { + const result = OAuthUtils.buildResourceParameter( + 'https://example.com/oauth/token', + ); + expect(result).toBe('https://example.com'); + }); + + it('should handle URLs with ports', () => { + const result = OAuthUtils.buildResourceParameter( + 'https://example.com:8080/oauth/token', + ); + expect(result).toBe('https://example.com:8080'); + }); + }); +}); diff --git a/packages/core/src/mcp/oauth-utils.ts b/packages/core/src/mcp/oauth-utils.ts new file mode 100644 index 00000000..6dad17c8 --- /dev/null +++ b/packages/core/src/mcp/oauth-utils.ts @@ -0,0 +1,285 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MCPOAuthConfig } from './oauth-provider.js'; +import { getErrorMessage } from '../utils/errors.js'; + +/** + * OAuth authorization server metadata as per RFC 8414. + */ +export interface OAuthAuthorizationServerMetadata { + issuer: string; + authorization_endpoint: string; + token_endpoint: string; + token_endpoint_auth_methods_supported?: string[]; + revocation_endpoint?: string; + revocation_endpoint_auth_methods_supported?: string[]; + registration_endpoint?: string; + response_types_supported?: string[]; + grant_types_supported?: string[]; + code_challenge_methods_supported?: string[]; + scopes_supported?: string[]; +} + +/** + * OAuth protected resource metadata as per RFC 9728. + */ +export interface OAuthProtectedResourceMetadata { + resource: string; + authorization_servers?: string[]; + bearer_methods_supported?: string[]; + resource_documentation?: string; + resource_signing_alg_values_supported?: string[]; + resource_encryption_alg_values_supported?: string[]; + resource_encryption_enc_values_supported?: string[]; +} + +/** + * Utility class for common OAuth operations. + */ +export class OAuthUtils { + /** + * Construct well-known OAuth endpoint URLs. + */ + static buildWellKnownUrls(baseUrl: string) { + const serverUrl = new URL(baseUrl); + const base = `${serverUrl.protocol}//${serverUrl.host}`; + + return { + protectedResource: new URL( + '/.well-known/oauth-protected-resource', + base, + ).toString(), + authorizationServer: new URL( + '/.well-known/oauth-authorization-server', + base, + ).toString(), + }; + } + + /** + * Fetch OAuth protected resource metadata. + * + * @param resourceMetadataUrl The protected resource metadata URL + * @returns The protected resource metadata or null if not available + */ + static async fetchProtectedResourceMetadata( + resourceMetadataUrl: string, + ): Promise { + try { + const response = await fetch(resourceMetadataUrl); + if (!response.ok) { + return null; + } + return (await response.json()) as OAuthProtectedResourceMetadata; + } catch (error) { + console.debug( + `Failed to fetch protected resource metadata from ${resourceMetadataUrl}: ${getErrorMessage(error)}`, + ); + return null; + } + } + + /** + * Fetch OAuth authorization server metadata. + * + * @param authServerMetadataUrl The authorization server metadata URL + * @returns The authorization server metadata or null if not available + */ + static async fetchAuthorizationServerMetadata( + authServerMetadataUrl: string, + ): Promise { + try { + const response = await fetch(authServerMetadataUrl); + if (!response.ok) { + return null; + } + return (await response.json()) as OAuthAuthorizationServerMetadata; + } catch (error) { + console.debug( + `Failed to fetch authorization server metadata from ${authServerMetadataUrl}: ${getErrorMessage(error)}`, + ); + return null; + } + } + + /** + * Convert authorization server metadata to OAuth configuration. + * + * @param metadata The authorization server metadata + * @returns The OAuth configuration + */ + static metadataToOAuthConfig( + metadata: OAuthAuthorizationServerMetadata, + ): MCPOAuthConfig { + return { + authorizationUrl: metadata.authorization_endpoint, + tokenUrl: metadata.token_endpoint, + scopes: metadata.scopes_supported || [], + }; + } + + /** + * Discover OAuth configuration using the standard well-known endpoints. + * + * @param serverUrl The base URL of the server + * @returns The discovered OAuth configuration or null if not available + */ + static async discoverOAuthConfig( + serverUrl: string, + ): Promise { + try { + const wellKnownUrls = this.buildWellKnownUrls(serverUrl); + + // First, try to get the protected resource metadata + const resourceMetadata = await this.fetchProtectedResourceMetadata( + wellKnownUrls.protectedResource, + ); + + if (resourceMetadata?.authorization_servers?.length) { + // Use the first authorization server + const authServerUrl = resourceMetadata.authorization_servers[0]; + const authServerMetadataUrl = new URL( + '/.well-known/oauth-authorization-server', + authServerUrl, + ).toString(); + + const authServerMetadata = await this.fetchAuthorizationServerMetadata( + authServerMetadataUrl, + ); + + if (authServerMetadata) { + const config = this.metadataToOAuthConfig(authServerMetadata); + if (authServerMetadata.registration_endpoint) { + console.log( + 'Dynamic client registration is supported at:', + authServerMetadata.registration_endpoint, + ); + } + return config; + } + } + + // Fallback: try /.well-known/oauth-authorization-server at the base URL + console.debug( + `Trying OAuth discovery fallback at ${wellKnownUrls.authorizationServer}`, + ); + const authServerMetadata = await this.fetchAuthorizationServerMetadata( + wellKnownUrls.authorizationServer, + ); + + if (authServerMetadata) { + const config = this.metadataToOAuthConfig(authServerMetadata); + if (authServerMetadata.registration_endpoint) { + console.log( + 'Dynamic client registration is supported at:', + authServerMetadata.registration_endpoint, + ); + } + return config; + } + + return null; + } catch (error) { + console.debug( + `Failed to discover OAuth configuration: ${getErrorMessage(error)}`, + ); + return null; + } + } + + /** + * Parse WWW-Authenticate header to extract OAuth information. + * + * @param header The WWW-Authenticate header value + * @returns The resource metadata URI if found + */ + static parseWWWAuthenticateHeader(header: string): string | null { + // Parse Bearer realm and resource_metadata_uri + const match = header.match(/resource_metadata_uri="([^"]+)"/); + if (match) { + return match[1]; + } + return null; + } + + /** + * Discover OAuth configuration from WWW-Authenticate header. + * + * @param wwwAuthenticate The WWW-Authenticate header value + * @returns The discovered OAuth configuration or null if not available + */ + static async discoverOAuthFromWWWAuthenticate( + wwwAuthenticate: string, + ): Promise { + const resourceMetadataUri = + this.parseWWWAuthenticateHeader(wwwAuthenticate); + if (!resourceMetadataUri) { + return null; + } + + console.log( + `Found resource metadata URI from www-authenticate header: ${resourceMetadataUri}`, + ); + + const resourceMetadata = + await this.fetchProtectedResourceMetadata(resourceMetadataUri); + if (!resourceMetadata?.authorization_servers?.length) { + return null; + } + + const authServerUrl = resourceMetadata.authorization_servers[0]; + const authServerMetadataUrl = new URL( + '/.well-known/oauth-authorization-server', + authServerUrl, + ).toString(); + + const authServerMetadata = await this.fetchAuthorizationServerMetadata( + authServerMetadataUrl, + ); + + if (authServerMetadata) { + console.log( + 'OAuth configuration discovered successfully from www-authenticate header', + ); + return this.metadataToOAuthConfig(authServerMetadata); + } + + return null; + } + + /** + * Extract base URL from an MCP server URL. + * + * @param mcpServerUrl The MCP server URL + * @returns The base URL + */ + static extractBaseUrl(mcpServerUrl: string): string { + const serverUrl = new URL(mcpServerUrl); + return `${serverUrl.protocol}//${serverUrl.host}`; + } + + /** + * Check if a URL is an SSE endpoint. + * + * @param url The URL to check + * @returns True if the URL appears to be an SSE endpoint + */ + static isSSEEndpoint(url: string): boolean { + return url.includes('/sse') || !url.includes('/mcp'); + } + + /** + * Build a resource parameter for OAuth requests. + * + * @param endpointUrl The endpoint URL + * @returns The resource parameter value + */ + static buildResourceParameter(endpointUrl: string): string { + const url = new URL(endpointUrl); + return `${url.protocol}//${url.host}`; + } +}