diff --git a/packages/core/src/ide/ide-client.test.ts b/packages/core/src/ide/ide-client.test.ts index f061156d..7ad71ba3 100644 --- a/packages/core/src/ide/ide-client.test.ts +++ b/packages/core/src/ide/ide-client.test.ts @@ -4,75 +4,224 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; -import * as path from 'path'; -import { IdeClient } from './ide-client.js'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mocked, +} from 'vitest'; +import { IdeClient, IDEConnectionStatus } from './ide-client.js'; +import * as fs from 'node:fs'; +import { getIdeProcessId } from './process-utils.js'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; +import { + detectIde, + DetectedIde, + getIdeInfo, + type IdeInfo, +} from './detect-ide.js'; +import * as os from 'node:os'; +import * as path from 'node:path'; -describe('IdeClient.validateWorkspacePath', () => { - it('should return valid if cwd is a subpath of the IDE workspace path', () => { - const result = IdeClient.validateWorkspacePath( - '/Users/person/gemini-cli', - 'VS Code', - '/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(true); +vi.mock('node:fs', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...(actual as object), + promises: { + readFile: vi.fn(), + }, + realpathSync: (p: string) => p, + existsSync: () => false, + }; +}); +vi.mock('./process-utils.js'); +vi.mock('@modelcontextprotocol/sdk/client/index.js'); +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js'); +vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); +vi.mock('./detect-ide.js'); +vi.mock('node:os'); + +describe('IdeClient', () => { + let mockClient: Mocked; + let mockHttpTransport: Mocked; + let mockStdioTransport: Mocked; + + beforeEach(() => { + // Reset singleton instance for test isolation + (IdeClient as unknown as { instance: IdeClient | undefined }).instance = + undefined; + + // Mock environment variables + process.env['GEMINI_CLI_IDE_WORKSPACE_PATH'] = '/test/workspace'; + delete process.env['GEMINI_CLI_IDE_SERVER_PORT']; + delete process.env['GEMINI_CLI_IDE_SERVER_STDIO_COMMAND']; + delete process.env['GEMINI_CLI_IDE_SERVER_STDIO_ARGS']; + + // Mock dependencies + vi.spyOn(process, 'cwd').mockReturnValue('/test/workspace/sub-dir'); + vi.mocked(detectIde).mockReturnValue(DetectedIde.VSCode); + vi.mocked(getIdeInfo).mockReturnValue({ + displayName: 'VS Code', + } as IdeInfo); + vi.mocked(getIdeProcessId).mockResolvedValue(12345); + vi.mocked(os.tmpdir).mockReturnValue('/tmp'); + + // Mock MCP client and transports + mockClient = { + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn(), + setNotificationHandler: vi.fn(), + callTool: vi.fn(), + } as unknown as Mocked; + mockHttpTransport = { + close: vi.fn(), + } as unknown as Mocked; + mockStdioTransport = { + close: vi.fn(), + } as unknown as Mocked; + + vi.mocked(Client).mockReturnValue(mockClient); + vi.mocked(StreamableHTTPClientTransport).mockReturnValue(mockHttpTransport); + vi.mocked(StdioClientTransport).mockReturnValue(mockStdioTransport); }); - it('should return invalid if GEMINI_CLI_IDE_WORKSPACE_PATH is undefined', () => { - const result = IdeClient.validateWorkspacePath( - undefined, - 'VS Code', - '/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(false); - expect(result.error).toContain('Failed to connect'); + afterEach(() => { + vi.restoreAllMocks(); }); - it('should return invalid if GEMINI_CLI_IDE_WORKSPACE_PATH is empty', () => { - const result = IdeClient.validateWorkspacePath( - '', - 'VS Code', - '/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(false); - expect(result.error).toContain('please open a workspace folder'); - }); + describe('connect', () => { + it('should connect using HTTP when port is provided in config file', async () => { + const config = { port: '8080' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); - it('should return invalid if cwd is not within the IDE workspace path', () => { - const result = IdeClient.validateWorkspacePath( - '/some/other/path', - 'VS Code', - '/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(false); - expect(result.error).toContain('Directory mismatch'); - }); + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); - it('should handle multiple workspace paths and return valid', () => { - const result = IdeClient.validateWorkspacePath( - ['/some/other/path', '/Users/person/gemini-cli'].join(path.delimiter), - 'VS Code', - '/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(true); - }); + expect(fs.promises.readFile).toHaveBeenCalledWith( + path.join('/tmp', 'gemini-ide-server-12345.json'), + 'utf8', + ); + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( + new URL('http://localhost:8080/mcp'), + expect.any(Object), + ); + expect(mockClient.connect).toHaveBeenCalledWith(mockHttpTransport); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + }); - it('should return invalid if cwd is not in any of the multiple workspace paths', () => { - const result = IdeClient.validateWorkspacePath( - ['/some/other/path', '/another/path'].join(path.delimiter), - 'VS Code', - '/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(false); - expect(result.error).toContain('Directory mismatch'); - }); + it('should connect using stdio when stdio config is provided in file', async () => { + const config = { stdio: { command: 'test-cmd', args: ['--foo'] } }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); - it.skipIf(process.platform !== 'win32')('should handle windows paths', () => { - const result = IdeClient.validateWorkspacePath( - 'c:/some/other/path;d:/Users/person/gemini-cli', - 'VS Code', - 'd:/Users/person/gemini-cli/sub-dir', - ); - expect(result.isValid).toBe(true); + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'test-cmd', + args: ['--foo'], + }); + expect(mockClient.connect).toHaveBeenCalledWith(mockStdioTransport); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + }); + + it('should prioritize port over stdio when both are in config file', async () => { + const config = { + port: '8080', + stdio: { command: 'test-cmd', args: ['--foo'] }, + }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); + + expect(StreamableHTTPClientTransport).toHaveBeenCalled(); + expect(StdioClientTransport).not.toHaveBeenCalled(); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + }); + + it('should connect using HTTP when port is provided in environment variables', async () => { + vi.mocked(fs.promises.readFile).mockRejectedValue( + new Error('File not found'), + ); + process.env['GEMINI_CLI_IDE_SERVER_PORT'] = '9090'; + + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); + + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( + new URL('http://localhost:9090/mcp'), + expect.any(Object), + ); + expect(mockClient.connect).toHaveBeenCalledWith(mockHttpTransport); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + }); + + it('should connect using stdio when stdio config is in environment variables', async () => { + vi.mocked(fs.promises.readFile).mockRejectedValue( + new Error('File not found'), + ); + process.env['GEMINI_CLI_IDE_SERVER_STDIO_COMMAND'] = 'env-cmd'; + process.env['GEMINI_CLI_IDE_SERVER_STDIO_ARGS'] = '["--bar"]'; + + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: 'env-cmd', + args: ['--bar'], + }); + expect(mockClient.connect).toHaveBeenCalledWith(mockStdioTransport); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + }); + + it('should prioritize file config over environment variables', async () => { + const config = { port: '8080' }; + vi.mocked(fs.promises.readFile).mockResolvedValue(JSON.stringify(config)); + process.env['GEMINI_CLI_IDE_SERVER_PORT'] = '9090'; + + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); + + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( + new URL('http://localhost:8080/mcp'), + expect.any(Object), + ); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Connected, + ); + }); + + it('should be disconnected if no config is found', async () => { + vi.mocked(fs.promises.readFile).mockRejectedValue( + new Error('File not found'), + ); + + const ideClient = IdeClient.getInstance(); + await ideClient.connect(); + + expect(StreamableHTTPClientTransport).not.toHaveBeenCalled(); + expect(StdioClientTransport).not.toHaveBeenCalled(); + expect(ideClient.getConnectionStatus().status).toBe( + IDEConnectionStatus.Disconnected, + ); + expect(ideClient.getConnectionStatus().details).toContain( + 'Failed to connect', + ); + }); }); }); diff --git a/packages/core/src/ide/ide-client.ts b/packages/core/src/ide/ide-client.ts index d6b1d0d2..0f8536aa 100644 --- a/packages/core/src/ide/ide-client.ts +++ b/packages/core/src/ide/ide-client.ts @@ -18,6 +18,7 @@ import { import { getIdeProcessId } from './process-utils.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import * as os from 'node:os'; import * as path from 'node:path'; import { EnvHttpProxyAgent } from 'undici'; @@ -40,6 +41,16 @@ export enum IDEConnectionStatus { Connecting = 'connecting', } +type StdioConfig = { + command: string; + args: string[]; +}; + +type ConnectionConfig = { + port?: string; + stdio?: StdioConfig; +}; + function getRealPath(path: string): string { try { return fs.realpathSync(path); @@ -104,9 +115,9 @@ export class IdeClient { this.setState(IDEConnectionStatus.Connecting); - const ideInfoFromFile = await this.getIdeInfoFromFile(); + const configFromFile = await this.getConnectionConfigFromFile(); const workspacePath = - ideInfoFromFile.workspacePath ?? + configFromFile?.workspacePath ?? process.env['GEMINI_CLI_IDE_WORKSPACE_PATH']; const { isValid, error } = IdeClient.validateWorkspacePath( @@ -120,17 +131,36 @@ export class IdeClient { return; } - const portFromFile = ideInfoFromFile.port; - if (portFromFile) { - const connected = await this.establishConnection(portFromFile); - if (connected) { - return; + if (configFromFile) { + if (configFromFile.port) { + const connected = await this.establishHttpConnection( + configFromFile.port, + ); + if (connected) { + return; + } + } + if (configFromFile.stdio) { + const connected = await this.establishStdioConnection( + configFromFile.stdio, + ); + if (connected) { + return; + } } } const portFromEnv = this.getPortFromEnv(); if (portFromEnv) { - const connected = await this.establishConnection(portFromEnv); + const connected = await this.establishHttpConnection(portFromEnv); + if (connected) { + return; + } + } + + const stdioConfigFromEnv = this.getStdioConfigFromEnv(); + if (stdioConfigFromEnv) { + const connected = await this.establishStdioConnection(stdioConfigFromEnv); if (connected) { return; } @@ -316,10 +346,35 @@ export class IdeClient { return port; } - private async getIdeInfoFromFile(): Promise<{ - port?: string; - workspacePath?: string; - }> { + private getStdioConfigFromEnv(): StdioConfig | undefined { + const command = process.env['GEMINI_CLI_IDE_SERVER_STDIO_COMMAND']; + if (!command) { + return undefined; + } + + const argsStr = process.env['GEMINI_CLI_IDE_SERVER_STDIO_ARGS']; + let args: string[] = []; + if (argsStr) { + try { + const parsedArgs = JSON.parse(argsStr); + if (Array.isArray(parsedArgs)) { + args = parsedArgs; + } else { + logger.error( + 'GEMINI_CLI_IDE_SERVER_STDIO_ARGS must be a JSON array string.', + ); + } + } catch (e) { + logger.error('Failed to parse GEMINI_CLI_IDE_SERVER_STDIO_ARGS:', e); + } + } + + return { command, args }; + } + + private async getConnectionConfigFromFile(): Promise< + (ConnectionConfig & { workspacePath?: string }) | undefined + > { try { const ideProcessId = await getIdeProcessId(); const portFile = path.join( @@ -327,13 +382,9 @@ export class IdeClient { `gemini-ide-server-${ideProcessId}.json`, ); const portFileContents = await fs.promises.readFile(portFile, 'utf8'); - const ideInfo = JSON.parse(portFileContents); - return { - port: ideInfo?.port?.toString(), - workspacePath: ideInfo?.workspacePath, - }; + return JSON.parse(portFileContents); } catch (_) { - return {}; + return undefined; } } @@ -414,9 +465,10 @@ export class IdeClient { ); } - private async establishConnection(port: string): Promise { + private async establishHttpConnection(port: string): Promise { let transport: StreamableHTTPClientTransport | undefined; try { + logger.debug('Attempting to connect to IDE via HTTP SSE'); this.client = new Client({ name: 'streamable-http-client', // TODO(#3487): use the CLI version here. @@ -443,6 +495,39 @@ export class IdeClient { return false; } } + + private async establishStdioConnection({ + command, + args, + }: StdioConfig): Promise { + let transport: StdioClientTransport | undefined; + try { + logger.debug('Attempting to connect to IDE via stdio'); + this.client = new Client({ + name: 'stdio-client', + // TODO(#3487): use the CLI version here. + version: '1.0.0', + }); + + transport = new StdioClientTransport({ + command, + args, + }); + await this.client.connect(transport); + this.registerClientHandlers(); + this.setState(IDEConnectionStatus.Connected); + return true; + } catch (_error) { + if (transport) { + try { + await transport.close(); + } catch (closeError) { + logger.debug('Failed to close transport:', closeError); + } + } + return false; + } + } } function getIdeServerHost() {