From 5dcdbe64ab6ef2ca2692e75b8600b8726ac72178 Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Fri, 16 May 2025 23:33:12 -0700 Subject: [PATCH] refactor: Unify file modification confirmation state - Modifies `EditTool` and `WriteFileTool` to share a single confirmation preference. - The "Always Proceed" choice for file modifications is now stored in `Config.alwaysSkipModificationConfirmation`. - This ensures that if a user chooses to always skip confirmation for one file modification tool, this preference is respected by the other. - `WriteFileTool` constructor now accepts `Config` instead of `targetDir` to facilitate this shared state. - Tests updated to reflect the new shared confirmation logic. Fixes https://b.corp.google.com/issues/415897960 --- packages/server/src/config/config.ts | 13 +++++- packages/server/src/tools/edit.test.ts | 36 ++++++++++++++-- packages/server/src/tools/edit.ts | 11 ++--- packages/server/src/tools/write-file.test.ts | 43 ++++++++++++++++++-- packages/server/src/tools/write-file.ts | 28 ++++++++----- 5 files changed, 106 insertions(+), 25 deletions(-) diff --git a/packages/server/src/config/config.ts b/packages/server/src/config/config.ts index 8b9648c4..fdd7973e 100644 --- a/packages/server/src/config/config.ts +++ b/packages/server/src/config/config.ts @@ -43,6 +43,7 @@ export class Config { private readonly userAgent: string, private userMemory: string = '', // Made mutable for refresh private geminiMdFileCount: number = 0, + private alwaysSkipModificationConfirmation: boolean = false, ) { // toolRegistry still needs initialization based on the instance this.toolRegistry = createToolRegistry(this); @@ -114,6 +115,14 @@ export class Config { setGeminiMdFileCount(count: number): void { this.geminiMdFileCount = count; } + + getAlwaysSkipModificationConfirmation(): boolean { + return this.alwaysSkipModificationConfirmation; + } + + setAlwaysSkipModificationConfirmation(skip: boolean): void { + this.alwaysSkipModificationConfirmation = skip; + } } function findEnvFile(startDir: string): string | null { @@ -159,6 +168,7 @@ export function createServerConfig( userAgent?: string, userMemory?: string, geminiMdFileCount?: number, + alwaysSkipModificationConfirmation?: boolean, ): Config { return new Config( apiKey, @@ -175,6 +185,7 @@ export function createServerConfig( userAgent ?? 'GeminiCLI/unknown', // Default user agent userMemory ?? '', geminiMdFileCount ?? 0, + alwaysSkipModificationConfirmation ?? false, ); } @@ -188,7 +199,7 @@ function createToolRegistry(config: Config): ToolRegistry { new GrepTool(targetDir), new GlobTool(targetDir), new EditTool(config), - new WriteFileTool(targetDir), + new WriteFileTool(config), new WebFetchTool(), new ReadManyFilesTool(targetDir), new ShellTool(config), diff --git a/packages/server/src/tools/edit.test.ts b/packages/server/src/tools/edit.test.ts index f0650d70..a552cf53 100644 --- a/packages/server/src/tools/edit.test.ts +++ b/packages/server/src/tools/edit.test.ts @@ -6,7 +6,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; +import { describe, it, expect, beforeEach, afterEach, vi, Mock } from 'vitest'; import { EditTool, EditToolParams } from './edit.js'; import { FileDiff } from './tools.js'; import path from 'path'; @@ -38,10 +38,37 @@ describe('EditTool', () => { mockConfig = { getTargetDir: () => rootDir, - getGeminiConfig: () => ({ apiKey: 'test-api-key' }), + getAlwaysSkipModificationConfirmation: vi.fn(() => false), + setAlwaysSkipModificationConfirmation: vi.fn(), + // getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method // Add other properties/methods of Config if EditTool uses them + // Minimal other methods to satisfy Config type if needed by EditTool constructor or other direct uses: + getApiKey: () => 'test-api-key', + getModel: () => 'test-model', + getSandbox: () => false, + getDebugMode: () => false, + getQuestion: () => undefined, + getFullContext: () => false, + getToolDiscoveryCommand: () => undefined, + getToolCallCommand: () => undefined, + getMcpServerCommand: () => undefined, + getMcpServers: () => undefined, + getUserAgent: () => 'test-agent', + getUserMemory: () => '', + setUserMemory: vi.fn(), + getGeminiMdFileCount: () => 0, + setGeminiMdFileCount: vi.fn(), + getToolRegistry: () => ({}) as any, // Minimal mock for ToolRegistry } as unknown as Config; + // Reset mocks before each test + (mockConfig.getAlwaysSkipModificationConfirmation as Mock).mockClear(); + (mockConfig.setAlwaysSkipModificationConfirmation as Mock).mockClear(); + // Default to not skipping confirmation + (mockConfig.getAlwaysSkipModificationConfirmation as Mock).mockReturnValue( + false, + ); + // Reset mocks and set default implementation for ensureCorrectEdit mockEnsureCorrectEdit.mockReset(); mockEnsureCorrectEdit.mockImplementation(async (currentContent, params) => { @@ -290,9 +317,10 @@ describe('EditTool', () => { new_string: fileContent, }; - (tool as any).shouldAlwaysEdit = true; + ( + mockConfig.getAlwaysSkipModificationConfirmation as Mock + ).mockReturnValueOnce(true); const result = await tool.execute(params, new AbortController().signal); - (tool as any).shouldAlwaysEdit = false; expect(result.llmContent).toMatch(/Created new file/); expect(fs.existsSync(newFilePath)).toBe(true); diff --git a/packages/server/src/tools/edit.ts b/packages/server/src/tools/edit.ts index f7c911ec..7b327778 100644 --- a/packages/server/src/tools/edit.ts +++ b/packages/server/src/tools/edit.ts @@ -56,7 +56,7 @@ interface CalculatedEdit { */ export class EditTool extends BaseTool { static readonly Name = 'replace'; - private shouldAlwaysEdit = false; + private readonly config: Config; private readonly rootDirectory: string; private readonly client: GeminiClient; @@ -98,8 +98,9 @@ Expectation for parameters: type: 'object', }, ); - this.rootDirectory = path.resolve(config.getTargetDir()); - this.client = new GeminiClient(config); + this.config = config; + this.rootDirectory = path.resolve(this.config.getTargetDir()); + this.client = new GeminiClient(this.config); } /** @@ -234,7 +235,7 @@ Expectation for parameters: async shouldConfirmExecute( params: EditToolParams, ): Promise { - if (this.shouldAlwaysEdit) { + if (this.config.getAlwaysSkipModificationConfirmation()) { return false; } const validationError = this.validateToolParams(params); @@ -295,7 +296,7 @@ Expectation for parameters: fileDiff, onConfirm: async (outcome: ToolConfirmationOutcome) => { if (outcome === ToolConfirmationOutcome.ProceedAlways) { - this.shouldAlwaysEdit = true; + this.config.setAlwaysSkipModificationConfirmation(true); } }, }; diff --git a/packages/server/src/tools/write-file.test.ts b/packages/server/src/tools/write-file.test.ts index 25d1e998..fc2ca61b 100644 --- a/packages/server/src/tools/write-file.test.ts +++ b/packages/server/src/tools/write-file.test.ts @@ -4,18 +4,48 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest'; import { WriteFileTool } from './write-file.js'; import { FileDiff, ToolConfirmationOutcome } from './tools.js'; +import { Config } from '../config/config.js'; +import { ToolRegistry } from './tool-registry.js'; // Added import import path from 'path'; import fs from 'fs'; import os from 'os'; +const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root'); + +// Mock Config +const mockConfigInternal = { + getTargetDir: () => rootDir, + getAlwaysSkipModificationConfirmation: vi.fn(() => false), + setAlwaysSkipModificationConfirmation: vi.fn(), + getApiKey: () => 'test-key', + getModel: () => 'test-model', + getSandbox: () => false, + getDebugMode: () => false, + getQuestion: () => undefined, + getFullContext: () => false, + getToolDiscoveryCommand: () => undefined, + getToolCallCommand: () => undefined, + getMcpServerCommand: () => undefined, + getMcpServers: () => undefined, + getUserAgent: () => 'test-agent', + getUserMemory: () => '', + setUserMemory: vi.fn(), + getGeminiMdFileCount: () => 0, + setGeminiMdFileCount: vi.fn(), + getToolRegistry: () => + ({ + registerTool: vi.fn(), + discoverTools: vi.fn(), + }) as unknown as ToolRegistry, +}; +const mockConfig = mockConfigInternal as unknown as Config; + describe('WriteFileTool', () => { let tool: WriteFileTool; let tempDir: string; - // Using a subdirectory within the OS temp directory for the root to avoid potential permission issues. - const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root'); beforeEach(() => { // Create a unique temporary directory for files created outside the root (for testing boundary conditions) @@ -26,7 +56,12 @@ describe('WriteFileTool', () => { if (!fs.existsSync(rootDir)) { fs.mkdirSync(rootDir, { recursive: true }); } - tool = new WriteFileTool(rootDir); + tool = new WriteFileTool(mockConfig); + // Reset mocks before each test that might use them for confirmation logic + mockConfigInternal.getAlwaysSkipModificationConfirmation.mockReturnValue( + false, + ); + mockConfigInternal.setAlwaysSkipModificationConfirmation.mockClear(); }); afterEach(() => { diff --git a/packages/server/src/tools/write-file.ts b/packages/server/src/tools/write-file.ts index 21178b5b..2979ffb8 100644 --- a/packages/server/src/tools/write-file.ts +++ b/packages/server/src/tools/write-file.ts @@ -7,6 +7,7 @@ import fs from 'fs'; import path from 'path'; import * as Diff from 'diff'; +import { Config } from '../config/config.js'; import { BaseTool, ToolResult, @@ -15,9 +16,10 @@ import { ToolConfirmationOutcome, ToolCallConfirmationDetails, } from './tools.js'; -import { SchemaValidator } from '../utils/schemaValidator.js'; // Updated import -import { makeRelative, shortenPath } from '../utils/paths.js'; // Updated import +import { SchemaValidator } from '../utils/schemaValidator.js'; +import { makeRelative, shortenPath } from '../utils/paths.js'; import { isNodeError } from '../utils/errors.js'; + /** * Parameters for the WriteFile tool */ @@ -38,9 +40,8 @@ export interface WriteFileToolParams { */ export class WriteFileTool extends BaseTool { static readonly Name: string = 'write_file'; - private shouldAlwaysWrite = false; - constructor(private readonly rootDirectory: string) { + constructor(private readonly config: Config) { super( WriteFileTool.Name, 'WriteFile', @@ -61,12 +62,11 @@ export class WriteFileTool extends BaseTool { type: 'object', }, ); - this.rootDirectory = path.resolve(rootDirectory); } private isWithinRoot(pathToCheck: string): boolean { const normalizedPath = path.normalize(pathToCheck); - const normalizedRoot = path.normalize(this.rootDirectory); + const normalizedRoot = path.normalize(this.config.getTargetDir()); const rootWithSep = normalizedRoot.endsWith(path.sep) ? normalizedRoot : normalizedRoot + path.sep; @@ -90,13 +90,16 @@ export class WriteFileTool extends BaseTool { return `File path must be absolute: ${params.file_path}`; } if (!this.isWithinRoot(params.file_path)) { - return `File path must be within the root directory (${this.rootDirectory}): ${params.file_path}`; + return `File path must be within the root directory (${this.config.getTargetDir()}): ${params.file_path}`; } return null; } getDescription(params: WriteFileToolParams): string { - const relativePath = makeRelative(params.file_path, this.rootDirectory); + const relativePath = makeRelative( + params.file_path, + this.config.getTargetDir(), + ); return `Writing to ${shortenPath(relativePath)}`; } @@ -106,7 +109,7 @@ export class WriteFileTool extends BaseTool { async shouldConfirmExecute( params: WriteFileToolParams, ): Promise { - if (this.shouldAlwaysWrite) { + if (this.config.getAlwaysSkipModificationConfirmation()) { return false; } @@ -118,7 +121,10 @@ export class WriteFileTool extends BaseTool { return false; } - const relativePath = makeRelative(params.file_path, this.rootDirectory); + const relativePath = makeRelative( + params.file_path, + this.config.getTargetDir(), + ); const fileName = path.basename(params.file_path); let currentContent = ''; @@ -143,7 +149,7 @@ export class WriteFileTool extends BaseTool { fileDiff, onConfirm: async (outcome: ToolConfirmationOutcome) => { if (outcome === ToolConfirmationOutcome.ProceedAlways) { - this.shouldAlwaysWrite = true; + this.config.setAlwaysSkipModificationConfirmation(true); } }, };