From fde9849d48e3b92377aca2eecfd390ebce288692 Mon Sep 17 00:00:00 2001 From: christine betts Date: Wed, 6 Aug 2025 17:36:05 +0000 Subject: [PATCH] [ide-mode] Add support for in-IDE diff handling in the CLI (#5603) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../messages/ToolConfirmationMessage.tsx | 41 +++++-- .../core/src/core/coreToolScheduler.test.ts | 2 + packages/core/src/core/coreToolScheduler.ts | 24 ++++ packages/core/src/ide/ide-client.ts | 106 +++++++++++++++++- packages/core/src/ide/ideContext.ts | 59 ++++++++++ packages/core/src/tools/edit.ts | 1 + packages/core/src/tools/memoryTool.ts | 1 + packages/core/src/tools/tools.ts | 3 + packages/core/src/tools/write-file.test.ts | 29 ++++- packages/core/src/tools/write-file.ts | 18 +++ .../vscode-ide-companion/src/diff-manager.ts | 69 +++++++----- .../vscode-ide-companion/src/ide-server.ts | 22 ++-- 12 files changed, 323 insertions(+), 52 deletions(-) diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 197a922c..8b7f93d1 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -33,6 +33,7 @@ export const ToolConfirmationMessage: React.FC< ToolConfirmationMessageProps > = ({ confirmationDetails, + config, isFocused = true, availableTerminalHeight, terminalWidth, @@ -40,14 +41,29 @@ export const ToolConfirmationMessage: React.FC< const { onConfirm } = confirmationDetails; const childWidth = terminalWidth - 2; // 2 for padding + const handleConfirm = async (outcome: ToolConfirmationOutcome) => { + if (confirmationDetails.type === 'edit') { + const ideClient = config?.getIdeClient(); + if (config?.getIdeMode() && config?.getIdeModeFeature()) { + const cliOutcome = + outcome === ToolConfirmationOutcome.Cancel ? 'rejected' : 'accepted'; + await ideClient?.resolveDiffFromCli( + confirmationDetails.filePath, + cliOutcome, + ); + } + } + onConfirm(outcome); + }; + useInput((_, key) => { if (!isFocused) return; if (key.escape) { - onConfirm(ToolConfirmationOutcome.Cancel); + handleConfirm(ToolConfirmationOutcome.Cancel); } }); - const handleSelect = (item: ToolConfirmationOutcome) => onConfirm(item); + const handleSelect = (item: ToolConfirmationOutcome) => handleConfirm(item); let bodyContent: React.ReactNode | null = null; // Removed contextDisplay here let question: string; @@ -85,6 +101,7 @@ export const ToolConfirmationMessage: React.FC< HEIGHT_OPTIONS; return Math.max(availableTerminalHeight - surroundingElementsHeight, 1); } + if (confirmationDetails.type === 'edit') { if (confirmationDetails.isModifying) { return ( @@ -114,15 +131,25 @@ export const ToolConfirmationMessage: React.FC< label: 'Yes, allow always', value: ToolConfirmationOutcome.ProceedAlways, }, - { + ); + if (config?.getIdeMode() && config?.getIdeModeFeature()) { + options.push({ + label: 'No', + value: ToolConfirmationOutcome.Cancel, + }); + } else { + // TODO(chrstnb): support edit tool in IDE mode. + + options.push({ label: 'Modify with external editor', value: ToolConfirmationOutcome.ModifyWithEditor, - }, - { + }); + options.push({ label: 'No, suggest changes (esc)', value: ToolConfirmationOutcome.Cancel, - }, - ); + }); + } + bodyContent = ( { type: 'edit', title: 'Confirm Edit', fileName: 'test.txt', + filePath: 'test.txt', fileDiff: '--- test.txt\n+++ test.txt\n@@ -1,1 +1,1 @@\n-old content\n+new content', originalContent: 'old content', diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 5f2cc895..f54aa532 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -476,6 +476,30 @@ export class CoreToolScheduler { ); if (confirmationDetails) { + // Allow IDE to resolve confirmation + if ( + confirmationDetails.type === 'edit' && + confirmationDetails.ideConfirmation + ) { + confirmationDetails.ideConfirmation.then((resolution) => { + if (resolution.status === 'accepted') { + this.handleConfirmationResponse( + reqInfo.callId, + confirmationDetails.onConfirm, + ToolConfirmationOutcome.ProceedOnce, + signal, + ); + } else { + this.handleConfirmationResponse( + reqInfo.callId, + confirmationDetails.onConfirm, + ToolConfirmationOutcome.Cancel, + signal, + ); + } + }); + } + const originalOnConfirm = confirmationDetails.onConfirm; const wrappedConfirmationDetails: ToolCallConfirmationDetails = { ...confirmationDetails, diff --git a/packages/core/src/ide/ide-client.ts b/packages/core/src/ide/ide-client.ts index 8f967147..42b79c44 100644 --- a/packages/core/src/ide/ide-client.ts +++ b/packages/core/src/ide/ide-client.ts @@ -9,7 +9,14 @@ import { DetectedIde, getIdeDisplayName, } from '../ide/detect-ide.js'; -import { ideContext, IdeContextNotificationSchema } from '../ide/ideContext.js'; +import { + ideContext, + IdeContextNotificationSchema, + IdeDiffAcceptedNotificationSchema, + IdeDiffClosedNotificationSchema, + CloseDiffResponseSchema, + DiffUpdateResult, +} from '../ide/ideContext.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; @@ -42,6 +49,7 @@ export class IdeClient { }; private readonly currentIde: DetectedIde | undefined; private readonly currentIdeDisplayName: string | undefined; + private diffResponses = new Map void>(); private constructor() { this.currentIde = detectIde(); @@ -77,6 +85,75 @@ export class IdeClient { await this.establishConnection(port); } + /** + * A diff is accepted with any modifications if the user performs one of the + * following actions: + * - Clicks the checkbox icon in the IDE to accept + * - Runs `command+shift+p` > "Gemini CLI: Accept Diff in IDE" to accept + * - Selects "accept" in the CLI UI + * - Saves the file via `ctrl/command+s` + * + * A diff is rejected if the user performs one of the following actions: + * - Clicks the "x" icon in the IDE + * - Runs "Gemini CLI: Close Diff in IDE" + * - Selects "no" in the CLI UI + * - Closes the file + */ + async openDiff( + filePath: string, + newContent?: string, + ): Promise { + return new Promise((resolve, reject) => { + this.diffResponses.set(filePath, resolve); + this.client + ?.callTool({ + name: `openDiff`, + arguments: { + filePath, + newContent, + }, + }) + .catch((err) => { + logger.debug(`callTool for ${filePath} failed:`, err); + reject(err); + }); + }); + } + + async closeDiff(filePath: string): Promise { + try { + const result = await this.client?.callTool({ + name: `closeDiff`, + arguments: { + filePath, + }, + }); + + if (result) { + const parsed = CloseDiffResponseSchema.parse(result); + return parsed.content; + } + } catch (err) { + logger.debug(`callTool for ${filePath} failed:`, err); + } + return; + } + + // Closes the diff. Instead of waiting for a notification, + // manually resolves the diff resolver as the desired outcome. + async resolveDiffFromCli(filePath: string, outcome: 'accepted' | 'rejected') { + const content = await this.closeDiff(filePath); + const resolver = this.diffResponses.get(filePath); + if (resolver) { + if (outcome === 'accepted') { + resolver({ status: 'accepted', content }); + } else { + resolver({ status: 'rejected', content: undefined }); + } + this.diffResponses.delete(filePath); + } + } + disconnect() { this.setState( IDEConnectionStatus.Disconnected, @@ -175,6 +252,33 @@ export class IdeClient { `IDE connection error. The connection was lost unexpectedly. Please try reconnecting by running /ide enable`, ); }; + this.client.setNotificationHandler( + IdeDiffAcceptedNotificationSchema, + (notification) => { + const { filePath, content } = notification.params; + const resolver = this.diffResponses.get(filePath); + if (resolver) { + resolver({ status: 'accepted', content }); + this.diffResponses.delete(filePath); + } else { + logger.debug(`No resolver found for ${filePath}`); + } + }, + ); + + this.client.setNotificationHandler( + IdeDiffClosedNotificationSchema, + (notification) => { + const { filePath } = notification.params; + const resolver = this.diffResponses.get(filePath); + if (resolver) { + resolver({ status: 'rejected', content: undefined }); + this.diffResponses.delete(filePath); + } else { + logger.debug(`No resolver found for ${filePath}`); + } + }, + ); } private async establishConnection(port: string) { diff --git a/packages/core/src/ide/ideContext.ts b/packages/core/src/ide/ideContext.ts index 588e25ee..3052c029 100644 --- a/packages/core/src/ide/ideContext.ts +++ b/packages/core/src/ide/ideContext.ts @@ -36,10 +36,69 @@ export type IdeContext = z.infer; * Zod schema for validating the 'ide/contextUpdate' notification from the IDE. */ export const IdeContextNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), method: z.literal('ide/contextUpdate'), params: IdeContextSchema, }); +export const IdeDiffAcceptedNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), + method: z.literal('ide/diffAccepted'), + params: z.object({ + filePath: z.string(), + content: z.string(), + }), +}); + +export const IdeDiffClosedNotificationSchema = z.object({ + jsonrpc: z.literal('2.0'), + method: z.literal('ide/diffClosed'), + params: z.object({ + filePath: z.string(), + content: z.string().optional(), + }), +}); + +export const CloseDiffResponseSchema = z + .object({ + content: z + .array( + z.object({ + text: z.string(), + type: z.literal('text'), + }), + ) + .min(1), + }) + .transform((val, ctx) => { + try { + const parsed = JSON.parse(val.content[0].text); + const innerSchema = z.object({ content: z.string().optional() }); + const validationResult = innerSchema.safeParse(parsed); + if (!validationResult.success) { + validationResult.error.issues.forEach((issue) => ctx.addIssue(issue)); + return z.NEVER; + } + return validationResult.data; + } catch (_) { + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: 'Invalid JSON in text content', + }); + return z.NEVER; + } + }); + +export type DiffUpdateResult = + | { + status: 'accepted'; + content?: string; + } + | { + status: 'rejected'; + content: undefined; + }; + type IdeContextSubscriber = (ideContext: IdeContext | undefined) => void; /** diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 25da2292..0d129e42 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -332,6 +332,7 @@ Expectation for required parameters: type: 'edit', title: `Confirm Edit: ${shortenPath(makeRelative(params.file_path, this.config.getTargetDir()))}`, fileName, + filePath: params.file_path, fileDiff, originalContent: editData.currentContent, newContent: editData.newContent, diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 96509f79..847ea5cf 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -220,6 +220,7 @@ export class MemoryTool type: 'edit', title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`, fileName: memoryFilePath, + filePath: memoryFilePath, fileDiff, originalContent: currentContent, newContent, diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 5d9d9253..3404093f 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -6,6 +6,7 @@ import { FunctionDeclaration, PartListUnion, Schema } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; +import { DiffUpdateResult } from '../ide/ideContext.js'; /** * Interface representing the base Tool functionality @@ -330,10 +331,12 @@ export interface ToolEditConfirmationDetails { payload?: ToolConfirmationPayload, ) => Promise; fileName: string; + filePath: string; fileDiff: string; originalContent: string | null; newContent: string; isModifying?: boolean; + ideConfirmation?: Promise; } export interface ToolConfirmationPayload { diff --git a/packages/core/src/tools/write-file.test.ts b/packages/core/src/tools/write-file.test.ts index fe662a02..563579bb 100644 --- a/packages/core/src/tools/write-file.test.ts +++ b/packages/core/src/tools/write-file.test.ts @@ -55,6 +55,9 @@ const mockConfigInternal = { getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT), setApprovalMode: vi.fn(), getGeminiClient: vi.fn(), // Initialize as a plain mock function + getIdeClient: vi.fn(), + getIdeMode: vi.fn(() => false), + getIdeModeFeature: vi.fn(() => false), getWorkspaceContext: () => createMockWorkspaceContext(rootDir), getApiKey: () => 'test-key', getModel: () => 'test-model', @@ -110,6 +113,14 @@ describe('WriteFileTool', () => { mockConfigInternal.getGeminiClient.mockReturnValue( mockGeminiClientInstance, ); + mockConfigInternal.getIdeClient.mockReturnValue({ + openDiff: vi.fn(), + closeDiff: vi.fn(), + getIdeContext: vi.fn(), + subscribeToIdeContext: vi.fn(), + isCodeTrackerEnabled: vi.fn(), + getTrackedCode: vi.fn(), + }); tool = new WriteFileTool(mockConfig); @@ -500,7 +511,11 @@ describe('WriteFileTool', () => { params, abortSignal, ); - if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { + if ( + typeof confirmDetails === 'object' && + 'onConfirm' in confirmDetails && + confirmDetails.onConfirm + ) { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } @@ -554,7 +569,11 @@ describe('WriteFileTool', () => { params, abortSignal, ); - if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { + if ( + typeof confirmDetails === 'object' && + 'onConfirm' in confirmDetails && + confirmDetails.onConfirm + ) { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } @@ -595,7 +614,11 @@ describe('WriteFileTool', () => { params, abortSignal, ); - if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { + if ( + typeof confirmDetails === 'object' && + 'onConfirm' in confirmDetails && + confirmDetails.onConfirm + ) { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 1cb1a917..32ecc068 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -32,6 +32,7 @@ import { recordFileOperationMetric, FileOperation, } from '../telemetry/metrics.js'; +import { IDEConnectionStatus } from '../ide/ide-client.js'; /** * Parameters for the WriteFile tool @@ -184,10 +185,19 @@ export class WriteFileTool DEFAULT_DIFF_OPTIONS, ); + const ideClient = this.config.getIdeClient(); + const ideConfirmation = + this.config.getIdeModeFeature() && + this.config.getIdeMode() && + ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected + ? ideClient.openDiff(params.file_path, correctedContent) + : undefined; + const confirmationDetails: ToolEditConfirmationDetails = { type: 'edit', title: `Confirm Write: ${shortenPath(relativePath)}`, fileName, + filePath: params.file_path, fileDiff, originalContent, newContent: correctedContent, @@ -195,7 +205,15 @@ export class WriteFileTool if (outcome === ToolConfirmationOutcome.ProceedAlways) { this.config.setApprovalMode(ApprovalMode.AUTO_EDIT); } + + if (ideConfirmation) { + const result = await ideConfirmation; + if (result.status === 'accepted' && result.content) { + params.content = result.content; + } + } }, + ideConfirmation, }; return confirmationDetails; } diff --git a/packages/vscode-ide-companion/src/diff-manager.ts b/packages/vscode-ide-companion/src/diff-manager.ts index 159a6101..0dad03a6 100644 --- a/packages/vscode-ide-companion/src/diff-manager.ts +++ b/packages/vscode-ide-companion/src/diff-manager.ts @@ -4,10 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import * as vscode from 'vscode'; -import * as path from 'node:path'; -import { DIFF_SCHEME } from './extension.js'; +import { + IdeDiffAcceptedNotificationSchema, + IdeDiffClosedNotificationSchema, +} from '@google/gemini-cli-core'; import { type JSONRPCNotification } from '@modelcontextprotocol/sdk/types.js'; +import * as path from 'node:path'; +import * as vscode from 'vscode'; +import { DIFF_SCHEME } from './extension.js'; export class DiffContentProvider implements vscode.TextDocumentContentProvider { private content = new Map(); @@ -126,18 +130,19 @@ export class DiffManager { const rightDoc = await vscode.workspace.openTextDocument(uriToClose); const modifiedContent = rightDoc.getText(); await this.closeDiffEditor(uriToClose); - this.onDidChangeEmitter.fire({ - jsonrpc: '2.0', - method: 'ide/diffClosed', - params: { - filePath, - content: modifiedContent, - }, - }); - vscode.window.showInformationMessage(`Diff for ${filePath} closed.`); - } else { - vscode.window.showWarningMessage(`No open diff found for ${filePath}.`); + this.onDidChangeEmitter.fire( + IdeDiffClosedNotificationSchema.parse({ + jsonrpc: '2.0', + method: 'ide/diffClosed', + params: { + filePath, + content: modifiedContent, + }, + }), + ); + return modifiedContent; } + return; } /** @@ -156,14 +161,16 @@ export class DiffManager { const modifiedContent = rightDoc.getText(); await this.closeDiffEditor(rightDocUri); - this.onDidChangeEmitter.fire({ - jsonrpc: '2.0', - method: 'ide/diffAccepted', - params: { - filePath: diffInfo.originalFilePath, - content: modifiedContent, - }, - }); + this.onDidChangeEmitter.fire( + IdeDiffAcceptedNotificationSchema.parse({ + jsonrpc: '2.0', + method: 'ide/diffAccepted', + params: { + filePath: diffInfo.originalFilePath, + content: modifiedContent, + }, + }), + ); } /** @@ -184,14 +191,16 @@ export class DiffManager { const modifiedContent = rightDoc.getText(); await this.closeDiffEditor(rightDocUri); - this.onDidChangeEmitter.fire({ - jsonrpc: '2.0', - method: 'ide/diffClosed', - params: { - filePath: diffInfo.originalFilePath, - content: modifiedContent, - }, - }); + this.onDidChangeEmitter.fire( + IdeDiffClosedNotificationSchema.parse({ + jsonrpc: '2.0', + method: 'ide/diffClosed', + params: { + filePath: diffInfo.originalFilePath, + content: modifiedContent, + }, + }), + ); } private addDiffDocument(uri: vscode.Uri, diffInfo: DiffInfo) { diff --git a/packages/vscode-ide-companion/src/ide-server.ts b/packages/vscode-ide-companion/src/ide-server.ts index 30215ccc..eec99cb3 100644 --- a/packages/vscode-ide-companion/src/ide-server.ts +++ b/packages/vscode-ide-companion/src/ide-server.ts @@ -5,15 +5,13 @@ */ import * as vscode from 'vscode'; +import { IdeContextNotificationSchema } from '@google/gemini-cli-core'; +import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; -import express, { Request, Response } from 'express'; +import express, { type Request, type Response } from 'express'; import { randomUUID } from 'node:crypto'; -import { - isInitializeRequest, - type JSONRPCNotification, -} from '@modelcontextprotocol/sdk/types.js'; -import { Server as HTTPServer } from 'node:http'; +import { type Server as HTTPServer } from 'node:http'; import { z } from 'zod'; import { DiffManager } from './diff-manager.js'; import { OpenFilesManager } from './open-files-manager.js'; @@ -28,11 +26,12 @@ function sendIdeContextUpdateNotification( ) { const ideContext = openFilesManager.state; - const notification: JSONRPCNotification = { + const notification = IdeContextNotificationSchema.parse({ jsonrpc: '2.0', method: 'ide/contextUpdate', params: ideContext, - }; + }); + log( `Sending IDE context update notification: ${JSON.stringify( notification, @@ -76,7 +75,7 @@ export class IDEServer { }); context.subscriptions.push(onDidChangeSubscription); const onDidChangeDiffSubscription = this.diffManager.onDidChange( - (notification: JSONRPCNotification) => { + (notification) => { for (const transport of Object.values(transports)) { transport.send(notification); } @@ -269,12 +268,13 @@ const createMcpServer = (diffManager: DiffManager) => { }).shape, }, async ({ filePath }: { filePath: string }) => { - await diffManager.closeDiff(filePath); + const content = await diffManager.closeDiff(filePath); + const response = { content: content ?? undefined }; return { content: [ { type: 'text', - text: `Closed diff for ${filePath}`, + text: JSON.stringify(response), }, ], };