refactor: Unify file modification confirmation state

- Modifies `EditTool` and `WriteFileTool` to share a single confirmation preference.
- The "Always Proceed" choice for file modifications is now stored in `Config.alwaysSkipModificationConfirmation`.
- This ensures that if a user chooses to always skip confirmation for one file modification tool, this preference is respected by the other.
- `WriteFileTool` constructor now accepts `Config` instead of `targetDir` to facilitate this shared state.
- Tests updated to reflect the new shared confirmation logic.

Fixes https://b.corp.google.com/issues/415897960
This commit is contained in:
Taylor Mullen 2025-05-16 23:33:12 -07:00 committed by N. Taylor Mullen
parent 58e0224061
commit 5dcdbe64ab
5 changed files with 106 additions and 25 deletions

View File

@ -43,6 +43,7 @@ export class Config {
private readonly userAgent: string,
private userMemory: string = '', // Made mutable for refresh
private geminiMdFileCount: number = 0,
private alwaysSkipModificationConfirmation: boolean = false,
) {
// toolRegistry still needs initialization based on the instance
this.toolRegistry = createToolRegistry(this);
@ -114,6 +115,14 @@ export class Config {
setGeminiMdFileCount(count: number): void {
this.geminiMdFileCount = count;
}
getAlwaysSkipModificationConfirmation(): boolean {
return this.alwaysSkipModificationConfirmation;
}
setAlwaysSkipModificationConfirmation(skip: boolean): void {
this.alwaysSkipModificationConfirmation = skip;
}
}
function findEnvFile(startDir: string): string | null {
@ -159,6 +168,7 @@ export function createServerConfig(
userAgent?: string,
userMemory?: string,
geminiMdFileCount?: number,
alwaysSkipModificationConfirmation?: boolean,
): Config {
return new Config(
apiKey,
@ -175,6 +185,7 @@ export function createServerConfig(
userAgent ?? 'GeminiCLI/unknown', // Default user agent
userMemory ?? '',
geminiMdFileCount ?? 0,
alwaysSkipModificationConfirmation ?? false,
);
}
@ -188,7 +199,7 @@ function createToolRegistry(config: Config): ToolRegistry {
new GrepTool(targetDir),
new GlobTool(targetDir),
new EditTool(config),
new WriteFileTool(targetDir),
new WriteFileTool(config),
new WebFetchTool(),
new ReadManyFilesTool(targetDir),
new ShellTool(config),

View File

@ -6,7 +6,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { describe, it, expect, beforeEach, afterEach, vi, Mock } from 'vitest';
import { EditTool, EditToolParams } from './edit.js';
import { FileDiff } from './tools.js';
import path from 'path';
@ -38,10 +38,37 @@ describe('EditTool', () => {
mockConfig = {
getTargetDir: () => rootDir,
getGeminiConfig: () => ({ apiKey: 'test-api-key' }),
getAlwaysSkipModificationConfirmation: vi.fn(() => false),
setAlwaysSkipModificationConfirmation: vi.fn(),
// getGeminiConfig: () => ({ apiKey: 'test-api-key' }), // This was not a real Config method
// Add other properties/methods of Config if EditTool uses them
// Minimal other methods to satisfy Config type if needed by EditTool constructor or other direct uses:
getApiKey: () => 'test-api-key',
getModel: () => 'test-model',
getSandbox: () => false,
getDebugMode: () => false,
getQuestion: () => undefined,
getFullContext: () => false,
getToolDiscoveryCommand: () => undefined,
getToolCallCommand: () => undefined,
getMcpServerCommand: () => undefined,
getMcpServers: () => undefined,
getUserAgent: () => 'test-agent',
getUserMemory: () => '',
setUserMemory: vi.fn(),
getGeminiMdFileCount: () => 0,
setGeminiMdFileCount: vi.fn(),
getToolRegistry: () => ({}) as any, // Minimal mock for ToolRegistry
} as unknown as Config;
// Reset mocks before each test
(mockConfig.getAlwaysSkipModificationConfirmation as Mock).mockClear();
(mockConfig.setAlwaysSkipModificationConfirmation as Mock).mockClear();
// Default to not skipping confirmation
(mockConfig.getAlwaysSkipModificationConfirmation as Mock).mockReturnValue(
false,
);
// Reset mocks and set default implementation for ensureCorrectEdit
mockEnsureCorrectEdit.mockReset();
mockEnsureCorrectEdit.mockImplementation(async (currentContent, params) => {
@ -290,9 +317,10 @@ describe('EditTool', () => {
new_string: fileContent,
};
(tool as any).shouldAlwaysEdit = true;
(
mockConfig.getAlwaysSkipModificationConfirmation as Mock
).mockReturnValueOnce(true);
const result = await tool.execute(params, new AbortController().signal);
(tool as any).shouldAlwaysEdit = false;
expect(result.llmContent).toMatch(/Created new file/);
expect(fs.existsSync(newFilePath)).toBe(true);

View File

@ -56,7 +56,7 @@ interface CalculatedEdit {
*/
export class EditTool extends BaseTool<EditToolParams, ToolResult> {
static readonly Name = 'replace';
private shouldAlwaysEdit = false;
private readonly config: Config;
private readonly rootDirectory: string;
private readonly client: GeminiClient;
@ -98,8 +98,9 @@ Expectation for parameters:
type: 'object',
},
);
this.rootDirectory = path.resolve(config.getTargetDir());
this.client = new GeminiClient(config);
this.config = config;
this.rootDirectory = path.resolve(this.config.getTargetDir());
this.client = new GeminiClient(this.config);
}
/**
@ -234,7 +235,7 @@ Expectation for parameters:
async shouldConfirmExecute(
params: EditToolParams,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldAlwaysEdit) {
if (this.config.getAlwaysSkipModificationConfirmation()) {
return false;
}
const validationError = this.validateToolParams(params);
@ -295,7 +296,7 @@ Expectation for parameters:
fileDiff,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.shouldAlwaysEdit = true;
this.config.setAlwaysSkipModificationConfirmation(true);
}
},
};

View File

@ -4,18 +4,48 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
import { WriteFileTool } from './write-file.js';
import { FileDiff, ToolConfirmationOutcome } from './tools.js';
import { Config } from '../config/config.js';
import { ToolRegistry } from './tool-registry.js'; // Added import
import path from 'path';
import fs from 'fs';
import os from 'os';
const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root');
// Mock Config
const mockConfigInternal = {
getTargetDir: () => rootDir,
getAlwaysSkipModificationConfirmation: vi.fn(() => false),
setAlwaysSkipModificationConfirmation: vi.fn(),
getApiKey: () => 'test-key',
getModel: () => 'test-model',
getSandbox: () => false,
getDebugMode: () => false,
getQuestion: () => undefined,
getFullContext: () => false,
getToolDiscoveryCommand: () => undefined,
getToolCallCommand: () => undefined,
getMcpServerCommand: () => undefined,
getMcpServers: () => undefined,
getUserAgent: () => 'test-agent',
getUserMemory: () => '',
setUserMemory: vi.fn(),
getGeminiMdFileCount: () => 0,
setGeminiMdFileCount: vi.fn(),
getToolRegistry: () =>
({
registerTool: vi.fn(),
discoverTools: vi.fn(),
}) as unknown as ToolRegistry,
};
const mockConfig = mockConfigInternal as unknown as Config;
describe('WriteFileTool', () => {
let tool: WriteFileTool;
let tempDir: string;
// Using a subdirectory within the OS temp directory for the root to avoid potential permission issues.
const rootDir = path.resolve(os.tmpdir(), 'gemini-cli-test-root');
beforeEach(() => {
// Create a unique temporary directory for files created outside the root (for testing boundary conditions)
@ -26,7 +56,12 @@ describe('WriteFileTool', () => {
if (!fs.existsSync(rootDir)) {
fs.mkdirSync(rootDir, { recursive: true });
}
tool = new WriteFileTool(rootDir);
tool = new WriteFileTool(mockConfig);
// Reset mocks before each test that might use them for confirmation logic
mockConfigInternal.getAlwaysSkipModificationConfirmation.mockReturnValue(
false,
);
mockConfigInternal.setAlwaysSkipModificationConfirmation.mockClear();
});
afterEach(() => {

View File

@ -7,6 +7,7 @@
import fs from 'fs';
import path from 'path';
import * as Diff from 'diff';
import { Config } from '../config/config.js';
import {
BaseTool,
ToolResult,
@ -15,9 +16,10 @@ import {
ToolConfirmationOutcome,
ToolCallConfirmationDetails,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js'; // Updated import
import { makeRelative, shortenPath } from '../utils/paths.js'; // Updated import
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js';
/**
* Parameters for the WriteFile tool
*/
@ -38,9 +40,8 @@ export interface WriteFileToolParams {
*/
export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
static readonly Name: string = 'write_file';
private shouldAlwaysWrite = false;
constructor(private readonly rootDirectory: string) {
constructor(private readonly config: Config) {
super(
WriteFileTool.Name,
'WriteFile',
@ -61,12 +62,11 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
type: 'object',
},
);
this.rootDirectory = path.resolve(rootDirectory);
}
private isWithinRoot(pathToCheck: string): boolean {
const normalizedPath = path.normalize(pathToCheck);
const normalizedRoot = path.normalize(this.rootDirectory);
const normalizedRoot = path.normalize(this.config.getTargetDir());
const rootWithSep = normalizedRoot.endsWith(path.sep)
? normalizedRoot
: normalizedRoot + path.sep;
@ -90,13 +90,16 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
return `File path must be absolute: ${params.file_path}`;
}
if (!this.isWithinRoot(params.file_path)) {
return `File path must be within the root directory (${this.rootDirectory}): ${params.file_path}`;
return `File path must be within the root directory (${this.config.getTargetDir()}): ${params.file_path}`;
}
return null;
}
getDescription(params: WriteFileToolParams): string {
const relativePath = makeRelative(params.file_path, this.rootDirectory);
const relativePath = makeRelative(
params.file_path,
this.config.getTargetDir(),
);
return `Writing to ${shortenPath(relativePath)}`;
}
@ -106,7 +109,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
async shouldConfirmExecute(
params: WriteFileToolParams,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldAlwaysWrite) {
if (this.config.getAlwaysSkipModificationConfirmation()) {
return false;
}
@ -118,7 +121,10 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
return false;
}
const relativePath = makeRelative(params.file_path, this.rootDirectory);
const relativePath = makeRelative(
params.file_path,
this.config.getTargetDir(),
);
const fileName = path.basename(params.file_path);
let currentContent = '';
@ -143,7 +149,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
fileDiff,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.shouldAlwaysWrite = true;
this.config.setAlwaysSkipModificationConfirmation(true);
}
},
};