diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 416ee99e..ff155679 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -5,7 +5,14 @@ */ import { FunctionDeclaration } from '@google/genai'; -import { AnyDeclarativeTool, Kind, ToolResult, BaseTool } from './tools.js'; +import { + AnyDeclarativeTool, + Kind, + ToolResult, + BaseDeclarativeTool, + BaseToolInvocation, + ToolInvocation, +} from './tools.js'; import { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; @@ -15,46 +22,29 @@ import { parse } from 'shell-quote'; type ToolParams = Record; -export class DiscoveredTool extends BaseTool { +class DiscoveredToolInvocation extends BaseToolInvocation< + ToolParams, + ToolResult +> { constructor( private readonly config: Config, - name: string, - override readonly description: string, - override readonly parameterSchema: Record, + private readonly toolName: string, + params: ToolParams, ) { - const discoveryCmd = config.getToolDiscoveryCommand()!; - const callCommand = config.getToolCallCommand()!; - description += ` - -This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root. -When called, this tool will execute the command \`${callCommand} ${name}\` on project root. -Tool discovery and call commands can be configured in project or user settings. - -When called, the tool call command is executed as a subprocess. -On success, tool output is returned as a json string. -Otherwise, the following information is returned: - -Stdout: Output on stdout stream. Can be \`(empty)\` or partial. -Stderr: Output on stderr stream. Can be \`(empty)\` or partial. -Error: Error or \`(none)\` if no error was reported for the subprocess. -Exit Code: Exit code or \`(none)\` if terminated by signal. -Signal: Signal number or \`(none)\` if no signal was received. -`; - super( - name, - name, - description, - Kind.Other, - parameterSchema, - false, // isOutputMarkdown - false, // canUpdateOutput - ); + super(params); } - async execute(params: ToolParams): Promise { + getDescription(): string { + return `Calling discovered tool: ${this.toolName}`; + } + + async execute( + _signal: AbortSignal, + _updateOutput?: (output: string) => void, + ): Promise { const callCommand = this.config.getToolCallCommand()!; - const child = spawn(callCommand, [this.name]); - child.stdin.write(JSON.stringify(params)); + const child = spawn(callCommand, [this.toolName]); + child.stdin.write(JSON.stringify(this.params)); child.stdin.end(); let stdout = ''; @@ -124,6 +114,52 @@ Signal: Signal number or \`(none)\` if no signal was received. } } +export class DiscoveredTool extends BaseDeclarativeTool< + ToolParams, + ToolResult +> { + constructor( + private readonly config: Config, + name: string, + override readonly description: string, + override readonly parameterSchema: Record, + ) { + const discoveryCmd = config.getToolDiscoveryCommand()!; + const callCommand = config.getToolCallCommand()!; + description += ` + +This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root. +When called, this tool will execute the command \`${callCommand} ${name}\` on project root. +Tool discovery and call commands can be configured in project or user settings. + +When called, the tool call command is executed as a subprocess. +On success, tool output is returned as a json string. +Otherwise, the following information is returned: + +Stdout: Output on stdout stream. Can be \`(empty)\` or partial. +Stderr: Output on stderr stream. Can be \`(empty)\` or partial. +Error: Error or \`(none)\` if no error was reported for the subprocess. +Exit Code: Exit code or \`(none)\` if terminated by signal. +Signal: Signal number or \`(none)\` if no signal was received. +`; + super( + name, + name, + description, + Kind.Other, + parameterSchema, + false, // isOutputMarkdown + false, // canUpdateOutput + ); + } + + protected createInvocation( + params: ToolParams, + ): ToolInvocation { + return new DiscoveredToolInvocation(this.config, this.name, params); + } +} + export class ToolRegistry { private tools: Map = new Map(); private config: Config; diff --git a/packages/core/src/tools/web-search.test.ts b/packages/core/src/tools/web-search.test.ts new file mode 100644 index 00000000..c0620c08 --- /dev/null +++ b/packages/core/src/tools/web-search.test.ts @@ -0,0 +1,175 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest'; +import { WebSearchTool, WebSearchToolParams } from './web-search.js'; +import { Config } from '../config/config.js'; +import { GeminiClient } from '../core/client.js'; + +// Mock GeminiClient and Config constructor +vi.mock('../core/client.js'); +vi.mock('../config/config.js'); + +describe('WebSearchTool', () => { + const abortSignal = new AbortController().signal; + let mockGeminiClient: GeminiClient; + let tool: WebSearchTool; + + beforeEach(() => { + const mockConfigInstance = { + getGeminiClient: () => mockGeminiClient, + getProxy: () => undefined, + } as unknown as Config; + mockGeminiClient = new GeminiClient(mockConfigInstance); + tool = new WebSearchTool(mockConfigInstance); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('build', () => { + it('should return an invocation for a valid query', () => { + const params: WebSearchToolParams = { query: 'test query' }; + const invocation = tool.build(params); + expect(invocation).toBeDefined(); + expect(invocation.params).toEqual(params); + }); + + it('should throw an error for an empty query', () => { + const params: WebSearchToolParams = { query: '' }; + expect(() => tool.build(params)).toThrow( + "The 'query' parameter cannot be empty.", + ); + }); + + it('should throw an error for a query with only whitespace', () => { + const params: WebSearchToolParams = { query: ' ' }; + expect(() => tool.build(params)).toThrow( + "The 'query' parameter cannot be empty.", + ); + }); + }); + + describe('getDescription', () => { + it('should return a description of the search', () => { + const params: WebSearchToolParams = { query: 'test query' }; + const invocation = tool.build(params); + expect(invocation.getDescription()).toBe( + 'Searching the web for: "test query"', + ); + }); + }); + + describe('execute', () => { + it('should return search results for a successful query', async () => { + const params: WebSearchToolParams = { query: 'successful query' }; + (mockGeminiClient.generateContent as Mock).mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'Here are your results.' }], + }, + }, + ], + }); + + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toBe( + 'Web search results for "successful query":\n\nHere are your results.', + ); + expect(result.returnDisplay).toBe( + 'Search results for "successful query" returned.', + ); + expect(result.sources).toBeUndefined(); + }); + + it('should handle no search results found', async () => { + const params: WebSearchToolParams = { query: 'no results query' }; + (mockGeminiClient.generateContent as Mock).mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: '' }], + }, + }, + ], + }); + + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toBe( + 'No search results or information found for query: "no results query"', + ); + expect(result.returnDisplay).toBe('No information found.'); + }); + + it('should handle API errors gracefully', async () => { + const params: WebSearchToolParams = { query: 'error query' }; + const testError = new Error('API Failure'); + (mockGeminiClient.generateContent as Mock).mockRejectedValue(testError); + + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); + + expect(result.llmContent).toContain('Error:'); + expect(result.llmContent).toContain('API Failure'); + expect(result.returnDisplay).toBe('Error performing web search.'); + }); + + it('should correctly format results with sources and citations', async () => { + const params: WebSearchToolParams = { query: 'grounding query' }; + (mockGeminiClient.generateContent as Mock).mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'This is a test response.' }], + }, + groundingMetadata: { + groundingChunks: [ + { web: { uri: 'https://example.com', title: 'Example Site' } }, + { web: { uri: 'https://google.com', title: 'Google' } }, + ], + groundingSupports: [ + { + segment: { startIndex: 5, endIndex: 14 }, + groundingChunkIndices: [0], + }, + { + segment: { startIndex: 15, endIndex: 24 }, + groundingChunkIndices: [0, 1], + }, + ], + }, + }, + ], + }); + + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); + + const expectedLlmContent = `Web search results for "grounding query": + +This is a test[1] response.[1][2] + +Sources: +[1] Example Site (https://example.com) +[2] Google (https://google.com)`; + + expect(result.llmContent).toBe(expectedLlmContent); + expect(result.returnDisplay).toBe( + 'Search results for "grounding query" returned.', + ); + expect(result.sources).toHaveLength(2); + }); + }); +}); diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts index a2306894..3ecaf21e 100644 --- a/packages/core/src/tools/web-search.ts +++ b/packages/core/src/tools/web-search.ts @@ -5,8 +5,13 @@ */ import { GroundingMetadata } from '@google/genai'; -import { BaseTool, Kind, ToolResult } from './tools.js'; -import { Type } from '@google/genai'; +import { + BaseDeclarativeTool, + BaseToolInvocation, + Kind, + ToolInvocation, + ToolResult, +} from './tools.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { getErrorMessage } from '../utils/errors.js'; @@ -55,74 +60,27 @@ export interface WebSearchToolResult extends ToolResult { : GroundingChunkItem[]; } -/** - * A tool to perform web searches using Google Search via the Gemini API. - */ -export class WebSearchTool extends BaseTool< +class WebSearchToolInvocation extends BaseToolInvocation< WebSearchToolParams, WebSearchToolResult > { - static readonly Name: string = 'google_web_search'; - - constructor(private readonly config: Config) { - super( - WebSearchTool.Name, - 'GoogleSearch', - 'Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.', - Kind.Search, - { - type: Type.OBJECT, - properties: { - query: { - type: Type.STRING, - description: 'The search query to find information on the web.', - }, - }, - required: ['query'], - }, - ); - } - - /** - * Validates the parameters for the WebSearchTool. - * @param params The parameters to validate - * @returns An error message string if validation fails, null if valid - */ - validateParams(params: WebSearchToolParams): string | null { - const errors = SchemaValidator.validate( - this.schema.parametersJsonSchema, - params, - ); - if (errors) { - return errors; - } - - if (!params.query || params.query.trim() === '') { - return "The 'query' parameter cannot be empty."; - } - return null; - } - - override getDescription(params: WebSearchToolParams): string { - return `Searching the web for: "${params.query}"`; - } - - async execute( + constructor( + private readonly config: Config, params: WebSearchToolParams, - signal: AbortSignal, - ): Promise { - const validationError = this.validateToolParams(params); - if (validationError) { - return { - llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`, - returnDisplay: validationError, - }; - } + ) { + super(params); + } + + override getDescription(): string { + return `Searching the web for: "${this.params.query}"`; + } + + async execute(signal: AbortSignal): Promise { const geminiClient = this.config.getGeminiClient(); try { const response = await geminiClient.generateContent( - [{ role: 'user', parts: [{ text: params.query }] }], + [{ role: 'user', parts: [{ text: this.params.query }] }], { tools: [{ googleSearch: {} }] }, signal, ); @@ -138,7 +96,7 @@ export class WebSearchTool extends BaseTool< if (!responseText || !responseText.trim()) { return { - llmContent: `No search results or information found for query: "${params.query}"`, + llmContent: `No search results or information found for query: "${this.params.query}"`, returnDisplay: 'No information found.', }; } @@ -172,7 +130,6 @@ export class WebSearchTool extends BaseTool< const responseChars = modifiedResponseText.split(''); // Use new variable insertions.forEach((insertion) => { - // Fixed arrow function syntax responseChars.splice(insertion.index, 0, insertion.marker); }); modifiedResponseText = responseChars.join(''); // Assign back to modifiedResponseText @@ -180,17 +137,19 @@ export class WebSearchTool extends BaseTool< if (sourceListFormatted.length > 0) { modifiedResponseText += - '\n\nSources:\n' + sourceListFormatted.join('\n'); // Fixed string concatenation + '\n\nSources:\n' + sourceListFormatted.join('\n'); } } return { - llmContent: `Web search results for "${params.query}":\n\n${modifiedResponseText}`, - returnDisplay: `Search results for "${params.query}" returned.`, + llmContent: `Web search results for "${this.params.query}":\n\n${modifiedResponseText}`, + returnDisplay: `Search results for "${this.params.query}" returned.`, sources, }; } catch (error: unknown) { - const errorMessage = `Error during web search for query "${params.query}": ${getErrorMessage(error)}`; + const errorMessage = `Error during web search for query "${ + this.params.query + }": ${getErrorMessage(error)}`; console.error(errorMessage, error); return { llmContent: `Error: ${errorMessage}`, @@ -199,3 +158,60 @@ export class WebSearchTool extends BaseTool< } } } + +/** + * A tool to perform web searches using Google Search via the Gemini API. + */ +export class WebSearchTool extends BaseDeclarativeTool< + WebSearchToolParams, + WebSearchToolResult +> { + static readonly Name: string = 'google_web_search'; + + constructor(private readonly config: Config) { + super( + WebSearchTool.Name, + 'GoogleSearch', + 'Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.', + Kind.Search, + { + type: 'object', + properties: { + query: { + type: 'string', + description: 'The search query to find information on the web.', + }, + }, + required: ['query'], + }, + ); + } + + /** + * Validates the parameters for the WebSearchTool. + * @param params The parameters to validate + * @returns An error message string if validation fails, null if valid + */ + protected override validateToolParams( + params: WebSearchToolParams, + ): string | null { + const errors = SchemaValidator.validate( + this.schema.parametersJsonSchema, + params, + ); + if (errors) { + return errors; + } + + if (!params.query || params.query.trim() === '') { + return "The 'query' parameter cannot be empty."; + } + return null; + } + + protected createInvocation( + params: WebSearchToolParams, + ): ToolInvocation { + return new WebSearchToolInvocation(this.config, params); + } +} diff --git a/packages/core/src/tools/write-file.test.ts b/packages/core/src/tools/write-file.test.ts index 06561602..2d877115 100644 --- a/packages/core/src/tools/write-file.test.ts +++ b/packages/core/src/tools/write-file.test.ts @@ -13,7 +13,11 @@ import { vi, type Mocked, } from 'vitest'; -import { WriteFileTool, WriteFileToolParams } from './write-file.js'; +import { + getCorrectedFileContent, + WriteFileTool, + WriteFileToolParams, +} from './write-file.js'; import { ToolErrorType } from './tool-error.js'; import { FileDiff, @@ -174,74 +178,67 @@ describe('WriteFileTool', () => { vi.clearAllMocks(); }); - describe('validateToolParams', () => { - it('should return null for valid absolute path within root', () => { + describe('build', () => { + it('should return an invocation for a valid absolute path within root', () => { const params = { file_path: path.join(rootDir, 'test.txt'), content: 'hello', }; - expect(tool.validateToolParams(params)).toBeNull(); + const invocation = tool.build(params); + expect(invocation).toBeDefined(); + expect(invocation.params).toEqual(params); }); - it('should return error for relative path', () => { + it('should throw an error for a relative path', () => { const params = { file_path: 'test.txt', content: 'hello' }; - expect(tool.validateToolParams(params)).toMatch( - /File path must be absolute/, - ); + expect(() => tool.build(params)).toThrow(/File path must be absolute/); }); - it('should return error for path outside root', () => { + it('should throw an error for a path outside root', () => { const outsidePath = path.resolve(tempDir, 'outside-root.txt'); const params = { file_path: outsidePath, content: 'hello', }; - const error = tool.validateToolParams(params); - expect(error).toContain( - 'File path must be within one of the workspace directories', + expect(() => tool.build(params)).toThrow( + /File path must be within one of the workspace directories/, ); }); - it('should return error if path is a directory', () => { + it('should throw an error if path is a directory', () => { const dirAsFilePath = path.join(rootDir, 'a_directory'); fs.mkdirSync(dirAsFilePath); const params = { file_path: dirAsFilePath, content: 'hello', }; - expect(tool.validateToolParams(params)).toMatch( + expect(() => tool.build(params)).toThrow( `Path is a directory, not a file: ${dirAsFilePath}`, ); }); - it('should return error if the content is null', () => { + it('should throw an error if the content is null', () => { const dirAsFilePath = path.join(rootDir, 'a_directory'); fs.mkdirSync(dirAsFilePath); const params = { file_path: dirAsFilePath, content: null, } as unknown as WriteFileToolParams; // Intentionally non-conforming - expect(tool.validateToolParams(params)).toMatch( - `params/content must be string`, - ); + expect(() => tool.build(params)).toThrow('params/content must be string'); }); - }); - describe('getDescription', () => { - it('should return error if the file_path is empty', () => { + it('should throw error if the file_path is empty', () => { const dirAsFilePath = path.join(rootDir, 'a_directory'); fs.mkdirSync(dirAsFilePath); const params = { file_path: '', content: '', }; - expect(tool.getDescription(params)).toMatch( - `Model did not provide valid parameters for write file tool, missing or empty "file_path"`, - ); + expect(() => tool.build(params)).toThrow(`Missing or empty "file_path"`); }); }); - describe('_getCorrectedFileContent', () => { + describe('getCorrectedFileContent', () => { it('should call ensureCorrectFileContent for a new file', async () => { const filePath = path.join(rootDir, 'new_corrected_file.txt'); const proposedContent = 'Proposed new content.'; @@ -250,8 +247,8 @@ describe('WriteFileTool', () => { // Ensure the mock is set for this specific test case if needed, or rely on beforeEach mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); - // @ts-expect-error _getCorrectedFileContent is private - const result = await tool._getCorrectedFileContent( + const result = await getCorrectedFileContent( + mockConfig, filePath, proposedContent, abortSignal, @@ -287,8 +284,8 @@ describe('WriteFileTool', () => { occurrences: 1, } as CorrectedEditResult); - // @ts-expect-error _getCorrectedFileContent is private - const result = await tool._getCorrectedFileContent( + const result = await getCorrectedFileContent( + mockConfig, filePath, proposedContent, abortSignal, @@ -324,8 +321,8 @@ describe('WriteFileTool', () => { throw readError; }); - // @ts-expect-error _getCorrectedFileContent is private - const result = await tool._getCorrectedFileContent( + const result = await getCorrectedFileContent( + mockConfig, filePath, proposedContent, abortSignal, @@ -349,18 +346,6 @@ describe('WriteFileTool', () => { describe('shouldConfirmExecute', () => { const abortSignal = new AbortController().signal; - it('should return false if params are invalid (relative path)', async () => { - const params = { file_path: 'relative.txt', content: 'test' }; - const confirmation = await tool.shouldConfirmExecute(params, abortSignal); - expect(confirmation).toBe(false); - }); - - it('should return false if params are invalid (outside root)', async () => { - const outsidePath = path.resolve(tempDir, 'outside-root.txt'); - const params = { file_path: outsidePath, content: 'test' }; - const confirmation = await tool.shouldConfirmExecute(params, abortSignal); - expect(confirmation).toBe(false); - }); it('should return false if _getCorrectedFileContent returns an error', async () => { const filePath = path.join(rootDir, 'confirm_error_file.txt'); @@ -373,7 +358,8 @@ describe('WriteFileTool', () => { throw readError; }); - const confirmation = await tool.shouldConfirmExecute(params, abortSignal); + const invocation = tool.build(params); + const confirmation = await invocation.shouldConfirmExecute(abortSignal); expect(confirmation).toBe(false); vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync); @@ -387,8 +373,8 @@ describe('WriteFileTool', () => { mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); // Ensure this mock is active const params = { file_path: filePath, content: proposedContent }; - const confirmation = (await tool.shouldConfirmExecute( - params, + const invocation = tool.build(params); + const confirmation = (await invocation.shouldConfirmExecute( abortSignal, )) as ToolEditConfirmationDetails; @@ -430,8 +416,8 @@ describe('WriteFileTool', () => { }); const params = { file_path: filePath, content: proposedContent }; - const confirmation = (await tool.shouldConfirmExecute( - params, + const invocation = tool.build(params); + const confirmation = (await invocation.shouldConfirmExecute( abortSignal, )) as ToolEditConfirmationDetails; @@ -461,31 +447,6 @@ describe('WriteFileTool', () => { describe('execute', () => { const abortSignal = new AbortController().signal; - it('should return error if params are invalid (relative path)', async () => { - const params = { file_path: 'relative.txt', content: 'test' }; - const result = await tool.execute(params, abortSignal); - expect(result.llmContent).toContain( - 'Could not write file due to invalid parameters:', - ); - expect(result.returnDisplay).toMatch(/File path must be absolute/); - expect(result.error).toEqual({ - message: 'File path must be absolute: relative.txt', - type: ToolErrorType.INVALID_TOOL_PARAMS, - }); - }); - - it('should return error if params are invalid (path outside root)', async () => { - const outsidePath = path.resolve(tempDir, 'outside-root.txt'); - const params = { file_path: outsidePath, content: 'test' }; - const result = await tool.execute(params, abortSignal); - expect(result.llmContent).toContain( - 'Could not write file due to invalid parameters:', - ); - expect(result.returnDisplay).toContain( - 'File path must be within one of the workspace directories', - ); - expect(result.error?.type).toBe(ToolErrorType.INVALID_TOOL_PARAMS); - }); it('should return error if _getCorrectedFileContent returns an error during execute', async () => { const filePath = path.join(rootDir, 'execute_error_file.txt'); @@ -498,7 +459,8 @@ describe('WriteFileTool', () => { throw readError; }); - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.llmContent).toContain('Error checking existing file:'); expect(result.returnDisplay).toMatch( /Error checking existing file: Simulated read error for execute/, @@ -520,11 +482,9 @@ describe('WriteFileTool', () => { mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); const params = { file_path: filePath, content: proposedContent }; + const invocation = tool.build(params); - const confirmDetails = await tool.shouldConfirmExecute( - params, - abortSignal, - ); + const confirmDetails = await invocation.shouldConfirmExecute(abortSignal); if ( typeof confirmDetails === 'object' && 'onConfirm' in confirmDetails && @@ -533,7 +493,7 @@ describe('WriteFileTool', () => { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } - const result = await tool.execute(params, abortSignal); + const result = await invocation.execute(abortSignal); expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( proposedContent, @@ -578,11 +538,9 @@ describe('WriteFileTool', () => { }); const params = { file_path: filePath, content: proposedContent }; + const invocation = tool.build(params); - const confirmDetails = await tool.shouldConfirmExecute( - params, - abortSignal, - ); + const confirmDetails = await invocation.shouldConfirmExecute(abortSignal); if ( typeof confirmDetails === 'object' && 'onConfirm' in confirmDetails && @@ -591,7 +549,7 @@ describe('WriteFileTool', () => { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } - const result = await tool.execute(params, abortSignal); + const result = await invocation.execute(abortSignal); expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( filePath, @@ -623,11 +581,9 @@ describe('WriteFileTool', () => { mockEnsureCorrectFileContent.mockResolvedValue(content); // Ensure this mock is active const params = { file_path: filePath, content }; + const invocation = tool.build(params); // Simulate confirmation if your logic requires it before execute, or remove if not needed for this path - const confirmDetails = await tool.shouldConfirmExecute( - params, - abortSignal, - ); + const confirmDetails = await invocation.shouldConfirmExecute(abortSignal); if ( typeof confirmDetails === 'object' && 'onConfirm' in confirmDetails && @@ -636,7 +592,7 @@ describe('WriteFileTool', () => { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } - await tool.execute(params, abortSignal); + await invocation.execute(abortSignal); expect(fs.existsSync(dirPath)).toBe(true); expect(fs.statSync(dirPath).isDirectory()).toBe(true); @@ -654,7 +610,8 @@ describe('WriteFileTool', () => { content, modified_by_user: true, }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.llmContent).toMatch(/User modified the `content`/); }); @@ -669,7 +626,8 @@ describe('WriteFileTool', () => { content, modified_by_user: false, }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.llmContent).not.toMatch(/User modified the `content`/); }); @@ -683,7 +641,8 @@ describe('WriteFileTool', () => { file_path: filePath, content, }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.llmContent).not.toMatch(/User modified the `content`/); }); @@ -695,7 +654,7 @@ describe('WriteFileTool', () => { file_path: path.join(rootDir, 'file.txt'), content: 'test content', }; - expect(tool.validateToolParams(params)).toBeNull(); + expect(() => tool.build(params)).not.toThrow(); }); it('should reject paths outside workspace root', () => { @@ -703,24 +662,9 @@ describe('WriteFileTool', () => { file_path: '/etc/passwd', content: 'malicious', }; - const error = tool.validateToolParams(params); - expect(error).toContain( - 'File path must be within one of the workspace directories', + expect(() => tool.build(params)).toThrow( + /File path must be within one of the workspace directories/, ); - expect(error).toContain(rootDir); - }); - - it('should provide clear error message with workspace directories', () => { - const outsidePath = path.join(tempDir, 'outside-root.txt'); - const params = { - file_path: outsidePath, - content: 'test', - }; - const error = tool.validateToolParams(params); - expect(error).toContain( - 'File path must be within one of the workspace directories', - ); - expect(error).toContain(rootDir); }); }); @@ -740,13 +684,16 @@ describe('WriteFileTool', () => { }); const params = { file_path: filePath, content }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.error?.type).toBe(ToolErrorType.PERMISSION_DENIED); expect(result.llmContent).toContain( `Permission denied writing to file: ${filePath} (EACCES)`, ); - expect(result.returnDisplay).toContain('Permission denied'); + expect(result.returnDisplay).toContain( + `Permission denied writing to file: ${filePath} (EACCES)`, + ); vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync); }); @@ -766,13 +713,16 @@ describe('WriteFileTool', () => { }); const params = { file_path: filePath, content }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.error?.type).toBe(ToolErrorType.NO_SPACE_LEFT); expect(result.llmContent).toContain( `No space left on device: ${filePath} (ENOSPC)`, ); - expect(result.returnDisplay).toContain('No space left'); + expect(result.returnDisplay).toContain( + `No space left on device: ${filePath} (ENOSPC)`, + ); vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync); }); @@ -799,13 +749,16 @@ describe('WriteFileTool', () => { }); const params = { file_path: dirPath, content }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.error?.type).toBe(ToolErrorType.TARGET_IS_DIRECTORY); expect(result.llmContent).toContain( `Target is a directory, not a file: ${dirPath} (EISDIR)`, ); - expect(result.returnDisplay).toContain('Target is a directory'); + expect(result.returnDisplay).toContain( + `Target is a directory, not a file: ${dirPath} (EISDIR)`, + ); vi.spyOn(fs, 'existsSync').mockImplementation(originalExistsSync); vi.spyOn(fs, 'writeFileSync').mockImplementation(originalWriteFileSync); @@ -824,13 +777,16 @@ describe('WriteFileTool', () => { }); const params = { file_path: filePath, content }; - const result = await tool.execute(params, abortSignal); + const invocation = tool.build(params); + const result = await invocation.execute(abortSignal); expect(result.error?.type).toBe(ToolErrorType.FILE_WRITE_FAILURE); expect(result.llmContent).toContain( 'Error writing to file: Generic write error', ); - expect(result.returnDisplay).toContain('Generic write error'); + expect(result.returnDisplay).toContain( + 'Error writing to file: Generic write error', + ); }); }); }); diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 01c92865..c889d6a3 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -9,14 +9,16 @@ import path from 'path'; import * as Diff from 'diff'; import { Config, ApprovalMode } from '../config/config.js'; import { - BaseTool, - ToolResult, + BaseDeclarativeTool, + BaseToolInvocation, FileDiff, - ToolEditConfirmationDetails, - ToolConfirmationOutcome, - ToolCallConfirmationDetails, Kind, + ToolCallConfirmationDetails, + ToolConfirmationOutcome, + ToolEditConfirmationDetails, + ToolInvocation, ToolLocation, + ToolResult, } from './tools.js'; import { ToolErrorType } from './tool-error.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; @@ -67,113 +69,99 @@ interface GetCorrectedFileContentResult { error?: { message: string; code?: string }; } -/** - * Implementation of the WriteFile tool logic - */ -export class WriteFileTool - extends BaseTool - implements ModifiableDeclarativeTool -{ - static readonly Name: string = 'write_file'; +export async function getCorrectedFileContent( + config: Config, + filePath: string, + proposedContent: string, + abortSignal: AbortSignal, +): Promise { + let originalContent = ''; + let fileExists = false; + let correctedContent = proposedContent; - constructor(private readonly config: Config) { - super( - WriteFileTool.Name, - 'WriteFile', - `Writes content to a specified file in the local filesystem. + try { + originalContent = fs.readFileSync(filePath, 'utf8'); + fileExists = true; // File exists and was read + } catch (err) { + if (isNodeError(err) && err.code === 'ENOENT') { + fileExists = false; + originalContent = ''; + } else { + // File exists but could not be read (permissions, etc.) + fileExists = true; // Mark as existing but problematic + originalContent = ''; // Can't use its content + const error = { + message: getErrorMessage(err), + code: isNodeError(err) ? err.code : undefined, + }; + // Return early as we can't proceed with content correction meaningfully + return { originalContent, correctedContent, fileExists, error }; + } + } - The user has the ability to modify \`content\`. If modified, this will be stated in the response.`, - Kind.Edit, + // If readError is set, we have returned. + // So, file was either read successfully (fileExists=true, originalContent set) + // or it was ENOENT (fileExists=false, originalContent=''). + + if (fileExists) { + // This implies originalContent is available + const { params: correctedParams } = await ensureCorrectEdit( + filePath, + originalContent, { - properties: { - file_path: { - description: - "The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.", - type: 'string', - }, - content: { - description: 'The content to write to the file.', - type: 'string', - }, - }, - required: ['file_path', 'content'], - type: 'object', + old_string: originalContent, // Treat entire current content as old_string + new_string: proposedContent, + file_path: filePath, }, + config.getGeminiClient(), + abortSignal, + ); + correctedContent = correctedParams.new_string; + } else { + // This implies new file (ENOENT) + correctedContent = await ensureCorrectFileContent( + proposedContent, + config.getGeminiClient(), + abortSignal, ); } + return { originalContent, correctedContent, fileExists }; +} - override toolLocations(params: WriteFileToolParams): ToolLocation[] { - return [{ path: params.file_path }]; +class WriteFileToolInvocation extends BaseToolInvocation< + WriteFileToolParams, + ToolResult +> { + constructor( + private readonly config: Config, + params: WriteFileToolParams, + ) { + super(params); } - override validateToolParams(params: WriteFileToolParams): string | null { - const errors = SchemaValidator.validate( - this.schema.parametersJsonSchema, - params, - ); - if (errors) { - return errors; - } - - const filePath = params.file_path; - if (!path.isAbsolute(filePath)) { - return `File path must be absolute: ${filePath}`; - } - - const workspaceContext = this.config.getWorkspaceContext(); - if (!workspaceContext.isPathWithinWorkspace(filePath)) { - const directories = workspaceContext.getDirectories(); - return `File path must be within one of the workspace directories: ${directories.join(', ')}`; - } - - try { - // This check should be performed only if the path exists. - // If it doesn't exist, it's a new file, which is valid for writing. - if (fs.existsSync(filePath)) { - const stats = fs.lstatSync(filePath); - if (stats.isDirectory()) { - return `Path is a directory, not a file: ${filePath}`; - } - } - } catch (statError: unknown) { - // If fs.existsSync is true but lstatSync fails (e.g., permissions, race condition where file is deleted) - // this indicates an issue with accessing the path that should be reported. - return `Error accessing path properties for validation: ${filePath}. Reason: ${statError instanceof Error ? statError.message : String(statError)}`; - } - - return null; + override toolLocations(): ToolLocation[] { + return [{ path: this.params.file_path }]; } - override getDescription(params: WriteFileToolParams): string { - if (!params.file_path) { - return `Model did not provide valid parameters for write file tool, missing or empty "file_path"`; - } + override getDescription(): string { const relativePath = makeRelative( - params.file_path, + this.params.file_path, this.config.getTargetDir(), ); return `Writing to ${shortenPath(relativePath)}`; } - /** - * Handles the confirmation prompt for the WriteFile tool. - */ override async shouldConfirmExecute( - params: WriteFileToolParams, abortSignal: AbortSignal, ): Promise { if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { return false; } - const validationError = this.validateToolParams(params); - if (validationError) { - return false; - } - - const correctedContentResult = await this._getCorrectedFileContent( - params.file_path, - params.content, + const correctedContentResult = await getCorrectedFileContent( + this.config, + this.params.file_path, + this.params.content, abortSignal, ); @@ -184,10 +172,10 @@ export class WriteFileTool const { originalContent, correctedContent } = correctedContentResult; const relativePath = makeRelative( - params.file_path, + this.params.file_path, this.config.getTargetDir(), ); - const fileName = path.basename(params.file_path); + const fileName = path.basename(this.params.file_path); const fileDiff = Diff.createPatch( fileName, @@ -202,14 +190,14 @@ export class WriteFileTool const ideConfirmation = this.config.getIdeMode() && ideClient.getConnectionStatus().status === IDEConnectionStatus.Connected - ? ideClient.openDiff(params.file_path, correctedContent) + ? ideClient.openDiff(this.params.file_path, correctedContent) : undefined; const confirmationDetails: ToolEditConfirmationDetails = { type: 'edit', title: `Confirm Write: ${shortenPath(relativePath)}`, fileName, - filePath: params.file_path, + filePath: this.params.file_path, fileDiff, originalContent, newContent: correctedContent, @@ -221,7 +209,7 @@ export class WriteFileTool if (ideConfirmation) { const result = await ideConfirmation; if (result.status === 'accepted' && result.content) { - params.content = result.content; + this.params.content = result.content; } } }, @@ -230,32 +218,20 @@ export class WriteFileTool return confirmationDetails; } - async execute( - params: WriteFileToolParams, - abortSignal: AbortSignal, - ): Promise { - const validationError = this.validateToolParams(params); - if (validationError) { - return { - llmContent: `Could not write file due to invalid parameters: ${validationError}`, - returnDisplay: validationError, - error: { - message: validationError, - type: ToolErrorType.INVALID_TOOL_PARAMS, - }, - }; - } - - const correctedContentResult = await this._getCorrectedFileContent( - params.file_path, - params.content, + async execute(abortSignal: AbortSignal): Promise { + const { file_path, content, ai_proposed_content, modified_by_user } = + this.params; + const correctedContentResult = await getCorrectedFileContent( + this.config, + file_path, + content, abortSignal, ); if (correctedContentResult.error) { const errDetails = correctedContentResult.error; const errorMsg = errDetails.code - ? `Error checking existing file '${params.file_path}': ${errDetails.message} (${errDetails.code})` + ? `Error checking existing file '${file_path}': ${errDetails.message} (${errDetails.code})` : `Error checking existing file: ${errDetails.message}`; return { llmContent: errorMsg, @@ -280,15 +256,15 @@ export class WriteFileTool !correctedContentResult.fileExists); try { - const dirName = path.dirname(params.file_path); + const dirName = path.dirname(file_path); if (!fs.existsSync(dirName)) { fs.mkdirSync(dirName, { recursive: true }); } - fs.writeFileSync(params.file_path, fileContent, 'utf8'); + fs.writeFileSync(file_path, fileContent, 'utf8'); // Generate diff for display result - const fileName = path.basename(params.file_path); + const fileName = path.basename(file_path); // If there was a readError, originalContent in correctedContentResult is '', // but for the diff, we want to show the original content as it was before the write if possible. // However, if it was unreadable, currentContentForDiff will be empty. @@ -305,23 +281,22 @@ export class WriteFileTool DEFAULT_DIFF_OPTIONS, ); - const originallyProposedContent = - params.ai_proposed_content || params.content; + const originallyProposedContent = ai_proposed_content || content; const diffStat = getDiffStat( fileName, currentContentForDiff, originallyProposedContent, - params.content, + content, ); const llmSuccessMessageParts = [ isNewFile - ? `Successfully created and wrote to new file: ${params.file_path}.` - : `Successfully overwrote file: ${params.file_path}.`, + ? `Successfully created and wrote to new file: ${file_path}.` + : `Successfully overwrote file: ${file_path}.`, ]; - if (params.modified_by_user) { + if (modified_by_user) { llmSuccessMessageParts.push( - `User modified the \`content\` to be: ${params.content}`, + `User modified the \`content\` to be: ${content}`, ); } @@ -334,8 +309,8 @@ export class WriteFileTool }; const lines = fileContent.split('\n').length; - const mimetype = getSpecificMimeType(params.file_path); - const extension = path.extname(params.file_path); // Get extension + const mimetype = getSpecificMimeType(file_path); + const extension = path.extname(file_path); // Get extension if (isNewFile) { recordFileOperationMetric( this.config, @@ -367,17 +342,17 @@ export class WriteFileTool if (isNodeError(error)) { // Handle specific Node.js errors with their error codes - errorMsg = `Error writing to file '${params.file_path}': ${error.message} (${error.code})`; + errorMsg = `Error writing to file '${file_path}': ${error.message} (${error.code})`; // Log specific error types for better debugging if (error.code === 'EACCES') { - errorMsg = `Permission denied writing to file: ${params.file_path} (${error.code})`; + errorMsg = `Permission denied writing to file: ${file_path} (${error.code})`; errorType = ToolErrorType.PERMISSION_DENIED; } else if (error.code === 'ENOSPC') { - errorMsg = `No space left on device: ${params.file_path} (${error.code})`; + errorMsg = `No space left on device: ${file_path} (${error.code})`; errorType = ToolErrorType.NO_SPACE_LEFT; } else if (error.code === 'EISDIR') { - errorMsg = `Target is a directory, not a file: ${params.file_path} (${error.code})`; + errorMsg = `Target is a directory, not a file: ${file_path} (${error.code})`; errorType = ToolErrorType.TARGET_IS_DIRECTORY; } @@ -401,63 +376,92 @@ export class WriteFileTool }; } } +} - private async _getCorrectedFileContent( - filePath: string, - proposedContent: string, - abortSignal: AbortSignal, - ): Promise { - let originalContent = ''; - let fileExists = false; - let correctedContent = proposedContent; +/** + * Implementation of the WriteFile tool logic + */ +export class WriteFileTool + extends BaseDeclarativeTool + implements ModifiableDeclarativeTool +{ + static readonly Name: string = 'write_file'; + + constructor(private readonly config: Config) { + super( + WriteFileTool.Name, + 'WriteFile', + `Writes content to a specified file in the local filesystem. + + The user has the ability to modify \`content\`. If modified, this will be stated in the response.`, + Kind.Edit, + { + properties: { + file_path: { + description: + "The absolute path to the file to write to (e.g., '/home/user/project/file.txt'). Relative paths are not supported.", + type: 'string', + }, + content: { + description: 'The content to write to the file.', + type: 'string', + }, + }, + required: ['file_path', 'content'], + type: 'object', + }, + ); + } + + protected override validateToolParams( + params: WriteFileToolParams, + ): string | null { + const errors = SchemaValidator.validate( + this.schema.parametersJsonSchema, + params, + ); + if (errors) { + return errors; + } + + const filePath = params.file_path; + + if (!filePath) { + return `Missing or empty "file_path"`; + } + + if (!path.isAbsolute(filePath)) { + return `File path must be absolute: ${filePath}`; + } + + const workspaceContext = this.config.getWorkspaceContext(); + if (!workspaceContext.isPathWithinWorkspace(filePath)) { + const directories = workspaceContext.getDirectories(); + return `File path must be within one of the workspace directories: ${directories.join( + ', ', + )}`; + } try { - originalContent = fs.readFileSync(filePath, 'utf8'); - fileExists = true; // File exists and was read - } catch (err) { - if (isNodeError(err) && err.code === 'ENOENT') { - fileExists = false; - originalContent = ''; - } else { - // File exists but could not be read (permissions, etc.) - fileExists = true; // Mark as existing but problematic - originalContent = ''; // Can't use its content - const error = { - message: getErrorMessage(err), - code: isNodeError(err) ? err.code : undefined, - }; - // Return early as we can't proceed with content correction meaningfully - return { originalContent, correctedContent, fileExists, error }; + if (fs.existsSync(filePath)) { + const stats = fs.lstatSync(filePath); + if (stats.isDirectory()) { + return `Path is a directory, not a file: ${filePath}`; + } } + } catch (statError: unknown) { + return `Error accessing path properties for validation: ${filePath}. Reason: ${ + statError instanceof Error ? statError.message : String(statError) + }`; } - // If readError is set, we have returned. - // So, file was either read successfully (fileExists=true, originalContent set) - // or it was ENOENT (fileExists=false, originalContent=''). + return null; + } - if (fileExists) { - // This implies originalContent is available - const { params: correctedParams } = await ensureCorrectEdit( - filePath, - originalContent, - { - old_string: originalContent, // Treat entire current content as old_string - new_string: proposedContent, - file_path: filePath, - }, - this.config.getGeminiClient(), - abortSignal, - ); - correctedContent = correctedParams.new_string; - } else { - // This implies new file (ENOENT) - correctedContent = await ensureCorrectFileContent( - proposedContent, - this.config.getGeminiClient(), - abortSignal, - ); - } - return { originalContent, correctedContent, fileExists }; + protected createInvocation( + params: WriteFileToolParams, + ): ToolInvocation { + return new WriteFileToolInvocation(this.config, params); } getModifyContext( @@ -466,7 +470,8 @@ export class WriteFileTool return { getFilePath: (params: WriteFileToolParams) => params.file_path, getCurrentContent: async (params: WriteFileToolParams) => { - const correctedContentResult = await this._getCorrectedFileContent( + const correctedContentResult = await getCorrectedFileContent( + this.config, params.file_path, params.content, abortSignal, @@ -474,7 +479,8 @@ export class WriteFileTool return correctedContentResult.originalContent; }, getProposedContent: async (params: WriteFileToolParams) => { - const correctedContentResult = await this._getCorrectedFileContent( + const correctedContentResult = await getCorrectedFileContent( + this.config, params.file_path, params.content, abortSignal,