MCP OAuth Part 1 - OAuth Infrastructure (#4316)

This commit is contained in:
Brian Ray 2025-07-18 10:14:23 -04:00 committed by GitHub
parent de27ea6095
commit c5761317f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 2453 additions and 0 deletions

View File

@ -60,6 +60,20 @@ export * from './tools/read-many-files.js';
export * from './tools/mcp-client.js';
export * from './tools/mcp-tool.js';
// MCP OAuth
export { MCPOAuthProvider } from './mcp/oauth-provider.js';
export {
MCPOAuthToken,
MCPOAuthCredentials,
MCPOAuthTokenStorage,
} from './mcp/oauth-token-storage.js';
export type { MCPOAuthConfig } from './mcp/oauth-provider.js';
export type {
OAuthAuthorizationServerMetadata,
OAuthProtectedResourceMetadata,
} from './mcp/oauth-utils.js';
export { OAuthUtils } from './mcp/oauth-utils.js';
// Export telemetry functions
export * from './telemetry/index.js';
export { sessionId } from './utils/session.js';

View File

@ -0,0 +1,720 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import * as http from 'node:http';
import * as crypto from 'node:crypto';
import open from 'open';
import {
MCPOAuthProvider,
MCPOAuthConfig,
OAuthTokenResponse,
OAuthClientRegistrationResponse,
} from './oauth-provider.js';
import { MCPOAuthTokenStorage, MCPOAuthToken } from './oauth-token-storage.js';
// Mock dependencies
vi.mock('open');
vi.mock('node:crypto');
vi.mock('./oauth-token-storage.js');
// Mock fetch globally
const mockFetch = vi.fn();
global.fetch = mockFetch;
// Define a reusable mock server with .listen, .close, and .on methods
const mockHttpServer = {
listen: vi.fn(),
close: vi.fn(),
on: vi.fn(),
};
vi.mock('node:http', () => ({
createServer: vi.fn(() => mockHttpServer),
}));
describe('MCPOAuthProvider', () => {
const mockConfig: MCPOAuthConfig = {
enabled: true,
clientId: 'test-client-id',
clientSecret: 'test-client-secret',
authorizationUrl: 'https://auth.example.com/authorize',
tokenUrl: 'https://auth.example.com/token',
scopes: ['read', 'write'],
redirectUri: 'http://localhost:7777/oauth/callback',
};
const mockToken: MCPOAuthToken = {
accessToken: 'access_token_123',
refreshToken: 'refresh_token_456',
tokenType: 'Bearer',
scope: 'read write',
expiresAt: Date.now() + 3600000,
};
const mockTokenResponse: OAuthTokenResponse = {
access_token: 'access_token_123',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'refresh_token_456',
scope: 'read write',
};
beforeEach(() => {
vi.clearAllMocks();
vi.spyOn(console, 'log').mockImplementation(() => {});
vi.spyOn(console, 'warn').mockImplementation(() => {});
vi.spyOn(console, 'error').mockImplementation(() => {});
// Mock crypto functions
vi.mocked(crypto.randomBytes).mockImplementation((size: number) => {
if (size === 32) return Buffer.from('code_verifier_mock_32_bytes_long');
if (size === 16) return Buffer.from('state_mock_16_by');
return Buffer.alloc(size);
});
vi.mocked(crypto.createHash).mockReturnValue({
update: vi.fn().mockReturnThis(),
digest: vi.fn().mockReturnValue('code_challenge_mock'),
} as unknown as crypto.Hash);
// Mock randomBytes to return predictable values for state
vi.mocked(crypto.randomBytes).mockImplementation((size) => {
if (size === 32) {
return Buffer.from('mock_code_verifier_32_bytes_long_string');
} else if (size === 16) {
return Buffer.from('mock_state_16_bytes');
}
return Buffer.alloc(size);
});
// Mock token storage
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('authenticate', () => {
it('should perform complete OAuth flow with PKCE', async () => {
// Mock HTTP server callback
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
// Simulate OAuth callback
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
// Mock token exchange
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
const result = await MCPOAuthProvider.authenticate(
'test-server',
mockConfig,
);
expect(result).toEqual({
accessToken: 'access_token_123',
refreshToken: 'refresh_token_456',
tokenType: 'Bearer',
scope: 'read write',
expiresAt: expect.any(Number),
});
expect(open).toHaveBeenCalledWith(expect.stringContaining('authorize'));
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'access_token_123' }),
'test-client-id',
'https://auth.example.com/token',
);
});
it('should handle OAuth discovery when no authorization URL provided', async () => {
// Use a mutable config object
const configWithoutAuth: MCPOAuthConfig = { ...mockConfig };
delete configWithoutAuth.authorizationUrl;
delete configWithoutAuth.tokenUrl;
const mockResourceMetadata = {
authorization_servers: ['https://discovered.auth.com'],
};
const mockAuthServerMetadata = {
authorization_endpoint: 'https://discovered.auth.com/authorize',
token_endpoint: 'https://discovered.auth.com/token',
scopes_supported: ['read', 'write'],
};
mockFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockResourceMetadata),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockAuthServerMetadata),
});
// Patch config after discovery
configWithoutAuth.authorizationUrl =
mockAuthServerMetadata.authorization_endpoint;
configWithoutAuth.tokenUrl = mockAuthServerMetadata.token_endpoint;
configWithoutAuth.scopes = mockAuthServerMetadata.scopes_supported;
// Setup callback handler
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
// Mock token exchange with discovered endpoint
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
const result = await MCPOAuthProvider.authenticate(
'test-server',
configWithoutAuth,
'https://api.example.com',
);
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledWith(
'https://discovered.auth.com/token',
expect.objectContaining({
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
}),
);
});
it('should perform dynamic client registration when no client ID provided', async () => {
const configWithoutClient = { ...mockConfig };
delete configWithoutClient.clientId;
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
client_id: 'dynamic_client_id',
client_secret: 'dynamic_client_secret',
redirect_uris: ['http://localhost:7777/oauth/callback'],
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
token_endpoint_auth_method: 'none',
};
const mockAuthServerMetadata = {
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
registration_endpoint: 'https://auth.example.com/register',
};
mockFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockAuthServerMetadata),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockRegistrationResponse),
});
// Setup callback handler
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
// Mock token exchange
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
const result = await MCPOAuthProvider.authenticate(
'test-server',
configWithoutClient,
);
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledWith(
'https://auth.example.com/register',
expect.objectContaining({
method: 'POST',
headers: { 'Content-Type': 'application/json' },
}),
);
});
it('should handle OAuth callback errors', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?error=access_denied&error_description=User%20denied%20access',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth error: access_denied');
});
it('should handle state mismatch in callback', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=wrong_state',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('State mismatch - possible CSRF attack');
});
it('should handle token exchange failure', async () => {
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
text: () => Promise.resolve('Invalid grant'),
});
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('Token exchange failed: 400 - Invalid grant');
});
it('should handle callback timeout', async () => {
vi.mocked(http.createServer).mockImplementation(
() => mockHttpServer as unknown as http.Server,
);
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
// Don't trigger callback - simulate timeout
});
// Mock setTimeout to trigger timeout immediately
const originalSetTimeout = global.setTimeout;
global.setTimeout = vi.fn((callback, delay) => {
if (delay === 5 * 60 * 1000) {
// 5 minute timeout
callback();
}
return originalSetTimeout(callback, 0);
}) as unknown as typeof setTimeout;
await expect(
MCPOAuthProvider.authenticate('test-server', mockConfig),
).rejects.toThrow('OAuth callback timeout');
global.setTimeout = originalSetTimeout;
});
});
describe('refreshAccessToken', () => {
it('should refresh token successfully', async () => {
const refreshResponse = {
access_token: 'new_access_token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'new_refresh_token',
};
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(refreshResponse),
});
const result = await MCPOAuthProvider.refreshAccessToken(
mockConfig,
'old_refresh_token',
'https://auth.example.com/token',
);
expect(result).toEqual(refreshResponse);
expect(mockFetch).toHaveBeenCalledWith(
'https://auth.example.com/token',
expect.objectContaining({
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: expect.stringContaining('grant_type=refresh_token'),
}),
);
});
it('should include client secret in refresh request when available', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
await MCPOAuthProvider.refreshAccessToken(
mockConfig,
'refresh_token',
'https://auth.example.com/token',
);
const fetchCall = mockFetch.mock.calls[0];
expect(fetchCall[1].body).toContain('client_secret=test-client-secret');
});
it('should handle refresh token failure', async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
text: () => Promise.resolve('Invalid refresh token'),
});
await expect(
MCPOAuthProvider.refreshAccessToken(
mockConfig,
'invalid_refresh_token',
'https://auth.example.com/token',
),
).rejects.toThrow('Token refresh failed: 400 - Invalid refresh token');
});
});
describe('getValidToken', () => {
it('should return valid token when not expired', async () => {
const validCredentials = {
serverName: 'test-server',
token: mockToken,
clientId: 'test-client-id',
tokenUrl: 'https://auth.example.com/token',
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
validCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
const result = await MCPOAuthProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBe('access_token_123');
});
it('should refresh expired token and return new token', async () => {
const expiredCredentials = {
serverName: 'test-server',
token: { ...mockToken, expiresAt: Date.now() - 3600000 },
clientId: 'test-client-id',
tokenUrl: 'https://auth.example.com/token',
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const refreshResponse = {
access_token: 'new_access_token',
token_type: 'Bearer',
expires_in: 3600,
refresh_token: 'new_refresh_token',
};
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(refreshResponse),
});
const result = await MCPOAuthProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBe('new_access_token');
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
'test-server',
expect.objectContaining({ accessToken: 'new_access_token' }),
'test-client-id',
'https://auth.example.com/token',
);
});
it('should return null when no credentials exist', async () => {
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
const result = await MCPOAuthProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
});
it('should handle refresh failure and remove invalid token', async () => {
const expiredCredentials = {
serverName: 'test-server',
token: { ...mockToken, expiresAt: Date.now() - 3600000 },
clientId: 'test-client-id',
tokenUrl: 'https://auth.example.com/token',
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
expiredCredentials,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
text: () => Promise.resolve('Invalid refresh token'),
});
const result = await MCPOAuthProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
'test-server',
);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to refresh token'),
);
});
it('should return null for token without refresh capability', async () => {
const tokenWithoutRefresh = {
serverName: 'test-server',
token: {
...mockToken,
refreshToken: undefined,
expiresAt: Date.now() - 3600000,
},
clientId: 'test-client-id',
tokenUrl: 'https://auth.example.com/token',
updatedAt: Date.now(),
};
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
tokenWithoutRefresh,
);
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
const result = await MCPOAuthProvider.getValidToken(
'test-server',
mockConfig,
);
expect(result).toBeNull();
});
});
describe('PKCE parameter generation', () => {
it('should generate valid PKCE parameters', async () => {
// Test is implicit in the authenticate flow tests, but we can verify
// the crypto mocks are called correctly
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
await MCPOAuthProvider.authenticate('test-server', mockConfig);
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
expect(crypto.createHash).toHaveBeenCalledWith('sha256');
});
});
describe('Authorization URL building', () => {
it('should build correct authorization URL with all parameters', async () => {
// Mock to capture the URL that would be opened
let capturedUrl: string;
vi.mocked(open).mockImplementation((url) => {
capturedUrl = url;
// Return a minimal mock ChildProcess
return Promise.resolve({
pid: 1234,
} as unknown as import('child_process').ChildProcess);
});
let callbackHandler: unknown;
vi.mocked(http.createServer).mockImplementation((handler) => {
callbackHandler = handler;
return mockHttpServer as unknown as http.Server;
});
mockHttpServer.listen.mockImplementation((port, callback) => {
callback?.();
setTimeout(() => {
const mockReq = {
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
};
const mockRes = {
writeHead: vi.fn(),
end: vi.fn(),
};
(callbackHandler as (req: unknown, res: unknown) => void)(
mockReq,
mockRes,
);
}, 10);
});
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockTokenResponse),
});
await MCPOAuthProvider.authenticate('test-server', mockConfig);
expect(capturedUrl!).toContain('response_type=code');
expect(capturedUrl!).toContain('client_id=test-client-id');
expect(capturedUrl!).toContain('code_challenge=code_challenge_mock');
expect(capturedUrl!).toContain('code_challenge_method=S256');
expect(capturedUrl!).toContain('scope=read+write');
expect(capturedUrl!).toContain('resource=https%3A%2F%2Fauth.example.com');
});
});
});

View File

@ -0,0 +1,698 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as http from 'node:http';
import * as crypto from 'node:crypto';
import { URL } from 'node:url';
import open from 'open';
import { MCPOAuthToken, MCPOAuthTokenStorage } from './oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js';
import { OAuthUtils } from './oauth-utils.js';
/**
* OAuth configuration for an MCP server.
*/
export interface MCPOAuthConfig {
enabled?: boolean; // Whether OAuth is enabled for this server
clientId?: string;
clientSecret?: string;
authorizationUrl?: string;
tokenUrl?: string;
scopes?: string[];
redirectUri?: string;
tokenParamName?: string; // For SSE connections, specifies the query parameter name for the token
}
/**
* OAuth authorization response.
*/
export interface OAuthAuthorizationResponse {
code: string;
state: string;
}
/**
* OAuth token response from the authorization server.
*/
export interface OAuthTokenResponse {
access_token: string;
token_type: string;
expires_in?: number;
refresh_token?: string;
scope?: string;
}
/**
* Dynamic client registration request.
*/
export interface OAuthClientRegistrationRequest {
client_name: string;
redirect_uris: string[];
grant_types: string[];
response_types: string[];
token_endpoint_auth_method: string;
code_challenge_method?: string[];
scope?: string;
}
/**
* Dynamic client registration response.
*/
export interface OAuthClientRegistrationResponse {
client_id: string;
client_secret?: string;
client_id_issued_at?: number;
client_secret_expires_at?: number;
redirect_uris: string[];
grant_types: string[];
response_types: string[];
token_endpoint_auth_method: string;
code_challenge_method?: string[];
scope?: string;
}
/**
* PKCE (Proof Key for Code Exchange) parameters.
*/
interface PKCEParams {
codeVerifier: string;
codeChallenge: string;
state: string;
}
/**
* Provider for handling OAuth authentication for MCP servers.
*/
export class MCPOAuthProvider {
private static readonly REDIRECT_PORT = 7777;
private static readonly REDIRECT_PATH = '/oauth/callback';
private static readonly HTTP_OK = 200;
private static readonly HTTP_REDIRECT = 302;
/**
* Register a client dynamically with the OAuth server.
*
* @param registrationUrl The client registration endpoint URL
* @param config OAuth configuration
* @returns The registered client information
*/
private static async registerClient(
registrationUrl: string,
config: MCPOAuthConfig,
): Promise<OAuthClientRegistrationResponse> {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
const registrationRequest: OAuthClientRegistrationRequest = {
client_name: 'Gemini CLI MCP Client',
redirect_uris: [redirectUri],
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
token_endpoint_auth_method: 'none', // Public client
code_challenge_method: ['S256'],
scope: config.scopes?.join(' ') || '',
};
const response = await fetch(registrationUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(registrationRequest),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Client registration failed: ${response.status} ${response.statusText} - ${errorText}`,
);
}
return (await response.json()) as OAuthClientRegistrationResponse;
}
/**
* Discover OAuth configuration from an MCP server URL.
*
* @param mcpServerUrl The MCP server URL
* @returns OAuth configuration if discovered, null otherwise
*/
private static async discoverOAuthFromMCPServer(
mcpServerUrl: string,
): Promise<MCPOAuthConfig | null> {
const baseUrl = OAuthUtils.extractBaseUrl(mcpServerUrl);
return OAuthUtils.discoverOAuthConfig(baseUrl);
}
/**
* Generate PKCE parameters for OAuth flow.
*
* @returns PKCE parameters including code verifier, challenge, and state
*/
private static generatePKCEParams(): PKCEParams {
// Generate code verifier (43-128 characters)
const codeVerifier = crypto.randomBytes(32).toString('base64url');
// Generate code challenge using SHA256
const codeChallenge = crypto
.createHash('sha256')
.update(codeVerifier)
.digest('base64url');
// Generate state for CSRF protection
const state = crypto.randomBytes(16).toString('base64url');
return { codeVerifier, codeChallenge, state };
}
/**
* Start a local HTTP server to handle OAuth callback.
*
* @param expectedState The state parameter to validate
* @returns Promise that resolves with the authorization code
*/
private static async startCallbackServer(
expectedState: string,
): Promise<OAuthAuthorizationResponse> {
return new Promise((resolve, reject) => {
const server = http.createServer(
async (req: http.IncomingMessage, res: http.ServerResponse) => {
try {
const url = new URL(
req.url!,
`http://localhost:${this.REDIRECT_PORT}`,
);
if (url.pathname !== this.REDIRECT_PATH) {
res.writeHead(404);
res.end('Not found');
return;
}
const code = url.searchParams.get('code');
const state = url.searchParams.get('state');
const error = url.searchParams.get('error');
if (error) {
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Failed</h1>
<p>Error: ${(error as string).replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>${((url.searchParams.get('error_description') || '') as string).replace(/</g, '&lt;').replace(/>/g, '&gt;')}</p>
<p>You can close this window.</p>
</body>
</html>
`);
server.close();
reject(new Error(`OAuth error: ${error}`));
return;
}
if (!code || !state) {
res.writeHead(400);
res.end('Missing code or state parameter');
return;
}
if (state !== expectedState) {
res.writeHead(400);
res.end('Invalid state parameter');
server.close();
reject(new Error('State mismatch - possible CSRF attack'));
return;
}
// Send success response to browser
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
res.end(`
<html>
<body>
<h1>Authentication Successful!</h1>
<p>You can close this window and return to Gemini CLI.</p>
<script>window.close();</script>
</body>
</html>
`);
server.close();
resolve({ code, state });
} catch (error) {
server.close();
reject(error);
}
},
);
server.on('error', reject);
server.listen(this.REDIRECT_PORT, () => {
console.log(
`OAuth callback server listening on port ${this.REDIRECT_PORT}`,
);
});
// Timeout after 5 minutes
setTimeout(
() => {
server.close();
reject(new Error('OAuth callback timeout'));
},
5 * 60 * 1000,
);
});
}
/**
* Build the authorization URL with PKCE parameters.
*
* @param config OAuth configuration
* @param pkceParams PKCE parameters
* @returns The authorization URL
*/
private static buildAuthorizationUrl(
config: MCPOAuthConfig,
pkceParams: PKCEParams,
): string {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
const params = new URLSearchParams({
client_id: config.clientId!,
response_type: 'code',
redirect_uri: redirectUri,
state: pkceParams.state,
code_challenge: pkceParams.codeChallenge,
code_challenge_method: 'S256',
});
if (config.scopes && config.scopes.length > 0) {
params.append('scope', config.scopes.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
params.append(
'resource',
OAuthUtils.buildResourceParameter(config.authorizationUrl!),
);
return `${config.authorizationUrl}?${params.toString()}`;
}
/**
* Exchange authorization code for tokens.
*
* @param config OAuth configuration
* @param code Authorization code
* @param codeVerifier PKCE code verifier
* @returns The token response
*/
private static async exchangeCodeForToken(
config: MCPOAuthConfig,
code: string,
codeVerifier: string,
): Promise<OAuthTokenResponse> {
const redirectUri =
config.redirectUri ||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
const params = new URLSearchParams({
grant_type: 'authorization_code',
code,
redirect_uri: redirectUri,
code_verifier: codeVerifier,
client_id: config.clientId!,
});
if (config.clientSecret) {
params.append('client_secret', config.clientSecret);
}
// Add resource parameter for MCP OAuth spec compliance
params.append(
'resource',
OAuthUtils.buildResourceParameter(config.tokenUrl!),
);
const response = await fetch(config.tokenUrl!, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: params.toString(),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Token exchange failed: ${response.status} - ${errorText}`,
);
}
return (await response.json()) as OAuthTokenResponse;
}
/**
* Refresh an access token using a refresh token.
*
* @param config OAuth configuration
* @param refreshToken The refresh token
* @returns The new token response
*/
static async refreshAccessToken(
config: MCPOAuthConfig,
refreshToken: string,
tokenUrl: string,
): Promise<OAuthTokenResponse> {
const params = new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: refreshToken,
client_id: config.clientId!,
});
if (config.clientSecret) {
params.append('client_secret', config.clientSecret);
}
if (config.scopes && config.scopes.length > 0) {
params.append('scope', config.scopes.join(' '));
}
// Add resource parameter for MCP OAuth spec compliance
params.append('resource', OAuthUtils.buildResourceParameter(tokenUrl));
const response = await fetch(tokenUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: params.toString(),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Token refresh failed: ${response.status} - ${errorText}`,
);
}
return (await response.json()) as OAuthTokenResponse;
}
/**
* Perform the full OAuth authorization code flow with PKCE.
*
* @param serverName The name of the MCP server
* @param config OAuth configuration
* @param mcpServerUrl Optional MCP server URL for OAuth discovery
* @returns The obtained OAuth token
*/
static async authenticate(
serverName: string,
config: MCPOAuthConfig,
mcpServerUrl?: string,
): Promise<MCPOAuthToken> {
// If no authorization URL is provided, try to discover OAuth configuration
if (!config.authorizationUrl && mcpServerUrl) {
console.log(
'No authorization URL provided, attempting OAuth discovery...',
);
// For SSE URLs, first check if authentication is required
if (OAuthUtils.isSSEEndpoint(mcpServerUrl)) {
try {
const response = await fetch(mcpServerUrl, {
method: 'HEAD',
headers: {
Accept: 'text/event-stream',
},
});
if (response.status === 401 || response.status === 307) {
const wwwAuthenticate = response.headers.get('www-authenticate');
if (wwwAuthenticate) {
const discoveredConfig =
await OAuthUtils.discoverOAuthFromWWWAuthenticate(
wwwAuthenticate,
);
if (discoveredConfig) {
config = {
...config,
...discoveredConfig,
scopes: discoveredConfig.scopes || config.scopes || [],
};
}
}
}
} catch (error) {
console.debug(
`Failed to check SSE endpoint for authentication requirements: ${getErrorMessage(error)}`,
);
}
}
// If we still don't have OAuth config, try the standard discovery
if (!config.authorizationUrl) {
const discoveredConfig =
await this.discoverOAuthFromMCPServer(mcpServerUrl);
if (discoveredConfig) {
config = { ...config, ...discoveredConfig };
console.log('OAuth configuration discovered successfully');
} else {
throw new Error(
'Failed to discover OAuth configuration from MCP server',
);
}
}
}
// If no client ID is provided, try dynamic client registration
if (!config.clientId) {
// Extract server URL from authorization URL
if (!config.authorizationUrl) {
throw new Error(
'Cannot perform dynamic registration without authorization URL',
);
}
const authUrl = new URL(config.authorizationUrl);
const serverUrl = `${authUrl.protocol}//${authUrl.host}`;
console.log(
'No client ID provided, attempting dynamic client registration...',
);
// Get the authorization server metadata for registration
const authServerMetadataUrl = new URL(
'/.well-known/oauth-authorization-server',
serverUrl,
).toString();
const authServerMetadata =
await OAuthUtils.fetchAuthorizationServerMetadata(
authServerMetadataUrl,
);
if (!authServerMetadata) {
throw new Error(
'Failed to fetch authorization server metadata for client registration',
);
}
// Register client if registration endpoint is available
if (authServerMetadata.registration_endpoint) {
const clientRegistration = await this.registerClient(
authServerMetadata.registration_endpoint,
config,
);
config.clientId = clientRegistration.client_id;
if (clientRegistration.client_secret) {
config.clientSecret = clientRegistration.client_secret;
}
console.log('Dynamic client registration successful');
} else {
throw new Error(
'No client ID provided and dynamic registration not supported',
);
}
}
// Validate configuration
if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) {
throw new Error(
'Missing required OAuth configuration after discovery and registration',
);
}
// Generate PKCE parameters
const pkceParams = this.generatePKCEParams();
// Build authorization URL
const authUrl = this.buildAuthorizationUrl(config, pkceParams);
console.log('\nOpening browser for OAuth authentication...');
console.log('If the browser does not open, please visit:');
console.log('');
// Get terminal width or default to 80
const terminalWidth = process.stdout.columns || 80;
const separatorLength = Math.min(terminalWidth - 2, 80);
const separator = '━'.repeat(separatorLength);
console.log(separator);
console.log(
'COPY THE ENTIRE URL BELOW (select all text between the lines):',
);
console.log(separator);
console.log(authUrl);
console.log(separator);
console.log('');
console.log(
'💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.',
);
console.log(
'⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.',
);
console.log('');
// Start callback server
const callbackPromise = this.startCallbackServer(pkceParams.state);
// Open browser
try {
await open(authUrl);
} catch (error) {
console.warn(
'Failed to open browser automatically:',
getErrorMessage(error),
);
}
// Wait for callback
const { code } = await callbackPromise;
console.log('\nAuthorization code received, exchanging for tokens...');
// Exchange code for tokens
const tokenResponse = await this.exchangeCodeForToken(
config,
code,
pkceParams.codeVerifier,
);
// Convert to our token format
const token: MCPOAuthToken = {
accessToken: tokenResponse.access_token,
tokenType: tokenResponse.token_type,
refreshToken: tokenResponse.refresh_token,
scope: tokenResponse.scope,
};
if (tokenResponse.expires_in) {
token.expiresAt = Date.now() + tokenResponse.expires_in * 1000;
}
// Save token
try {
await MCPOAuthTokenStorage.saveToken(
serverName,
token,
config.clientId,
config.tokenUrl,
);
console.log('Authentication successful! Token saved.');
// Verify token was saved
const savedToken = await MCPOAuthTokenStorage.getToken(serverName);
if (savedToken) {
console.log(
`Token verification successful: ${savedToken.token.accessToken.substring(0, 20)}...`,
);
} else {
console.error('Token verification failed: token not found after save');
}
} catch (saveError) {
console.error(`Failed to save token: ${getErrorMessage(saveError)}`);
throw saveError;
}
return token;
}
/**
* Get a valid access token for an MCP server, refreshing if necessary.
*
* @param serverName The name of the MCP server
* @param config OAuth configuration
* @returns A valid access token or null if not authenticated
*/
static async getValidToken(
serverName: string,
config: MCPOAuthConfig,
): Promise<string | null> {
console.debug(`Getting valid token for server: ${serverName}`);
const credentials = await MCPOAuthTokenStorage.getToken(serverName);
if (!credentials) {
console.debug(`No credentials found for server: ${serverName}`);
return null;
}
const { token } = credentials;
console.debug(
`Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`,
);
// Check if token is expired
if (!MCPOAuthTokenStorage.isTokenExpired(token)) {
console.debug(`Returning valid token for server: ${serverName}`);
return token.accessToken;
}
// Try to refresh if we have a refresh token
if (token.refreshToken && config.clientId && credentials.tokenUrl) {
try {
console.log(`Refreshing expired token for MCP server: ${serverName}`);
const newTokenResponse = await this.refreshAccessToken(
config,
token.refreshToken,
credentials.tokenUrl,
);
// Update stored token
const newToken: MCPOAuthToken = {
accessToken: newTokenResponse.access_token,
tokenType: newTokenResponse.token_type,
refreshToken: newTokenResponse.refresh_token || token.refreshToken,
scope: newTokenResponse.scope || token.scope,
};
if (newTokenResponse.expires_in) {
newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000;
}
await MCPOAuthTokenStorage.saveToken(
serverName,
newToken,
config.clientId,
credentials.tokenUrl,
);
return newToken.accessToken;
} catch (error) {
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
// Remove invalid token
await MCPOAuthTokenStorage.removeToken(serverName);
}
}
return null;
}
}

View File

@ -0,0 +1,325 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import {
MCPOAuthTokenStorage,
MCPOAuthToken,
MCPOAuthCredentials,
} from './oauth-token-storage.js';
// Mock file system operations
vi.mock('node:fs', () => ({
promises: {
readFile: vi.fn(),
writeFile: vi.fn(),
mkdir: vi.fn(),
unlink: vi.fn(),
},
}));
vi.mock('node:os', () => ({
homedir: vi.fn(() => '/mock/home'),
}));
describe('MCPOAuthTokenStorage', () => {
const mockToken: MCPOAuthToken = {
accessToken: 'access_token_123',
refreshToken: 'refresh_token_456',
tokenType: 'Bearer',
scope: 'read write',
expiresAt: Date.now() + 3600000, // 1 hour from now
};
const mockCredentials: MCPOAuthCredentials = {
serverName: 'test-server',
token: mockToken,
clientId: 'test-client-id',
tokenUrl: 'https://auth.example.com/token',
updatedAt: Date.now(),
};
beforeEach(() => {
vi.clearAllMocks();
vi.spyOn(console, 'error').mockImplementation(() => {});
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('loadTokens', () => {
it('should return empty map when token file does not exist', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).not.toHaveBeenCalled();
});
it('should load tokens from file successfully', async () => {
const tokensArray = [mockCredentials];
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(1);
expect(tokens.get('test-server')).toEqual(mockCredentials);
expect(fs.readFile).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
'utf-8',
);
});
it('should handle corrupted token file gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to load MCP OAuth tokens'),
);
});
it('should handle file read errors other than ENOENT', async () => {
const error = new Error('Permission denied');
vi.mocked(fs.readFile).mockRejectedValue(error);
const tokens = await MCPOAuthTokenStorage.loadTokens();
expect(tokens.size).toBe(0);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to load MCP OAuth tokens'),
);
});
});
describe('saveToken', () => {
it('should save token with restricted permissions', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.saveToken(
'test-server',
mockToken,
'client-id',
'https://token.url',
);
expect(fs.mkdir).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini'),
{ recursive: true },
);
expect(fs.writeFile).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
expect.stringContaining('test-server'),
{ mode: 0o600 },
);
});
it('should update existing token for same server', async () => {
const existingCredentials = {
...mockCredentials,
serverName: 'existing-server',
};
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([existingCredentials]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const newToken = { ...mockToken, accessToken: 'new_access_token' };
await MCPOAuthTokenStorage.saveToken('existing-server', newToken);
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
expect(savedData).toHaveLength(1);
expect(savedData[0].token.accessToken).toBe('new_access_token');
expect(savedData[0].serverName).toBe('existing-server');
});
it('should handle write errors gracefully', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
const writeError = new Error('Disk full');
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
await expect(
MCPOAuthTokenStorage.saveToken('test-server', mockToken),
).rejects.toThrow('Disk full');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to save MCP OAuth token'),
);
});
});
describe('getToken', () => {
it('should return token for existing server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
const result = await MCPOAuthTokenStorage.getToken('test-server');
expect(result).toEqual(mockCredentials);
});
it('should return null for non-existent server', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
const result = await MCPOAuthTokenStorage.getToken('non-existent');
expect(result).toBeNull();
});
it('should return null when no tokens file exists', async () => {
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
const result = await MCPOAuthTokenStorage.getToken('test-server');
expect(result).toBeNull();
});
});
describe('removeToken', () => {
it('should remove token for specific server', async () => {
const credentials1 = { ...mockCredentials, serverName: 'server1' };
const credentials2 = { ...mockCredentials, serverName: 'server2' };
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([credentials1, credentials2]),
);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('server1');
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
const savedData = JSON.parse(writeCall[1] as string);
expect(savedData).toHaveLength(1);
expect(savedData[0].serverName).toBe('server2');
});
it('should remove token file when no tokens remain', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.removeToken('test-server');
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
);
expect(fs.writeFile).not.toHaveBeenCalled();
});
it('should handle removal of non-existent token gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
await MCPOAuthTokenStorage.removeToken('non-existent');
expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.unlink).not.toHaveBeenCalled();
});
it('should handle file operation errors gracefully', async () => {
vi.mocked(fs.readFile).mockResolvedValue(
JSON.stringify([mockCredentials]),
);
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.removeToken('test-server');
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to remove MCP OAuth token'),
);
});
});
describe('isTokenExpired', () => {
it('should return false for token without expiry', () => {
const tokenWithoutExpiry = { ...mockToken };
delete tokenWithoutExpiry.expiresAt;
const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry);
expect(result).toBe(false);
});
it('should return false for valid token', () => {
const futureToken = {
...mockToken,
expiresAt: Date.now() + 3600000, // 1 hour from now
};
const result = MCPOAuthTokenStorage.isTokenExpired(futureToken);
expect(result).toBe(false);
});
it('should return true for expired token', () => {
const expiredToken = {
...mockToken,
expiresAt: Date.now() - 3600000, // 1 hour ago
};
const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken);
expect(result).toBe(true);
});
it('should return true for token expiring within buffer time', () => {
const soonToExpireToken = {
...mockToken,
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
};
const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken);
expect(result).toBe(true);
});
});
describe('clearAllTokens', () => {
it('should remove token file successfully', async () => {
vi.mocked(fs.unlink).mockResolvedValue(undefined);
await MCPOAuthTokenStorage.clearAllTokens();
expect(fs.unlink).toHaveBeenCalledWith(
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
);
});
it('should handle non-existent file gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
await MCPOAuthTokenStorage.clearAllTokens();
expect(console.error).not.toHaveBeenCalled();
});
it('should handle other file errors gracefully', async () => {
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
await MCPOAuthTokenStorage.clearAllTokens();
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining('Failed to clear MCP OAuth tokens'),
);
});
});
});

View File

@ -0,0 +1,205 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { promises as fs } from 'node:fs';
import * as path from 'node:path';
import * as os from 'node:os';
import { getErrorMessage } from '../utils/errors.js';
/**
* Interface for MCP OAuth tokens.
*/
export interface MCPOAuthToken {
accessToken: string;
refreshToken?: string;
expiresAt?: number;
tokenType: string;
scope?: string;
}
/**
* Interface for stored MCP OAuth credentials.
*/
export interface MCPOAuthCredentials {
serverName: string;
token: MCPOAuthToken;
clientId?: string;
tokenUrl?: string;
updatedAt: number;
}
/**
* Class for managing MCP OAuth token storage and retrieval.
*/
export class MCPOAuthTokenStorage {
private static readonly TOKEN_FILE = 'mcp-oauth-tokens.json';
private static readonly CONFIG_DIR = '.gemini';
/**
* Get the path to the token storage file.
*
* @returns The full path to the token storage file
*/
private static getTokenFilePath(): string {
const homeDir = os.homedir();
return path.join(homeDir, this.CONFIG_DIR, this.TOKEN_FILE);
}
/**
* Ensure the config directory exists.
*/
private static async ensureConfigDir(): Promise<void> {
const configDir = path.dirname(this.getTokenFilePath());
await fs.mkdir(configDir, { recursive: true });
}
/**
* Load all stored MCP OAuth tokens.
*
* @returns A map of server names to credentials
*/
static async loadTokens(): Promise<Map<string, MCPOAuthCredentials>> {
const tokenMap = new Map<string, MCPOAuthCredentials>();
try {
const tokenFile = this.getTokenFilePath();
const data = await fs.readFile(tokenFile, 'utf-8');
const tokens = JSON.parse(data) as MCPOAuthCredentials[];
for (const credential of tokens) {
tokenMap.set(credential.serverName, credential);
}
} catch (error) {
// File doesn't exist or is invalid, return empty map
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
console.error(
`Failed to load MCP OAuth tokens: ${getErrorMessage(error)}`,
);
}
}
return tokenMap;
}
/**
* Save a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @param token The OAuth token to save
* @param clientId Optional client ID used for this token
* @param tokenUrl Optional token URL used for this token
*/
static async saveToken(
serverName: string,
token: MCPOAuthToken,
clientId?: string,
tokenUrl?: string,
): Promise<void> {
await this.ensureConfigDir();
const tokens = await this.loadTokens();
const credential: MCPOAuthCredentials = {
serverName,
token,
clientId,
tokenUrl,
updatedAt: Date.now(),
};
tokens.set(serverName, credential);
const tokenArray = Array.from(tokens.values());
const tokenFile = this.getTokenFilePath();
try {
await fs.writeFile(
tokenFile,
JSON.stringify(tokenArray, null, 2),
{ mode: 0o600 }, // Restrict file permissions
);
} catch (error) {
console.error(
`Failed to save MCP OAuth token: ${getErrorMessage(error)}`,
);
throw error;
}
}
/**
* Get a token for a specific MCP server.
*
* @param serverName The name of the MCP server
* @returns The stored credentials or null if not found
*/
static async getToken(
serverName: string,
): Promise<MCPOAuthCredentials | null> {
const tokens = await this.loadTokens();
return tokens.get(serverName) || null;
}
/**
* Remove a token for a specific MCP server.
*
* @param serverName The name of the MCP server
*/
static async removeToken(serverName: string): Promise<void> {
const tokens = await this.loadTokens();
if (tokens.delete(serverName)) {
const tokenArray = Array.from(tokens.values());
const tokenFile = this.getTokenFilePath();
try {
if (tokenArray.length === 0) {
// Remove file if no tokens left
await fs.unlink(tokenFile);
} else {
await fs.writeFile(tokenFile, JSON.stringify(tokenArray, null, 2), {
mode: 0o600,
});
}
} catch (error) {
console.error(
`Failed to remove MCP OAuth token: ${getErrorMessage(error)}`,
);
}
}
}
/**
* Check if a token is expired.
*
* @param token The token to check
* @returns True if the token is expired
*/
static isTokenExpired(token: MCPOAuthToken): boolean {
if (!token.expiresAt) {
return false; // No expiry, assume valid
}
// Add a 5-minute buffer to account for clock skew
const bufferMs = 5 * 60 * 1000;
return Date.now() + bufferMs >= token.expiresAt;
}
/**
* Clear all stored MCP OAuth tokens.
*/
static async clearAllTokens(): Promise<void> {
try {
const tokenFile = this.getTokenFilePath();
await fs.unlink(tokenFile);
} catch (error) {
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
console.error(
`Failed to clear MCP OAuth tokens: ${getErrorMessage(error)}`,
);
}
}
}
}

View File

@ -0,0 +1,206 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
OAuthUtils,
OAuthAuthorizationServerMetadata,
OAuthProtectedResourceMetadata,
} from './oauth-utils.js';
// Mock fetch globally
const mockFetch = vi.fn();
global.fetch = mockFetch;
describe('OAuthUtils', () => {
beforeEach(() => {
vi.clearAllMocks();
vi.spyOn(console, 'debug').mockImplementation(() => {});
vi.spyOn(console, 'error').mockImplementation(() => {});
vi.spyOn(console, 'log').mockImplementation(() => {});
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('buildWellKnownUrls', () => {
it('should build correct well-known URLs', () => {
const urls = OAuthUtils.buildWellKnownUrls('https://example.com/path');
expect(urls.protectedResource).toBe(
'https://example.com/.well-known/oauth-protected-resource',
);
expect(urls.authorizationServer).toBe(
'https://example.com/.well-known/oauth-authorization-server',
);
});
});
describe('fetchProtectedResourceMetadata', () => {
const mockResourceMetadata: OAuthProtectedResourceMetadata = {
resource: 'https://api.example.com',
authorization_servers: ['https://auth.example.com'],
bearer_methods_supported: ['header'],
};
it('should fetch protected resource metadata successfully', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockResourceMetadata),
});
const result = await OAuthUtils.fetchProtectedResourceMetadata(
'https://example.com/.well-known/oauth-protected-resource',
);
expect(result).toEqual(mockResourceMetadata);
});
it('should return null when fetch fails', async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
});
const result = await OAuthUtils.fetchProtectedResourceMetadata(
'https://example.com/.well-known/oauth-protected-resource',
);
expect(result).toBeNull();
});
});
describe('fetchAuthorizationServerMetadata', () => {
const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
scopes_supported: ['read', 'write'],
};
it('should fetch authorization server metadata successfully', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve(mockAuthServerMetadata),
});
const result = await OAuthUtils.fetchAuthorizationServerMetadata(
'https://auth.example.com/.well-known/oauth-authorization-server',
);
expect(result).toEqual(mockAuthServerMetadata);
});
it('should return null when fetch fails', async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
});
const result = await OAuthUtils.fetchAuthorizationServerMetadata(
'https://auth.example.com/.well-known/oauth-authorization-server',
);
expect(result).toBeNull();
});
});
describe('metadataToOAuthConfig', () => {
it('should convert metadata to OAuth config', () => {
const metadata: OAuthAuthorizationServerMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
scopes_supported: ['read', 'write'],
};
const config = OAuthUtils.metadataToOAuthConfig(metadata);
expect(config).toEqual({
authorizationUrl: 'https://auth.example.com/authorize',
tokenUrl: 'https://auth.example.com/token',
scopes: ['read', 'write'],
});
});
it('should handle empty scopes', () => {
const metadata: OAuthAuthorizationServerMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
};
const config = OAuthUtils.metadataToOAuthConfig(metadata);
expect(config.scopes).toEqual([]);
});
});
describe('parseWWWAuthenticateHeader', () => {
it('should parse resource metadata URI from WWW-Authenticate header', () => {
const header =
'Bearer realm="example", resource_metadata_uri="https://example.com/.well-known/oauth-protected-resource"';
const result = OAuthUtils.parseWWWAuthenticateHeader(header);
expect(result).toBe(
'https://example.com/.well-known/oauth-protected-resource',
);
});
it('should return null when no resource metadata URI is found', () => {
const header = 'Bearer realm="example"';
const result = OAuthUtils.parseWWWAuthenticateHeader(header);
expect(result).toBeNull();
});
});
describe('extractBaseUrl', () => {
it('should extract base URL from MCP server URL', () => {
const result = OAuthUtils.extractBaseUrl('https://example.com/mcp/v1');
expect(result).toBe('https://example.com');
});
it('should handle URLs with ports', () => {
const result = OAuthUtils.extractBaseUrl(
'https://example.com:8080/mcp/v1',
);
expect(result).toBe('https://example.com:8080');
});
});
describe('isSSEEndpoint', () => {
it('should return true for SSE endpoints', () => {
expect(OAuthUtils.isSSEEndpoint('https://example.com/sse')).toBe(true);
expect(OAuthUtils.isSSEEndpoint('https://example.com/api/v1/sse')).toBe(
true,
);
});
it('should return true for non-MCP endpoints', () => {
expect(OAuthUtils.isSSEEndpoint('https://example.com/api')).toBe(true);
});
it('should return false for MCP endpoints', () => {
expect(OAuthUtils.isSSEEndpoint('https://example.com/mcp')).toBe(false);
expect(OAuthUtils.isSSEEndpoint('https://example.com/api/mcp/v1')).toBe(
false,
);
});
});
describe('buildResourceParameter', () => {
it('should build resource parameter from endpoint URL', () => {
const result = OAuthUtils.buildResourceParameter(
'https://example.com/oauth/token',
);
expect(result).toBe('https://example.com');
});
it('should handle URLs with ports', () => {
const result = OAuthUtils.buildResourceParameter(
'https://example.com:8080/oauth/token',
);
expect(result).toBe('https://example.com:8080');
});
});
});

View File

@ -0,0 +1,285 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { MCPOAuthConfig } from './oauth-provider.js';
import { getErrorMessage } from '../utils/errors.js';
/**
* OAuth authorization server metadata as per RFC 8414.
*/
export interface OAuthAuthorizationServerMetadata {
issuer: string;
authorization_endpoint: string;
token_endpoint: string;
token_endpoint_auth_methods_supported?: string[];
revocation_endpoint?: string;
revocation_endpoint_auth_methods_supported?: string[];
registration_endpoint?: string;
response_types_supported?: string[];
grant_types_supported?: string[];
code_challenge_methods_supported?: string[];
scopes_supported?: string[];
}
/**
* OAuth protected resource metadata as per RFC 9728.
*/
export interface OAuthProtectedResourceMetadata {
resource: string;
authorization_servers?: string[];
bearer_methods_supported?: string[];
resource_documentation?: string;
resource_signing_alg_values_supported?: string[];
resource_encryption_alg_values_supported?: string[];
resource_encryption_enc_values_supported?: string[];
}
/**
* Utility class for common OAuth operations.
*/
export class OAuthUtils {
/**
* Construct well-known OAuth endpoint URLs.
*/
static buildWellKnownUrls(baseUrl: string) {
const serverUrl = new URL(baseUrl);
const base = `${serverUrl.protocol}//${serverUrl.host}`;
return {
protectedResource: new URL(
'/.well-known/oauth-protected-resource',
base,
).toString(),
authorizationServer: new URL(
'/.well-known/oauth-authorization-server',
base,
).toString(),
};
}
/**
* Fetch OAuth protected resource metadata.
*
* @param resourceMetadataUrl The protected resource metadata URL
* @returns The protected resource metadata or null if not available
*/
static async fetchProtectedResourceMetadata(
resourceMetadataUrl: string,
): Promise<OAuthProtectedResourceMetadata | null> {
try {
const response = await fetch(resourceMetadataUrl);
if (!response.ok) {
return null;
}
return (await response.json()) as OAuthProtectedResourceMetadata;
} catch (error) {
console.debug(
`Failed to fetch protected resource metadata from ${resourceMetadataUrl}: ${getErrorMessage(error)}`,
);
return null;
}
}
/**
* Fetch OAuth authorization server metadata.
*
* @param authServerMetadataUrl The authorization server metadata URL
* @returns The authorization server metadata or null if not available
*/
static async fetchAuthorizationServerMetadata(
authServerMetadataUrl: string,
): Promise<OAuthAuthorizationServerMetadata | null> {
try {
const response = await fetch(authServerMetadataUrl);
if (!response.ok) {
return null;
}
return (await response.json()) as OAuthAuthorizationServerMetadata;
} catch (error) {
console.debug(
`Failed to fetch authorization server metadata from ${authServerMetadataUrl}: ${getErrorMessage(error)}`,
);
return null;
}
}
/**
* Convert authorization server metadata to OAuth configuration.
*
* @param metadata The authorization server metadata
* @returns The OAuth configuration
*/
static metadataToOAuthConfig(
metadata: OAuthAuthorizationServerMetadata,
): MCPOAuthConfig {
return {
authorizationUrl: metadata.authorization_endpoint,
tokenUrl: metadata.token_endpoint,
scopes: metadata.scopes_supported || [],
};
}
/**
* Discover OAuth configuration using the standard well-known endpoints.
*
* @param serverUrl The base URL of the server
* @returns The discovered OAuth configuration or null if not available
*/
static async discoverOAuthConfig(
serverUrl: string,
): Promise<MCPOAuthConfig | null> {
try {
const wellKnownUrls = this.buildWellKnownUrls(serverUrl);
// First, try to get the protected resource metadata
const resourceMetadata = await this.fetchProtectedResourceMetadata(
wellKnownUrls.protectedResource,
);
if (resourceMetadata?.authorization_servers?.length) {
// Use the first authorization server
const authServerUrl = resourceMetadata.authorization_servers[0];
const authServerMetadataUrl = new URL(
'/.well-known/oauth-authorization-server',
authServerUrl,
).toString();
const authServerMetadata = await this.fetchAuthorizationServerMetadata(
authServerMetadataUrl,
);
if (authServerMetadata) {
const config = this.metadataToOAuthConfig(authServerMetadata);
if (authServerMetadata.registration_endpoint) {
console.log(
'Dynamic client registration is supported at:',
authServerMetadata.registration_endpoint,
);
}
return config;
}
}
// Fallback: try /.well-known/oauth-authorization-server at the base URL
console.debug(
`Trying OAuth discovery fallback at ${wellKnownUrls.authorizationServer}`,
);
const authServerMetadata = await this.fetchAuthorizationServerMetadata(
wellKnownUrls.authorizationServer,
);
if (authServerMetadata) {
const config = this.metadataToOAuthConfig(authServerMetadata);
if (authServerMetadata.registration_endpoint) {
console.log(
'Dynamic client registration is supported at:',
authServerMetadata.registration_endpoint,
);
}
return config;
}
return null;
} catch (error) {
console.debug(
`Failed to discover OAuth configuration: ${getErrorMessage(error)}`,
);
return null;
}
}
/**
* Parse WWW-Authenticate header to extract OAuth information.
*
* @param header The WWW-Authenticate header value
* @returns The resource metadata URI if found
*/
static parseWWWAuthenticateHeader(header: string): string | null {
// Parse Bearer realm and resource_metadata_uri
const match = header.match(/resource_metadata_uri="([^"]+)"/);
if (match) {
return match[1];
}
return null;
}
/**
* Discover OAuth configuration from WWW-Authenticate header.
*
* @param wwwAuthenticate The WWW-Authenticate header value
* @returns The discovered OAuth configuration or null if not available
*/
static async discoverOAuthFromWWWAuthenticate(
wwwAuthenticate: string,
): Promise<MCPOAuthConfig | null> {
const resourceMetadataUri =
this.parseWWWAuthenticateHeader(wwwAuthenticate);
if (!resourceMetadataUri) {
return null;
}
console.log(
`Found resource metadata URI from www-authenticate header: ${resourceMetadataUri}`,
);
const resourceMetadata =
await this.fetchProtectedResourceMetadata(resourceMetadataUri);
if (!resourceMetadata?.authorization_servers?.length) {
return null;
}
const authServerUrl = resourceMetadata.authorization_servers[0];
const authServerMetadataUrl = new URL(
'/.well-known/oauth-authorization-server',
authServerUrl,
).toString();
const authServerMetadata = await this.fetchAuthorizationServerMetadata(
authServerMetadataUrl,
);
if (authServerMetadata) {
console.log(
'OAuth configuration discovered successfully from www-authenticate header',
);
return this.metadataToOAuthConfig(authServerMetadata);
}
return null;
}
/**
* Extract base URL from an MCP server URL.
*
* @param mcpServerUrl The MCP server URL
* @returns The base URL
*/
static extractBaseUrl(mcpServerUrl: string): string {
const serverUrl = new URL(mcpServerUrl);
return `${serverUrl.protocol}//${serverUrl.host}`;
}
/**
* Check if a URL is an SSE endpoint.
*
* @param url The URL to check
* @returns True if the URL appears to be an SSE endpoint
*/
static isSSEEndpoint(url: string): boolean {
return url.includes('/sse') || !url.includes('/mcp');
}
/**
* Build a resource parameter for OAuth requests.
*
* @param endpointUrl The endpoint URL
* @returns The resource parameter value
*/
static buildResourceParameter(endpointUrl: string): string {
const url = new URL(endpointUrl);
return `${url.protocol}//${url.host}`;
}
}