MCP OAuth Part 1 - OAuth Infrastructure (#4316)
This commit is contained in:
parent
de27ea6095
commit
c5761317f4
|
@ -60,6 +60,20 @@ export * from './tools/read-many-files.js';
|
|||
export * from './tools/mcp-client.js';
|
||||
export * from './tools/mcp-tool.js';
|
||||
|
||||
// MCP OAuth
|
||||
export { MCPOAuthProvider } from './mcp/oauth-provider.js';
|
||||
export {
|
||||
MCPOAuthToken,
|
||||
MCPOAuthCredentials,
|
||||
MCPOAuthTokenStorage,
|
||||
} from './mcp/oauth-token-storage.js';
|
||||
export type { MCPOAuthConfig } from './mcp/oauth-provider.js';
|
||||
export type {
|
||||
OAuthAuthorizationServerMetadata,
|
||||
OAuthProtectedResourceMetadata,
|
||||
} from './mcp/oauth-utils.js';
|
||||
export { OAuthUtils } from './mcp/oauth-utils.js';
|
||||
|
||||
// Export telemetry functions
|
||||
export * from './telemetry/index.js';
|
||||
export { sessionId } from './utils/session.js';
|
||||
|
|
|
@ -0,0 +1,720 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import * as http from 'node:http';
|
||||
import * as crypto from 'node:crypto';
|
||||
import open from 'open';
|
||||
import {
|
||||
MCPOAuthProvider,
|
||||
MCPOAuthConfig,
|
||||
OAuthTokenResponse,
|
||||
OAuthClientRegistrationResponse,
|
||||
} from './oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage, MCPOAuthToken } from './oauth-token-storage.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('open');
|
||||
vi.mock('node:crypto');
|
||||
vi.mock('./oauth-token-storage.js');
|
||||
|
||||
// Mock fetch globally
|
||||
const mockFetch = vi.fn();
|
||||
global.fetch = mockFetch;
|
||||
|
||||
// Define a reusable mock server with .listen, .close, and .on methods
|
||||
const mockHttpServer = {
|
||||
listen: vi.fn(),
|
||||
close: vi.fn(),
|
||||
on: vi.fn(),
|
||||
};
|
||||
vi.mock('node:http', () => ({
|
||||
createServer: vi.fn(() => mockHttpServer),
|
||||
}));
|
||||
|
||||
describe('MCPOAuthProvider', () => {
|
||||
const mockConfig: MCPOAuthConfig = {
|
||||
enabled: true,
|
||||
clientId: 'test-client-id',
|
||||
clientSecret: 'test-client-secret',
|
||||
authorizationUrl: 'https://auth.example.com/authorize',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
scopes: ['read', 'write'],
|
||||
redirectUri: 'http://localhost:7777/oauth/callback',
|
||||
};
|
||||
|
||||
const mockToken: MCPOAuthToken = {
|
||||
accessToken: 'access_token_123',
|
||||
refreshToken: 'refresh_token_456',
|
||||
tokenType: 'Bearer',
|
||||
scope: 'read write',
|
||||
expiresAt: Date.now() + 3600000,
|
||||
};
|
||||
|
||||
const mockTokenResponse: OAuthTokenResponse = {
|
||||
access_token: 'access_token_123',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'refresh_token_456',
|
||||
scope: 'read write',
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
// Mock crypto functions
|
||||
vi.mocked(crypto.randomBytes).mockImplementation((size: number) => {
|
||||
if (size === 32) return Buffer.from('code_verifier_mock_32_bytes_long');
|
||||
if (size === 16) return Buffer.from('state_mock_16_by');
|
||||
return Buffer.alloc(size);
|
||||
});
|
||||
|
||||
vi.mocked(crypto.createHash).mockReturnValue({
|
||||
update: vi.fn().mockReturnThis(),
|
||||
digest: vi.fn().mockReturnValue('code_challenge_mock'),
|
||||
} as unknown as crypto.Hash);
|
||||
|
||||
// Mock randomBytes to return predictable values for state
|
||||
vi.mocked(crypto.randomBytes).mockImplementation((size) => {
|
||||
if (size === 32) {
|
||||
return Buffer.from('mock_code_verifier_32_bytes_long_string');
|
||||
} else if (size === 16) {
|
||||
return Buffer.from('mock_state_16_bytes');
|
||||
}
|
||||
return Buffer.alloc(size);
|
||||
});
|
||||
|
||||
// Mock token storage
|
||||
vi.mocked(MCPOAuthTokenStorage.saveToken).mockResolvedValue(undefined);
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('authenticate', () => {
|
||||
it('should perform complete OAuth flow with PKCE', async () => {
|
||||
// Mock HTTP server callback
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
// Simulate OAuth callback
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
// Mock token exchange
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockTokenResponse),
|
||||
});
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
accessToken: 'access_token_123',
|
||||
refreshToken: 'refresh_token_456',
|
||||
tokenType: 'Bearer',
|
||||
scope: 'read write',
|
||||
expiresAt: expect.any(Number),
|
||||
});
|
||||
|
||||
expect(open).toHaveBeenCalledWith(expect.stringContaining('authorize'));
|
||||
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
expect.objectContaining({ accessToken: 'access_token_123' }),
|
||||
'test-client-id',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle OAuth discovery when no authorization URL provided', async () => {
|
||||
// Use a mutable config object
|
||||
const configWithoutAuth: MCPOAuthConfig = { ...mockConfig };
|
||||
delete configWithoutAuth.authorizationUrl;
|
||||
delete configWithoutAuth.tokenUrl;
|
||||
|
||||
const mockResourceMetadata = {
|
||||
authorization_servers: ['https://discovered.auth.com'],
|
||||
};
|
||||
|
||||
const mockAuthServerMetadata = {
|
||||
authorization_endpoint: 'https://discovered.auth.com/authorize',
|
||||
token_endpoint: 'https://discovered.auth.com/token',
|
||||
scopes_supported: ['read', 'write'],
|
||||
};
|
||||
|
||||
mockFetch
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockResourceMetadata),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockAuthServerMetadata),
|
||||
});
|
||||
|
||||
// Patch config after discovery
|
||||
configWithoutAuth.authorizationUrl =
|
||||
mockAuthServerMetadata.authorization_endpoint;
|
||||
configWithoutAuth.tokenUrl = mockAuthServerMetadata.token_endpoint;
|
||||
configWithoutAuth.scopes = mockAuthServerMetadata.scopes_supported;
|
||||
|
||||
// Setup callback handler
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
// Mock token exchange with discovered endpoint
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockTokenResponse),
|
||||
});
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutAuth,
|
||||
'https://api.example.com',
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://discovered.auth.com/token',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should perform dynamic client registration when no client ID provided', async () => {
|
||||
const configWithoutClient = { ...mockConfig };
|
||||
delete configWithoutClient.clientId;
|
||||
|
||||
const mockRegistrationResponse: OAuthClientRegistrationResponse = {
|
||||
client_id: 'dynamic_client_id',
|
||||
client_secret: 'dynamic_client_secret',
|
||||
redirect_uris: ['http://localhost:7777/oauth/callback'],
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
response_types: ['code'],
|
||||
token_endpoint_auth_method: 'none',
|
||||
};
|
||||
|
||||
const mockAuthServerMetadata = {
|
||||
authorization_endpoint: 'https://auth.example.com/authorize',
|
||||
token_endpoint: 'https://auth.example.com/token',
|
||||
registration_endpoint: 'https://auth.example.com/register',
|
||||
};
|
||||
|
||||
mockFetch
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockAuthServerMetadata),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockRegistrationResponse),
|
||||
});
|
||||
|
||||
// Setup callback handler
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
// Mock token exchange
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockTokenResponse),
|
||||
});
|
||||
|
||||
const result = await MCPOAuthProvider.authenticate(
|
||||
'test-server',
|
||||
configWithoutClient,
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://auth.example.com/register',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle OAuth callback errors', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?error=access_denied&error_description=User%20denied%20access',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('OAuth error: access_denied');
|
||||
});
|
||||
|
||||
it('should handle state mismatch in callback', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=wrong_state',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('State mismatch - possible CSRF attack');
|
||||
});
|
||||
|
||||
it('should handle token exchange failure', async () => {
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
text: () => Promise.resolve('Invalid grant'),
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('Token exchange failed: 400 - Invalid grant');
|
||||
});
|
||||
|
||||
it('should handle callback timeout', async () => {
|
||||
vi.mocked(http.createServer).mockImplementation(
|
||||
() => mockHttpServer as unknown as http.Server,
|
||||
);
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
// Don't trigger callback - simulate timeout
|
||||
});
|
||||
|
||||
// Mock setTimeout to trigger timeout immediately
|
||||
const originalSetTimeout = global.setTimeout;
|
||||
global.setTimeout = vi.fn((callback, delay) => {
|
||||
if (delay === 5 * 60 * 1000) {
|
||||
// 5 minute timeout
|
||||
callback();
|
||||
}
|
||||
return originalSetTimeout(callback, 0);
|
||||
}) as unknown as typeof setTimeout;
|
||||
|
||||
await expect(
|
||||
MCPOAuthProvider.authenticate('test-server', mockConfig),
|
||||
).rejects.toThrow('OAuth callback timeout');
|
||||
|
||||
global.setTimeout = originalSetTimeout;
|
||||
});
|
||||
});
|
||||
|
||||
describe('refreshAccessToken', () => {
|
||||
it('should refresh token successfully', async () => {
|
||||
const refreshResponse = {
|
||||
access_token: 'new_access_token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'new_refresh_token',
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(refreshResponse),
|
||||
});
|
||||
|
||||
const result = await MCPOAuthProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'old_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
|
||||
expect(result).toEqual(refreshResponse);
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'https://auth.example.com/token',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
|
||||
body: expect.stringContaining('grant_type=refresh_token'),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should include client secret in refresh request when available', async () => {
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockTokenResponse),
|
||||
});
|
||||
|
||||
await MCPOAuthProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
|
||||
const fetchCall = mockFetch.mock.calls[0];
|
||||
expect(fetchCall[1].body).toContain('client_secret=test-client-secret');
|
||||
});
|
||||
|
||||
it('should handle refresh token failure', async () => {
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
text: () => Promise.resolve('Invalid refresh token'),
|
||||
});
|
||||
|
||||
await expect(
|
||||
MCPOAuthProvider.refreshAccessToken(
|
||||
mockConfig,
|
||||
'invalid_refresh_token',
|
||||
'https://auth.example.com/token',
|
||||
),
|
||||
).rejects.toThrow('Token refresh failed: 400 - Invalid refresh token');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getValidToken', () => {
|
||||
it('should return valid token when not expired', async () => {
|
||||
const validCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: mockToken,
|
||||
clientId: 'test-client-id',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
validCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(false);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBe('access_token_123');
|
||||
});
|
||||
|
||||
it('should refresh expired token and return new token', async () => {
|
||||
const expiredCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: { ...mockToken, expiresAt: Date.now() - 3600000 },
|
||||
clientId: 'test-client-id',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const refreshResponse = {
|
||||
access_token: 'new_access_token',
|
||||
token_type: 'Bearer',
|
||||
expires_in: 3600,
|
||||
refresh_token: 'new_refresh_token',
|
||||
};
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(refreshResponse),
|
||||
});
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBe('new_access_token');
|
||||
expect(MCPOAuthTokenStorage.saveToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
expect.objectContaining({ accessToken: 'new_access_token' }),
|
||||
'test-client-id',
|
||||
'https://auth.example.com/token',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null when no credentials exist', async () => {
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(null);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle refresh failure and remove invalid token', async () => {
|
||||
const expiredCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: { ...mockToken, expiresAt: Date.now() - 3600000 },
|
||||
clientId: 'test-client-id',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
expiredCredentials,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
vi.mocked(MCPOAuthTokenStorage.removeToken).mockResolvedValue(undefined);
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
text: () => Promise.resolve('Invalid refresh token'),
|
||||
});
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(MCPOAuthTokenStorage.removeToken).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to refresh token'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null for token without refresh capability', async () => {
|
||||
const tokenWithoutRefresh = {
|
||||
serverName: 'test-server',
|
||||
token: {
|
||||
...mockToken,
|
||||
refreshToken: undefined,
|
||||
expiresAt: Date.now() - 3600000,
|
||||
},
|
||||
clientId: 'test-client-id',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
vi.mocked(MCPOAuthTokenStorage.getToken).mockResolvedValue(
|
||||
tokenWithoutRefresh,
|
||||
);
|
||||
vi.mocked(MCPOAuthTokenStorage.isTokenExpired).mockReturnValue(true);
|
||||
|
||||
const result = await MCPOAuthProvider.getValidToken(
|
||||
'test-server',
|
||||
mockConfig,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('PKCE parameter generation', () => {
|
||||
it('should generate valid PKCE parameters', async () => {
|
||||
// Test is implicit in the authenticate flow tests, but we can verify
|
||||
// the crypto mocks are called correctly
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockTokenResponse),
|
||||
});
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', mockConfig);
|
||||
|
||||
expect(crypto.randomBytes).toHaveBeenCalledWith(32); // code verifier
|
||||
expect(crypto.randomBytes).toHaveBeenCalledWith(16); // state
|
||||
expect(crypto.createHash).toHaveBeenCalledWith('sha256');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Authorization URL building', () => {
|
||||
it('should build correct authorization URL with all parameters', async () => {
|
||||
// Mock to capture the URL that would be opened
|
||||
let capturedUrl: string;
|
||||
vi.mocked(open).mockImplementation((url) => {
|
||||
capturedUrl = url;
|
||||
// Return a minimal mock ChildProcess
|
||||
return Promise.resolve({
|
||||
pid: 1234,
|
||||
} as unknown as import('child_process').ChildProcess);
|
||||
});
|
||||
|
||||
let callbackHandler: unknown;
|
||||
vi.mocked(http.createServer).mockImplementation((handler) => {
|
||||
callbackHandler = handler;
|
||||
return mockHttpServer as unknown as http.Server;
|
||||
});
|
||||
|
||||
mockHttpServer.listen.mockImplementation((port, callback) => {
|
||||
callback?.();
|
||||
setTimeout(() => {
|
||||
const mockReq = {
|
||||
url: '/oauth/callback?code=auth_code_123&state=bW9ja19zdGF0ZV8xNl9ieXRlcw',
|
||||
};
|
||||
const mockRes = {
|
||||
writeHead: vi.fn(),
|
||||
end: vi.fn(),
|
||||
};
|
||||
(callbackHandler as (req: unknown, res: unknown) => void)(
|
||||
mockReq,
|
||||
mockRes,
|
||||
);
|
||||
}, 10);
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockTokenResponse),
|
||||
});
|
||||
|
||||
await MCPOAuthProvider.authenticate('test-server', mockConfig);
|
||||
|
||||
expect(capturedUrl!).toContain('response_type=code');
|
||||
expect(capturedUrl!).toContain('client_id=test-client-id');
|
||||
expect(capturedUrl!).toContain('code_challenge=code_challenge_mock');
|
||||
expect(capturedUrl!).toContain('code_challenge_method=S256');
|
||||
expect(capturedUrl!).toContain('scope=read+write');
|
||||
expect(capturedUrl!).toContain('resource=https%3A%2F%2Fauth.example.com');
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,698 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import * as crypto from 'node:crypto';
|
||||
import { URL } from 'node:url';
|
||||
import open from 'open';
|
||||
import { MCPOAuthToken, MCPOAuthTokenStorage } from './oauth-token-storage.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { OAuthUtils } from './oauth-utils.js';
|
||||
|
||||
/**
|
||||
* OAuth configuration for an MCP server.
|
||||
*/
|
||||
export interface MCPOAuthConfig {
|
||||
enabled?: boolean; // Whether OAuth is enabled for this server
|
||||
clientId?: string;
|
||||
clientSecret?: string;
|
||||
authorizationUrl?: string;
|
||||
tokenUrl?: string;
|
||||
scopes?: string[];
|
||||
redirectUri?: string;
|
||||
tokenParamName?: string; // For SSE connections, specifies the query parameter name for the token
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth authorization response.
|
||||
*/
|
||||
export interface OAuthAuthorizationResponse {
|
||||
code: string;
|
||||
state: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth token response from the authorization server.
|
||||
*/
|
||||
export interface OAuthTokenResponse {
|
||||
access_token: string;
|
||||
token_type: string;
|
||||
expires_in?: number;
|
||||
refresh_token?: string;
|
||||
scope?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic client registration request.
|
||||
*/
|
||||
export interface OAuthClientRegistrationRequest {
|
||||
client_name: string;
|
||||
redirect_uris: string[];
|
||||
grant_types: string[];
|
||||
response_types: string[];
|
||||
token_endpoint_auth_method: string;
|
||||
code_challenge_method?: string[];
|
||||
scope?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic client registration response.
|
||||
*/
|
||||
export interface OAuthClientRegistrationResponse {
|
||||
client_id: string;
|
||||
client_secret?: string;
|
||||
client_id_issued_at?: number;
|
||||
client_secret_expires_at?: number;
|
||||
redirect_uris: string[];
|
||||
grant_types: string[];
|
||||
response_types: string[];
|
||||
token_endpoint_auth_method: string;
|
||||
code_challenge_method?: string[];
|
||||
scope?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* PKCE (Proof Key for Code Exchange) parameters.
|
||||
*/
|
||||
interface PKCEParams {
|
||||
codeVerifier: string;
|
||||
codeChallenge: string;
|
||||
state: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider for handling OAuth authentication for MCP servers.
|
||||
*/
|
||||
export class MCPOAuthProvider {
|
||||
private static readonly REDIRECT_PORT = 7777;
|
||||
private static readonly REDIRECT_PATH = '/oauth/callback';
|
||||
private static readonly HTTP_OK = 200;
|
||||
private static readonly HTTP_REDIRECT = 302;
|
||||
|
||||
/**
|
||||
* Register a client dynamically with the OAuth server.
|
||||
*
|
||||
* @param registrationUrl The client registration endpoint URL
|
||||
* @param config OAuth configuration
|
||||
* @returns The registered client information
|
||||
*/
|
||||
private static async registerClient(
|
||||
registrationUrl: string,
|
||||
config: MCPOAuthConfig,
|
||||
): Promise<OAuthClientRegistrationResponse> {
|
||||
const redirectUri =
|
||||
config.redirectUri ||
|
||||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
|
||||
|
||||
const registrationRequest: OAuthClientRegistrationRequest = {
|
||||
client_name: 'Gemini CLI MCP Client',
|
||||
redirect_uris: [redirectUri],
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
response_types: ['code'],
|
||||
token_endpoint_auth_method: 'none', // Public client
|
||||
code_challenge_method: ['S256'],
|
||||
scope: config.scopes?.join(' ') || '',
|
||||
};
|
||||
|
||||
const response = await fetch(registrationUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(registrationRequest),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Client registration failed: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
return (await response.json()) as OAuthClientRegistrationResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover OAuth configuration from an MCP server URL.
|
||||
*
|
||||
* @param mcpServerUrl The MCP server URL
|
||||
* @returns OAuth configuration if discovered, null otherwise
|
||||
*/
|
||||
private static async discoverOAuthFromMCPServer(
|
||||
mcpServerUrl: string,
|
||||
): Promise<MCPOAuthConfig | null> {
|
||||
const baseUrl = OAuthUtils.extractBaseUrl(mcpServerUrl);
|
||||
return OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE parameters for OAuth flow.
|
||||
*
|
||||
* @returns PKCE parameters including code verifier, challenge, and state
|
||||
*/
|
||||
private static generatePKCEParams(): PKCEParams {
|
||||
// Generate code verifier (43-128 characters)
|
||||
const codeVerifier = crypto.randomBytes(32).toString('base64url');
|
||||
|
||||
// Generate code challenge using SHA256
|
||||
const codeChallenge = crypto
|
||||
.createHash('sha256')
|
||||
.update(codeVerifier)
|
||||
.digest('base64url');
|
||||
|
||||
// Generate state for CSRF protection
|
||||
const state = crypto.randomBytes(16).toString('base64url');
|
||||
|
||||
return { codeVerifier, codeChallenge, state };
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to handle OAuth callback.
|
||||
*
|
||||
* @param expectedState The state parameter to validate
|
||||
* @returns Promise that resolves with the authorization code
|
||||
*/
|
||||
private static async startCallbackServer(
|
||||
expectedState: string,
|
||||
): Promise<OAuthAuthorizationResponse> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const server = http.createServer(
|
||||
async (req: http.IncomingMessage, res: http.ServerResponse) => {
|
||||
try {
|
||||
const url = new URL(
|
||||
req.url!,
|
||||
`http://localhost:${this.REDIRECT_PORT}`,
|
||||
);
|
||||
|
||||
if (url.pathname !== this.REDIRECT_PATH) {
|
||||
res.writeHead(404);
|
||||
res.end('Not found');
|
||||
return;
|
||||
}
|
||||
|
||||
const code = url.searchParams.get('code');
|
||||
const state = url.searchParams.get('state');
|
||||
const error = url.searchParams.get('error');
|
||||
|
||||
if (error) {
|
||||
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
|
||||
res.end(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authentication Failed</h1>
|
||||
<p>Error: ${(error as string).replace(/</g, '<').replace(/>/g, '>')}</p>
|
||||
<p>${((url.searchParams.get('error_description') || '') as string).replace(/</g, '<').replace(/>/g, '>')}</p>
|
||||
<p>You can close this window.</p>
|
||||
</body>
|
||||
</html>
|
||||
`);
|
||||
server.close();
|
||||
reject(new Error(`OAuth error: ${error}`));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
res.writeHead(400);
|
||||
res.end('Missing code or state parameter');
|
||||
return;
|
||||
}
|
||||
|
||||
if (state !== expectedState) {
|
||||
res.writeHead(400);
|
||||
res.end('Invalid state parameter');
|
||||
server.close();
|
||||
reject(new Error('State mismatch - possible CSRF attack'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Send success response to browser
|
||||
res.writeHead(this.HTTP_OK, { 'Content-Type': 'text/html' });
|
||||
res.end(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<p>You can close this window and return to Gemini CLI.</p>
|
||||
<script>window.close();</script>
|
||||
</body>
|
||||
</html>
|
||||
`);
|
||||
|
||||
server.close();
|
||||
resolve({ code, state });
|
||||
} catch (error) {
|
||||
server.close();
|
||||
reject(error);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
server.on('error', reject);
|
||||
server.listen(this.REDIRECT_PORT, () => {
|
||||
console.log(
|
||||
`OAuth callback server listening on port ${this.REDIRECT_PORT}`,
|
||||
);
|
||||
});
|
||||
|
||||
// Timeout after 5 minutes
|
||||
setTimeout(
|
||||
() => {
|
||||
server.close();
|
||||
reject(new Error('OAuth callback timeout'));
|
||||
},
|
||||
5 * 60 * 1000,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the authorization URL with PKCE parameters.
|
||||
*
|
||||
* @param config OAuth configuration
|
||||
* @param pkceParams PKCE parameters
|
||||
* @returns The authorization URL
|
||||
*/
|
||||
private static buildAuthorizationUrl(
|
||||
config: MCPOAuthConfig,
|
||||
pkceParams: PKCEParams,
|
||||
): string {
|
||||
const redirectUri =
|
||||
config.redirectUri ||
|
||||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
client_id: config.clientId!,
|
||||
response_type: 'code',
|
||||
redirect_uri: redirectUri,
|
||||
state: pkceParams.state,
|
||||
code_challenge: pkceParams.codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
});
|
||||
|
||||
if (config.scopes && config.scopes.length > 0) {
|
||||
params.append('scope', config.scopes.join(' '));
|
||||
}
|
||||
|
||||
// Add resource parameter for MCP OAuth spec compliance
|
||||
params.append(
|
||||
'resource',
|
||||
OAuthUtils.buildResourceParameter(config.authorizationUrl!),
|
||||
);
|
||||
|
||||
return `${config.authorizationUrl}?${params.toString()}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Exchange authorization code for tokens.
|
||||
*
|
||||
* @param config OAuth configuration
|
||||
* @param code Authorization code
|
||||
* @param codeVerifier PKCE code verifier
|
||||
* @returns The token response
|
||||
*/
|
||||
private static async exchangeCodeForToken(
|
||||
config: MCPOAuthConfig,
|
||||
code: string,
|
||||
codeVerifier: string,
|
||||
): Promise<OAuthTokenResponse> {
|
||||
const redirectUri =
|
||||
config.redirectUri ||
|
||||
`http://localhost:${this.REDIRECT_PORT}${this.REDIRECT_PATH}`;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code,
|
||||
redirect_uri: redirectUri,
|
||||
code_verifier: codeVerifier,
|
||||
client_id: config.clientId!,
|
||||
});
|
||||
|
||||
if (config.clientSecret) {
|
||||
params.append('client_secret', config.clientSecret);
|
||||
}
|
||||
|
||||
// Add resource parameter for MCP OAuth spec compliance
|
||||
params.append(
|
||||
'resource',
|
||||
OAuthUtils.buildResourceParameter(config.tokenUrl!),
|
||||
);
|
||||
|
||||
const response = await fetch(config.tokenUrl!, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: params.toString(),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Token exchange failed: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
return (await response.json()) as OAuthTokenResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh an access token using a refresh token.
|
||||
*
|
||||
* @param config OAuth configuration
|
||||
* @param refreshToken The refresh token
|
||||
* @returns The new token response
|
||||
*/
|
||||
static async refreshAccessToken(
|
||||
config: MCPOAuthConfig,
|
||||
refreshToken: string,
|
||||
tokenUrl: string,
|
||||
): Promise<OAuthTokenResponse> {
|
||||
const params = new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: refreshToken,
|
||||
client_id: config.clientId!,
|
||||
});
|
||||
|
||||
if (config.clientSecret) {
|
||||
params.append('client_secret', config.clientSecret);
|
||||
}
|
||||
|
||||
if (config.scopes && config.scopes.length > 0) {
|
||||
params.append('scope', config.scopes.join(' '));
|
||||
}
|
||||
|
||||
// Add resource parameter for MCP OAuth spec compliance
|
||||
params.append('resource', OAuthUtils.buildResourceParameter(tokenUrl));
|
||||
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: params.toString(),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Token refresh failed: ${response.status} - ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
return (await response.json()) as OAuthTokenResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the full OAuth authorization code flow with PKCE.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @param config OAuth configuration
|
||||
* @param mcpServerUrl Optional MCP server URL for OAuth discovery
|
||||
* @returns The obtained OAuth token
|
||||
*/
|
||||
static async authenticate(
|
||||
serverName: string,
|
||||
config: MCPOAuthConfig,
|
||||
mcpServerUrl?: string,
|
||||
): Promise<MCPOAuthToken> {
|
||||
// If no authorization URL is provided, try to discover OAuth configuration
|
||||
if (!config.authorizationUrl && mcpServerUrl) {
|
||||
console.log(
|
||||
'No authorization URL provided, attempting OAuth discovery...',
|
||||
);
|
||||
|
||||
// For SSE URLs, first check if authentication is required
|
||||
if (OAuthUtils.isSSEEndpoint(mcpServerUrl)) {
|
||||
try {
|
||||
const response = await fetch(mcpServerUrl, {
|
||||
method: 'HEAD',
|
||||
headers: {
|
||||
Accept: 'text/event-stream',
|
||||
},
|
||||
});
|
||||
|
||||
if (response.status === 401 || response.status === 307) {
|
||||
const wwwAuthenticate = response.headers.get('www-authenticate');
|
||||
if (wwwAuthenticate) {
|
||||
const discoveredConfig =
|
||||
await OAuthUtils.discoverOAuthFromWWWAuthenticate(
|
||||
wwwAuthenticate,
|
||||
);
|
||||
if (discoveredConfig) {
|
||||
config = {
|
||||
...config,
|
||||
...discoveredConfig,
|
||||
scopes: discoveredConfig.scopes || config.scopes || [],
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.debug(
|
||||
`Failed to check SSE endpoint for authentication requirements: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If we still don't have OAuth config, try the standard discovery
|
||||
if (!config.authorizationUrl) {
|
||||
const discoveredConfig =
|
||||
await this.discoverOAuthFromMCPServer(mcpServerUrl);
|
||||
if (discoveredConfig) {
|
||||
config = { ...config, ...discoveredConfig };
|
||||
console.log('OAuth configuration discovered successfully');
|
||||
} else {
|
||||
throw new Error(
|
||||
'Failed to discover OAuth configuration from MCP server',
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no client ID is provided, try dynamic client registration
|
||||
if (!config.clientId) {
|
||||
// Extract server URL from authorization URL
|
||||
if (!config.authorizationUrl) {
|
||||
throw new Error(
|
||||
'Cannot perform dynamic registration without authorization URL',
|
||||
);
|
||||
}
|
||||
|
||||
const authUrl = new URL(config.authorizationUrl);
|
||||
const serverUrl = `${authUrl.protocol}//${authUrl.host}`;
|
||||
|
||||
console.log(
|
||||
'No client ID provided, attempting dynamic client registration...',
|
||||
);
|
||||
|
||||
// Get the authorization server metadata for registration
|
||||
const authServerMetadataUrl = new URL(
|
||||
'/.well-known/oauth-authorization-server',
|
||||
serverUrl,
|
||||
).toString();
|
||||
|
||||
const authServerMetadata =
|
||||
await OAuthUtils.fetchAuthorizationServerMetadata(
|
||||
authServerMetadataUrl,
|
||||
);
|
||||
if (!authServerMetadata) {
|
||||
throw new Error(
|
||||
'Failed to fetch authorization server metadata for client registration',
|
||||
);
|
||||
}
|
||||
|
||||
// Register client if registration endpoint is available
|
||||
if (authServerMetadata.registration_endpoint) {
|
||||
const clientRegistration = await this.registerClient(
|
||||
authServerMetadata.registration_endpoint,
|
||||
config,
|
||||
);
|
||||
|
||||
config.clientId = clientRegistration.client_id;
|
||||
if (clientRegistration.client_secret) {
|
||||
config.clientSecret = clientRegistration.client_secret;
|
||||
}
|
||||
|
||||
console.log('Dynamic client registration successful');
|
||||
} else {
|
||||
throw new Error(
|
||||
'No client ID provided and dynamic registration not supported',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if (!config.clientId || !config.authorizationUrl || !config.tokenUrl) {
|
||||
throw new Error(
|
||||
'Missing required OAuth configuration after discovery and registration',
|
||||
);
|
||||
}
|
||||
|
||||
// Generate PKCE parameters
|
||||
const pkceParams = this.generatePKCEParams();
|
||||
|
||||
// Build authorization URL
|
||||
const authUrl = this.buildAuthorizationUrl(config, pkceParams);
|
||||
|
||||
console.log('\nOpening browser for OAuth authentication...');
|
||||
console.log('If the browser does not open, please visit:');
|
||||
console.log('');
|
||||
|
||||
// Get terminal width or default to 80
|
||||
const terminalWidth = process.stdout.columns || 80;
|
||||
const separatorLength = Math.min(terminalWidth - 2, 80);
|
||||
const separator = '━'.repeat(separatorLength);
|
||||
|
||||
console.log(separator);
|
||||
console.log(
|
||||
'COPY THE ENTIRE URL BELOW (select all text between the lines):',
|
||||
);
|
||||
console.log(separator);
|
||||
console.log(authUrl);
|
||||
console.log(separator);
|
||||
console.log('');
|
||||
console.log(
|
||||
'💡 TIP: Triple-click to select the entire URL, then copy and paste it into your browser.',
|
||||
);
|
||||
console.log(
|
||||
'⚠️ Make sure to copy the COMPLETE URL - it may wrap across multiple lines.',
|
||||
);
|
||||
console.log('');
|
||||
|
||||
// Start callback server
|
||||
const callbackPromise = this.startCallbackServer(pkceParams.state);
|
||||
|
||||
// Open browser
|
||||
try {
|
||||
await open(authUrl);
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to open browser automatically:',
|
||||
getErrorMessage(error),
|
||||
);
|
||||
}
|
||||
|
||||
// Wait for callback
|
||||
const { code } = await callbackPromise;
|
||||
|
||||
console.log('\nAuthorization code received, exchanging for tokens...');
|
||||
|
||||
// Exchange code for tokens
|
||||
const tokenResponse = await this.exchangeCodeForToken(
|
||||
config,
|
||||
code,
|
||||
pkceParams.codeVerifier,
|
||||
);
|
||||
|
||||
// Convert to our token format
|
||||
const token: MCPOAuthToken = {
|
||||
accessToken: tokenResponse.access_token,
|
||||
tokenType: tokenResponse.token_type,
|
||||
refreshToken: tokenResponse.refresh_token,
|
||||
scope: tokenResponse.scope,
|
||||
};
|
||||
|
||||
if (tokenResponse.expires_in) {
|
||||
token.expiresAt = Date.now() + tokenResponse.expires_in * 1000;
|
||||
}
|
||||
|
||||
// Save token
|
||||
try {
|
||||
await MCPOAuthTokenStorage.saveToken(
|
||||
serverName,
|
||||
token,
|
||||
config.clientId,
|
||||
config.tokenUrl,
|
||||
);
|
||||
console.log('Authentication successful! Token saved.');
|
||||
|
||||
// Verify token was saved
|
||||
const savedToken = await MCPOAuthTokenStorage.getToken(serverName);
|
||||
if (savedToken) {
|
||||
console.log(
|
||||
`Token verification successful: ${savedToken.token.accessToken.substring(0, 20)}...`,
|
||||
);
|
||||
} else {
|
||||
console.error('Token verification failed: token not found after save');
|
||||
}
|
||||
} catch (saveError) {
|
||||
console.error(`Failed to save token: ${getErrorMessage(saveError)}`);
|
||||
throw saveError;
|
||||
}
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a valid access token for an MCP server, refreshing if necessary.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @param config OAuth configuration
|
||||
* @returns A valid access token or null if not authenticated
|
||||
*/
|
||||
static async getValidToken(
|
||||
serverName: string,
|
||||
config: MCPOAuthConfig,
|
||||
): Promise<string | null> {
|
||||
console.debug(`Getting valid token for server: ${serverName}`);
|
||||
const credentials = await MCPOAuthTokenStorage.getToken(serverName);
|
||||
|
||||
if (!credentials) {
|
||||
console.debug(`No credentials found for server: ${serverName}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
const { token } = credentials;
|
||||
console.debug(
|
||||
`Found token for server: ${serverName}, expired: ${MCPOAuthTokenStorage.isTokenExpired(token)}`,
|
||||
);
|
||||
|
||||
// Check if token is expired
|
||||
if (!MCPOAuthTokenStorage.isTokenExpired(token)) {
|
||||
console.debug(`Returning valid token for server: ${serverName}`);
|
||||
return token.accessToken;
|
||||
}
|
||||
|
||||
// Try to refresh if we have a refresh token
|
||||
if (token.refreshToken && config.clientId && credentials.tokenUrl) {
|
||||
try {
|
||||
console.log(`Refreshing expired token for MCP server: ${serverName}`);
|
||||
|
||||
const newTokenResponse = await this.refreshAccessToken(
|
||||
config,
|
||||
token.refreshToken,
|
||||
credentials.tokenUrl,
|
||||
);
|
||||
|
||||
// Update stored token
|
||||
const newToken: MCPOAuthToken = {
|
||||
accessToken: newTokenResponse.access_token,
|
||||
tokenType: newTokenResponse.token_type,
|
||||
refreshToken: newTokenResponse.refresh_token || token.refreshToken,
|
||||
scope: newTokenResponse.scope || token.scope,
|
||||
};
|
||||
|
||||
if (newTokenResponse.expires_in) {
|
||||
newToken.expiresAt = Date.now() + newTokenResponse.expires_in * 1000;
|
||||
}
|
||||
|
||||
await MCPOAuthTokenStorage.saveToken(
|
||||
serverName,
|
||||
newToken,
|
||||
config.clientId,
|
||||
credentials.tokenUrl,
|
||||
);
|
||||
|
||||
return newToken.accessToken;
|
||||
} catch (error) {
|
||||
console.error(`Failed to refresh token: ${getErrorMessage(error)}`);
|
||||
// Remove invalid token
|
||||
await MCPOAuthTokenStorage.removeToken(serverName);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,325 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import {
|
||||
MCPOAuthTokenStorage,
|
||||
MCPOAuthToken,
|
||||
MCPOAuthCredentials,
|
||||
} from './oauth-token-storage.js';
|
||||
|
||||
// Mock file system operations
|
||||
vi.mock('node:fs', () => ({
|
||||
promises: {
|
||||
readFile: vi.fn(),
|
||||
writeFile: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
unlink: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('node:os', () => ({
|
||||
homedir: vi.fn(() => '/mock/home'),
|
||||
}));
|
||||
|
||||
describe('MCPOAuthTokenStorage', () => {
|
||||
const mockToken: MCPOAuthToken = {
|
||||
accessToken: 'access_token_123',
|
||||
refreshToken: 'refresh_token_456',
|
||||
tokenType: 'Bearer',
|
||||
scope: 'read write',
|
||||
expiresAt: Date.now() + 3600000, // 1 hour from now
|
||||
};
|
||||
|
||||
const mockCredentials: MCPOAuthCredentials = {
|
||||
serverName: 'test-server',
|
||||
token: mockToken,
|
||||
clientId: 'test-client-id',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('loadTokens', () => {
|
||||
it('should return empty map when token file does not exist', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should load tokens from file successfully', async () => {
|
||||
const tokensArray = [mockCredentials];
|
||||
vi.mocked(fs.readFile).mockResolvedValue(JSON.stringify(tokensArray));
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(1);
|
||||
expect(tokens.get('test-server')).toEqual(mockCredentials);
|
||||
expect(fs.readFile).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
||||
'utf-8',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle corrupted token file gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue('invalid json');
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to load MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle file read errors other than ENOENT', async () => {
|
||||
const error = new Error('Permission denied');
|
||||
vi.mocked(fs.readFile).mockRejectedValue(error);
|
||||
|
||||
const tokens = await MCPOAuthTokenStorage.loadTokens();
|
||||
|
||||
expect(tokens.size).toBe(0);
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to load MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveToken', () => {
|
||||
it('should save token with restricted permissions', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.saveToken(
|
||||
'test-server',
|
||||
mockToken,
|
||||
'client-id',
|
||||
'https://token.url',
|
||||
);
|
||||
|
||||
expect(fs.mkdir).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.gemini'),
|
||||
{ recursive: true },
|
||||
);
|
||||
expect(fs.writeFile).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
||||
expect.stringContaining('test-server'),
|
||||
{ mode: 0o600 },
|
||||
);
|
||||
});
|
||||
|
||||
it('should update existing token for same server', async () => {
|
||||
const existingCredentials = {
|
||||
...mockCredentials,
|
||||
serverName: 'existing-server',
|
||||
};
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([existingCredentials]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
const newToken = { ...mockToken, accessToken: 'new_access_token' };
|
||||
await MCPOAuthTokenStorage.saveToken('existing-server', newToken);
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(writeCall[1] as string);
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].token.accessToken).toBe('new_access_token');
|
||||
expect(savedData[0].serverName).toBe('existing-server');
|
||||
});
|
||||
|
||||
it('should handle write errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
||||
const writeError = new Error('Disk full');
|
||||
vi.mocked(fs.writeFile).mockRejectedValue(writeError);
|
||||
|
||||
await expect(
|
||||
MCPOAuthTokenStorage.saveToken('test-server', mockToken),
|
||||
).rejects.toThrow('Disk full');
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to save MCP OAuth token'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToken', () => {
|
||||
it('should return token for existing server', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthTokenStorage.getToken('test-server');
|
||||
|
||||
expect(result).toEqual(mockCredentials);
|
||||
});
|
||||
|
||||
it('should return null for non-existent server', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
const result = await MCPOAuthTokenStorage.getToken('non-existent');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null when no tokens file exists', async () => {
|
||||
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
const result = await MCPOAuthTokenStorage.getToken('test-server');
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeToken', () => {
|
||||
it('should remove token for specific server', async () => {
|
||||
const credentials1 = { ...mockCredentials, serverName: 'server1' };
|
||||
const credentials2 = { ...mockCredentials, serverName: 'server2' };
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([credentials1, credentials2]),
|
||||
);
|
||||
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('server1');
|
||||
|
||||
const writeCall = vi.mocked(fs.writeFile).mock.calls[0];
|
||||
const savedData = JSON.parse(writeCall[1] as string);
|
||||
|
||||
expect(savedData).toHaveLength(1);
|
||||
expect(savedData[0].serverName).toBe('server2');
|
||||
});
|
||||
|
||||
it('should remove token file when no tokens remain', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('test-server');
|
||||
|
||||
expect(fs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
||||
);
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle removal of non-existent token gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('non-existent');
|
||||
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
expect(fs.unlink).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle file operation errors gracefully', async () => {
|
||||
vi.mocked(fs.readFile).mockResolvedValue(
|
||||
JSON.stringify([mockCredentials]),
|
||||
);
|
||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||
|
||||
await MCPOAuthTokenStorage.removeToken('test-server');
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to remove MCP OAuth token'),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTokenExpired', () => {
|
||||
it('should return false for token without expiry', () => {
|
||||
const tokenWithoutExpiry = { ...mockToken };
|
||||
delete tokenWithoutExpiry.expiresAt;
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(tokenWithoutExpiry);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for valid token', () => {
|
||||
const futureToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() + 3600000, // 1 hour from now
|
||||
};
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(futureToken);
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for expired token', () => {
|
||||
const expiredToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() - 3600000, // 1 hour ago
|
||||
};
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(expiredToken);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for token expiring within buffer time', () => {
|
||||
const soonToExpireToken = {
|
||||
...mockToken,
|
||||
expiresAt: Date.now() + 60000, // 1 minute from now (within 5-minute buffer)
|
||||
};
|
||||
|
||||
const result = MCPOAuthTokenStorage.isTokenExpired(soonToExpireToken);
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearAllTokens', () => {
|
||||
it('should remove token file successfully', async () => {
|
||||
vi.mocked(fs.unlink).mockResolvedValue(undefined);
|
||||
|
||||
await MCPOAuthTokenStorage.clearAllTokens();
|
||||
|
||||
expect(fs.unlink).toHaveBeenCalledWith(
|
||||
path.join('/mock/home', '.gemini', 'mcp-oauth-tokens.json'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle non-existent file gracefully', async () => {
|
||||
vi.mocked(fs.unlink).mockRejectedValue({ code: 'ENOENT' });
|
||||
|
||||
await MCPOAuthTokenStorage.clearAllTokens();
|
||||
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle other file errors gracefully', async () => {
|
||||
vi.mocked(fs.unlink).mockRejectedValue(new Error('Permission denied'));
|
||||
|
||||
await MCPOAuthTokenStorage.clearAllTokens();
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Failed to clear MCP OAuth tokens'),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,205 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { promises as fs } from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
/**
|
||||
* Interface for MCP OAuth tokens.
|
||||
*/
|
||||
export interface MCPOAuthToken {
|
||||
accessToken: string;
|
||||
refreshToken?: string;
|
||||
expiresAt?: number;
|
||||
tokenType: string;
|
||||
scope?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for stored MCP OAuth credentials.
|
||||
*/
|
||||
export interface MCPOAuthCredentials {
|
||||
serverName: string;
|
||||
token: MCPOAuthToken;
|
||||
clientId?: string;
|
||||
tokenUrl?: string;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Class for managing MCP OAuth token storage and retrieval.
|
||||
*/
|
||||
export class MCPOAuthTokenStorage {
|
||||
private static readonly TOKEN_FILE = 'mcp-oauth-tokens.json';
|
||||
private static readonly CONFIG_DIR = '.gemini';
|
||||
|
||||
/**
|
||||
* Get the path to the token storage file.
|
||||
*
|
||||
* @returns The full path to the token storage file
|
||||
*/
|
||||
private static getTokenFilePath(): string {
|
||||
const homeDir = os.homedir();
|
||||
return path.join(homeDir, this.CONFIG_DIR, this.TOKEN_FILE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure the config directory exists.
|
||||
*/
|
||||
private static async ensureConfigDir(): Promise<void> {
|
||||
const configDir = path.dirname(this.getTokenFilePath());
|
||||
await fs.mkdir(configDir, { recursive: true });
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all stored MCP OAuth tokens.
|
||||
*
|
||||
* @returns A map of server names to credentials
|
||||
*/
|
||||
static async loadTokens(): Promise<Map<string, MCPOAuthCredentials>> {
|
||||
const tokenMap = new Map<string, MCPOAuthCredentials>();
|
||||
|
||||
try {
|
||||
const tokenFile = this.getTokenFilePath();
|
||||
const data = await fs.readFile(tokenFile, 'utf-8');
|
||||
const tokens = JSON.parse(data) as MCPOAuthCredentials[];
|
||||
|
||||
for (const credential of tokens) {
|
||||
tokenMap.set(credential.serverName, credential);
|
||||
}
|
||||
} catch (error) {
|
||||
// File doesn't exist or is invalid, return empty map
|
||||
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
|
||||
console.error(
|
||||
`Failed to load MCP OAuth tokens: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return tokenMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Save a token for a specific MCP server.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @param token The OAuth token to save
|
||||
* @param clientId Optional client ID used for this token
|
||||
* @param tokenUrl Optional token URL used for this token
|
||||
*/
|
||||
static async saveToken(
|
||||
serverName: string,
|
||||
token: MCPOAuthToken,
|
||||
clientId?: string,
|
||||
tokenUrl?: string,
|
||||
): Promise<void> {
|
||||
await this.ensureConfigDir();
|
||||
|
||||
const tokens = await this.loadTokens();
|
||||
|
||||
const credential: MCPOAuthCredentials = {
|
||||
serverName,
|
||||
token,
|
||||
clientId,
|
||||
tokenUrl,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
|
||||
tokens.set(serverName, credential);
|
||||
|
||||
const tokenArray = Array.from(tokens.values());
|
||||
const tokenFile = this.getTokenFilePath();
|
||||
|
||||
try {
|
||||
await fs.writeFile(
|
||||
tokenFile,
|
||||
JSON.stringify(tokenArray, null, 2),
|
||||
{ mode: 0o600 }, // Restrict file permissions
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to save MCP OAuth token: ${getErrorMessage(error)}`,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a token for a specific MCP server.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
* @returns The stored credentials or null if not found
|
||||
*/
|
||||
static async getToken(
|
||||
serverName: string,
|
||||
): Promise<MCPOAuthCredentials | null> {
|
||||
const tokens = await this.loadTokens();
|
||||
return tokens.get(serverName) || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a token for a specific MCP server.
|
||||
*
|
||||
* @param serverName The name of the MCP server
|
||||
*/
|
||||
static async removeToken(serverName: string): Promise<void> {
|
||||
const tokens = await this.loadTokens();
|
||||
|
||||
if (tokens.delete(serverName)) {
|
||||
const tokenArray = Array.from(tokens.values());
|
||||
const tokenFile = this.getTokenFilePath();
|
||||
|
||||
try {
|
||||
if (tokenArray.length === 0) {
|
||||
// Remove file if no tokens left
|
||||
await fs.unlink(tokenFile);
|
||||
} else {
|
||||
await fs.writeFile(tokenFile, JSON.stringify(tokenArray, null, 2), {
|
||||
mode: 0o600,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Failed to remove MCP OAuth token: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a token is expired.
|
||||
*
|
||||
* @param token The token to check
|
||||
* @returns True if the token is expired
|
||||
*/
|
||||
static isTokenExpired(token: MCPOAuthToken): boolean {
|
||||
if (!token.expiresAt) {
|
||||
return false; // No expiry, assume valid
|
||||
}
|
||||
|
||||
// Add a 5-minute buffer to account for clock skew
|
||||
const bufferMs = 5 * 60 * 1000;
|
||||
return Date.now() + bufferMs >= token.expiresAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all stored MCP OAuth tokens.
|
||||
*/
|
||||
static async clearAllTokens(): Promise<void> {
|
||||
try {
|
||||
const tokenFile = this.getTokenFilePath();
|
||||
await fs.unlink(tokenFile);
|
||||
} catch (error) {
|
||||
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
|
||||
console.error(
|
||||
`Failed to clear MCP OAuth tokens: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,206 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
OAuthUtils,
|
||||
OAuthAuthorizationServerMetadata,
|
||||
OAuthProtectedResourceMetadata,
|
||||
} from './oauth-utils.js';
|
||||
|
||||
// Mock fetch globally
|
||||
const mockFetch = vi.fn();
|
||||
global.fetch = mockFetch;
|
||||
|
||||
describe('OAuthUtils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('buildWellKnownUrls', () => {
|
||||
it('should build correct well-known URLs', () => {
|
||||
const urls = OAuthUtils.buildWellKnownUrls('https://example.com/path');
|
||||
expect(urls.protectedResource).toBe(
|
||||
'https://example.com/.well-known/oauth-protected-resource',
|
||||
);
|
||||
expect(urls.authorizationServer).toBe(
|
||||
'https://example.com/.well-known/oauth-authorization-server',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fetchProtectedResourceMetadata', () => {
|
||||
const mockResourceMetadata: OAuthProtectedResourceMetadata = {
|
||||
resource: 'https://api.example.com',
|
||||
authorization_servers: ['https://auth.example.com'],
|
||||
bearer_methods_supported: ['header'],
|
||||
};
|
||||
|
||||
it('should fetch protected resource metadata successfully', async () => {
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockResourceMetadata),
|
||||
});
|
||||
|
||||
const result = await OAuthUtils.fetchProtectedResourceMetadata(
|
||||
'https://example.com/.well-known/oauth-protected-resource',
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockResourceMetadata);
|
||||
});
|
||||
|
||||
it('should return null when fetch fails', async () => {
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
});
|
||||
|
||||
const result = await OAuthUtils.fetchProtectedResourceMetadata(
|
||||
'https://example.com/.well-known/oauth-protected-resource',
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('fetchAuthorizationServerMetadata', () => {
|
||||
const mockAuthServerMetadata: OAuthAuthorizationServerMetadata = {
|
||||
issuer: 'https://auth.example.com',
|
||||
authorization_endpoint: 'https://auth.example.com/authorize',
|
||||
token_endpoint: 'https://auth.example.com/token',
|
||||
scopes_supported: ['read', 'write'],
|
||||
};
|
||||
|
||||
it('should fetch authorization server metadata successfully', async () => {
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(mockAuthServerMetadata),
|
||||
});
|
||||
|
||||
const result = await OAuthUtils.fetchAuthorizationServerMetadata(
|
||||
'https://auth.example.com/.well-known/oauth-authorization-server',
|
||||
);
|
||||
|
||||
expect(result).toEqual(mockAuthServerMetadata);
|
||||
});
|
||||
|
||||
it('should return null when fetch fails', async () => {
|
||||
mockFetch.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
});
|
||||
|
||||
const result = await OAuthUtils.fetchAuthorizationServerMetadata(
|
||||
'https://auth.example.com/.well-known/oauth-authorization-server',
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('metadataToOAuthConfig', () => {
|
||||
it('should convert metadata to OAuth config', () => {
|
||||
const metadata: OAuthAuthorizationServerMetadata = {
|
||||
issuer: 'https://auth.example.com',
|
||||
authorization_endpoint: 'https://auth.example.com/authorize',
|
||||
token_endpoint: 'https://auth.example.com/token',
|
||||
scopes_supported: ['read', 'write'],
|
||||
};
|
||||
|
||||
const config = OAuthUtils.metadataToOAuthConfig(metadata);
|
||||
|
||||
expect(config).toEqual({
|
||||
authorizationUrl: 'https://auth.example.com/authorize',
|
||||
tokenUrl: 'https://auth.example.com/token',
|
||||
scopes: ['read', 'write'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty scopes', () => {
|
||||
const metadata: OAuthAuthorizationServerMetadata = {
|
||||
issuer: 'https://auth.example.com',
|
||||
authorization_endpoint: 'https://auth.example.com/authorize',
|
||||
token_endpoint: 'https://auth.example.com/token',
|
||||
};
|
||||
|
||||
const config = OAuthUtils.metadataToOAuthConfig(metadata);
|
||||
|
||||
expect(config.scopes).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseWWWAuthenticateHeader', () => {
|
||||
it('should parse resource metadata URI from WWW-Authenticate header', () => {
|
||||
const header =
|
||||
'Bearer realm="example", resource_metadata_uri="https://example.com/.well-known/oauth-protected-resource"';
|
||||
const result = OAuthUtils.parseWWWAuthenticateHeader(header);
|
||||
expect(result).toBe(
|
||||
'https://example.com/.well-known/oauth-protected-resource',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return null when no resource metadata URI is found', () => {
|
||||
const header = 'Bearer realm="example"';
|
||||
const result = OAuthUtils.parseWWWAuthenticateHeader(header);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('extractBaseUrl', () => {
|
||||
it('should extract base URL from MCP server URL', () => {
|
||||
const result = OAuthUtils.extractBaseUrl('https://example.com/mcp/v1');
|
||||
expect(result).toBe('https://example.com');
|
||||
});
|
||||
|
||||
it('should handle URLs with ports', () => {
|
||||
const result = OAuthUtils.extractBaseUrl(
|
||||
'https://example.com:8080/mcp/v1',
|
||||
);
|
||||
expect(result).toBe('https://example.com:8080');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isSSEEndpoint', () => {
|
||||
it('should return true for SSE endpoints', () => {
|
||||
expect(OAuthUtils.isSSEEndpoint('https://example.com/sse')).toBe(true);
|
||||
expect(OAuthUtils.isSSEEndpoint('https://example.com/api/v1/sse')).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true for non-MCP endpoints', () => {
|
||||
expect(OAuthUtils.isSSEEndpoint('https://example.com/api')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for MCP endpoints', () => {
|
||||
expect(OAuthUtils.isSSEEndpoint('https://example.com/mcp')).toBe(false);
|
||||
expect(OAuthUtils.isSSEEndpoint('https://example.com/api/mcp/v1')).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildResourceParameter', () => {
|
||||
it('should build resource parameter from endpoint URL', () => {
|
||||
const result = OAuthUtils.buildResourceParameter(
|
||||
'https://example.com/oauth/token',
|
||||
);
|
||||
expect(result).toBe('https://example.com');
|
||||
});
|
||||
|
||||
it('should handle URLs with ports', () => {
|
||||
const result = OAuthUtils.buildResourceParameter(
|
||||
'https://example.com:8080/oauth/token',
|
||||
);
|
||||
expect(result).toBe('https://example.com:8080');
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,285 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { MCPOAuthConfig } from './oauth-provider.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
/**
|
||||
* OAuth authorization server metadata as per RFC 8414.
|
||||
*/
|
||||
export interface OAuthAuthorizationServerMetadata {
|
||||
issuer: string;
|
||||
authorization_endpoint: string;
|
||||
token_endpoint: string;
|
||||
token_endpoint_auth_methods_supported?: string[];
|
||||
revocation_endpoint?: string;
|
||||
revocation_endpoint_auth_methods_supported?: string[];
|
||||
registration_endpoint?: string;
|
||||
response_types_supported?: string[];
|
||||
grant_types_supported?: string[];
|
||||
code_challenge_methods_supported?: string[];
|
||||
scopes_supported?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth protected resource metadata as per RFC 9728.
|
||||
*/
|
||||
export interface OAuthProtectedResourceMetadata {
|
||||
resource: string;
|
||||
authorization_servers?: string[];
|
||||
bearer_methods_supported?: string[];
|
||||
resource_documentation?: string;
|
||||
resource_signing_alg_values_supported?: string[];
|
||||
resource_encryption_alg_values_supported?: string[];
|
||||
resource_encryption_enc_values_supported?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class for common OAuth operations.
|
||||
*/
|
||||
export class OAuthUtils {
|
||||
/**
|
||||
* Construct well-known OAuth endpoint URLs.
|
||||
*/
|
||||
static buildWellKnownUrls(baseUrl: string) {
|
||||
const serverUrl = new URL(baseUrl);
|
||||
const base = `${serverUrl.protocol}//${serverUrl.host}`;
|
||||
|
||||
return {
|
||||
protectedResource: new URL(
|
||||
'/.well-known/oauth-protected-resource',
|
||||
base,
|
||||
).toString(),
|
||||
authorizationServer: new URL(
|
||||
'/.well-known/oauth-authorization-server',
|
||||
base,
|
||||
).toString(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch OAuth protected resource metadata.
|
||||
*
|
||||
* @param resourceMetadataUrl The protected resource metadata URL
|
||||
* @returns The protected resource metadata or null if not available
|
||||
*/
|
||||
static async fetchProtectedResourceMetadata(
|
||||
resourceMetadataUrl: string,
|
||||
): Promise<OAuthProtectedResourceMetadata | null> {
|
||||
try {
|
||||
const response = await fetch(resourceMetadataUrl);
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
}
|
||||
return (await response.json()) as OAuthProtectedResourceMetadata;
|
||||
} catch (error) {
|
||||
console.debug(
|
||||
`Failed to fetch protected resource metadata from ${resourceMetadataUrl}: ${getErrorMessage(error)}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch OAuth authorization server metadata.
|
||||
*
|
||||
* @param authServerMetadataUrl The authorization server metadata URL
|
||||
* @returns The authorization server metadata or null if not available
|
||||
*/
|
||||
static async fetchAuthorizationServerMetadata(
|
||||
authServerMetadataUrl: string,
|
||||
): Promise<OAuthAuthorizationServerMetadata | null> {
|
||||
try {
|
||||
const response = await fetch(authServerMetadataUrl);
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
}
|
||||
return (await response.json()) as OAuthAuthorizationServerMetadata;
|
||||
} catch (error) {
|
||||
console.debug(
|
||||
`Failed to fetch authorization server metadata from ${authServerMetadataUrl}: ${getErrorMessage(error)}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert authorization server metadata to OAuth configuration.
|
||||
*
|
||||
* @param metadata The authorization server metadata
|
||||
* @returns The OAuth configuration
|
||||
*/
|
||||
static metadataToOAuthConfig(
|
||||
metadata: OAuthAuthorizationServerMetadata,
|
||||
): MCPOAuthConfig {
|
||||
return {
|
||||
authorizationUrl: metadata.authorization_endpoint,
|
||||
tokenUrl: metadata.token_endpoint,
|
||||
scopes: metadata.scopes_supported || [],
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover OAuth configuration using the standard well-known endpoints.
|
||||
*
|
||||
* @param serverUrl The base URL of the server
|
||||
* @returns The discovered OAuth configuration or null if not available
|
||||
*/
|
||||
static async discoverOAuthConfig(
|
||||
serverUrl: string,
|
||||
): Promise<MCPOAuthConfig | null> {
|
||||
try {
|
||||
const wellKnownUrls = this.buildWellKnownUrls(serverUrl);
|
||||
|
||||
// First, try to get the protected resource metadata
|
||||
const resourceMetadata = await this.fetchProtectedResourceMetadata(
|
||||
wellKnownUrls.protectedResource,
|
||||
);
|
||||
|
||||
if (resourceMetadata?.authorization_servers?.length) {
|
||||
// Use the first authorization server
|
||||
const authServerUrl = resourceMetadata.authorization_servers[0];
|
||||
const authServerMetadataUrl = new URL(
|
||||
'/.well-known/oauth-authorization-server',
|
||||
authServerUrl,
|
||||
).toString();
|
||||
|
||||
const authServerMetadata = await this.fetchAuthorizationServerMetadata(
|
||||
authServerMetadataUrl,
|
||||
);
|
||||
|
||||
if (authServerMetadata) {
|
||||
const config = this.metadataToOAuthConfig(authServerMetadata);
|
||||
if (authServerMetadata.registration_endpoint) {
|
||||
console.log(
|
||||
'Dynamic client registration is supported at:',
|
||||
authServerMetadata.registration_endpoint,
|
||||
);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try /.well-known/oauth-authorization-server at the base URL
|
||||
console.debug(
|
||||
`Trying OAuth discovery fallback at ${wellKnownUrls.authorizationServer}`,
|
||||
);
|
||||
const authServerMetadata = await this.fetchAuthorizationServerMetadata(
|
||||
wellKnownUrls.authorizationServer,
|
||||
);
|
||||
|
||||
if (authServerMetadata) {
|
||||
const config = this.metadataToOAuthConfig(authServerMetadata);
|
||||
if (authServerMetadata.registration_endpoint) {
|
||||
console.log(
|
||||
'Dynamic client registration is supported at:',
|
||||
authServerMetadata.registration_endpoint,
|
||||
);
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
console.debug(
|
||||
`Failed to discover OAuth configuration: ${getErrorMessage(error)}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse WWW-Authenticate header to extract OAuth information.
|
||||
*
|
||||
* @param header The WWW-Authenticate header value
|
||||
* @returns The resource metadata URI if found
|
||||
*/
|
||||
static parseWWWAuthenticateHeader(header: string): string | null {
|
||||
// Parse Bearer realm and resource_metadata_uri
|
||||
const match = header.match(/resource_metadata_uri="([^"]+)"/);
|
||||
if (match) {
|
||||
return match[1];
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover OAuth configuration from WWW-Authenticate header.
|
||||
*
|
||||
* @param wwwAuthenticate The WWW-Authenticate header value
|
||||
* @returns The discovered OAuth configuration or null if not available
|
||||
*/
|
||||
static async discoverOAuthFromWWWAuthenticate(
|
||||
wwwAuthenticate: string,
|
||||
): Promise<MCPOAuthConfig | null> {
|
||||
const resourceMetadataUri =
|
||||
this.parseWWWAuthenticateHeader(wwwAuthenticate);
|
||||
if (!resourceMetadataUri) {
|
||||
return null;
|
||||
}
|
||||
|
||||
console.log(
|
||||
`Found resource metadata URI from www-authenticate header: ${resourceMetadataUri}`,
|
||||
);
|
||||
|
||||
const resourceMetadata =
|
||||
await this.fetchProtectedResourceMetadata(resourceMetadataUri);
|
||||
if (!resourceMetadata?.authorization_servers?.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const authServerUrl = resourceMetadata.authorization_servers[0];
|
||||
const authServerMetadataUrl = new URL(
|
||||
'/.well-known/oauth-authorization-server',
|
||||
authServerUrl,
|
||||
).toString();
|
||||
|
||||
const authServerMetadata = await this.fetchAuthorizationServerMetadata(
|
||||
authServerMetadataUrl,
|
||||
);
|
||||
|
||||
if (authServerMetadata) {
|
||||
console.log(
|
||||
'OAuth configuration discovered successfully from www-authenticate header',
|
||||
);
|
||||
return this.metadataToOAuthConfig(authServerMetadata);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract base URL from an MCP server URL.
|
||||
*
|
||||
* @param mcpServerUrl The MCP server URL
|
||||
* @returns The base URL
|
||||
*/
|
||||
static extractBaseUrl(mcpServerUrl: string): string {
|
||||
const serverUrl = new URL(mcpServerUrl);
|
||||
return `${serverUrl.protocol}//${serverUrl.host}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a URL is an SSE endpoint.
|
||||
*
|
||||
* @param url The URL to check
|
||||
* @returns True if the URL appears to be an SSE endpoint
|
||||
*/
|
||||
static isSSEEndpoint(url: string): boolean {
|
||||
return url.includes('/sse') || !url.includes('/mcp');
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a resource parameter for OAuth requests.
|
||||
*
|
||||
* @param endpointUrl The endpoint URL
|
||||
* @returns The resource parameter value
|
||||
*/
|
||||
static buildResourceParameter(endpointUrl: string): string {
|
||||
const url = new URL(endpointUrl);
|
||||
return `${url.protocol}//${url.host}`;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue