From 6133bea388a2de69c71a6be6f1450707f2ce4dfb Mon Sep 17 00:00:00 2001 From: joshualitt Date: Wed, 6 Aug 2025 10:50:02 -0700 Subject: [PATCH] feat(core): Introduce `DeclarativeTool` and `ToolInvocation`. (#5613) --- packages/cli/src/acp/acpPeer.ts | 159 ++++---- .../cli/src/ui/hooks/atCommandProcessor.ts | 13 +- .../cli/src/ui/hooks/useGeminiStream.test.tsx | 56 ++- .../cli/src/ui/hooks/useReactToolScheduler.ts | 27 +- .../cli/src/ui/hooks/useToolScheduler.test.ts | 180 +++++----- packages/core/src/core/client.ts | 18 +- .../core/src/core/coreToolScheduler.test.ts | 83 ++--- packages/core/src/core/coreToolScheduler.ts | 110 +++++- .../core/nonInteractiveToolExecutor.test.ts | 81 ++--- .../src/core/nonInteractiveToolExecutor.ts | 2 +- .../src/telemetry/loggers.test.circular.ts | 10 +- packages/core/src/telemetry/loggers.test.ts | 7 +- .../core/src/telemetry/uiTelemetry.test.ts | 8 +- packages/core/src/test-utils/tools.ts | 63 ++++ packages/core/src/tools/edit.ts | 4 +- packages/core/src/tools/memoryTool.ts | 4 +- .../core/src/tools/modifiable-tool.test.ts | 12 +- packages/core/src/tools/modifiable-tool.ts | 18 +- packages/core/src/tools/read-file.test.ts | 338 +++++++++--------- packages/core/src/tools/read-file.ts | 138 +++---- packages/core/src/tools/tool-registry.test.ts | 24 +- packages/core/src/tools/tool-registry.ts | 14 +- packages/core/src/tools/tools.ts | 299 ++++++++++++---- packages/core/src/tools/write-file.ts | 4 +- 24 files changed, 991 insertions(+), 681 deletions(-) create mode 100644 packages/core/src/test-utils/tools.ts diff --git a/packages/cli/src/acp/acpPeer.ts b/packages/cli/src/acp/acpPeer.ts index 90952b7f..40d8753f 100644 --- a/packages/cli/src/acp/acpPeer.ts +++ b/packages/cli/src/acp/acpPeer.ts @@ -239,65 +239,62 @@ class GeminiAgent implements Agent { ); } - let toolCallId; - const confirmationDetails = await tool.shouldConfirmExecute( - args, - abortSignal, - ); - if (confirmationDetails) { - let content: acp.ToolCallContent | null = null; - if (confirmationDetails.type === 'edit') { - content = { - type: 'diff', - path: confirmationDetails.fileName, - oldText: confirmationDetails.originalContent, - newText: confirmationDetails.newContent, - }; - } - - const result = await this.client.requestToolCallConfirmation({ - label: tool.getDescription(args), - icon: tool.icon, - content, - confirmation: toAcpToolCallConfirmation(confirmationDetails), - locations: tool.toolLocations(args), - }); - - await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome)); - switch (result.outcome) { - case 'reject': - return errorResponse( - new Error(`Tool "${fc.name}" not allowed to run by the user.`), - ); - - case 'cancel': - return errorResponse( - new Error(`Tool "${fc.name}" was canceled by the user.`), - ); - case 'allow': - case 'alwaysAllow': - case 'alwaysAllowMcpServer': - case 'alwaysAllowTool': - break; - default: { - const resultOutcome: never = result.outcome; - throw new Error(`Unexpected: ${resultOutcome}`); - } - } - - toolCallId = result.id; - } else { - const result = await this.client.pushToolCall({ - icon: tool.icon, - label: tool.getDescription(args), - locations: tool.toolLocations(args), - }); - - toolCallId = result.id; - } - + let toolCallId: number | undefined = undefined; try { - const toolResult: ToolResult = await tool.execute(args, abortSignal); + const invocation = tool.build(args); + const confirmationDetails = + await invocation.shouldConfirmExecute(abortSignal); + if (confirmationDetails) { + let content: acp.ToolCallContent | null = null; + if (confirmationDetails.type === 'edit') { + content = { + type: 'diff', + path: confirmationDetails.fileName, + oldText: confirmationDetails.originalContent, + newText: confirmationDetails.newContent, + }; + } + + const result = await this.client.requestToolCallConfirmation({ + label: invocation.getDescription(), + icon: tool.icon, + content, + confirmation: toAcpToolCallConfirmation(confirmationDetails), + locations: invocation.toolLocations(), + }); + + await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome)); + switch (result.outcome) { + case 'reject': + return errorResponse( + new Error(`Tool "${fc.name}" not allowed to run by the user.`), + ); + + case 'cancel': + return errorResponse( + new Error(`Tool "${fc.name}" was canceled by the user.`), + ); + case 'allow': + case 'alwaysAllow': + case 'alwaysAllowMcpServer': + case 'alwaysAllowTool': + break; + default: { + const resultOutcome: never = result.outcome; + throw new Error(`Unexpected: ${resultOutcome}`); + } + } + toolCallId = result.id; + } else { + const result = await this.client.pushToolCall({ + icon: tool.icon, + label: invocation.getDescription(), + locations: invocation.toolLocations(), + }); + toolCallId = result.id; + } + + const toolResult: ToolResult = await invocation.execute(abortSignal); const toolCallContent = toToolCallContent(toolResult); await this.client.updateToolCall({ @@ -320,12 +317,13 @@ class GeminiAgent implements Agent { return convertToFunctionResponse(fc.name, callId, toolResult.llmContent); } catch (e) { const error = e instanceof Error ? e : new Error(String(e)); - await this.client.updateToolCall({ - toolCallId, - status: 'error', - content: { type: 'markdown', markdown: error.message }, - }); - + if (toolCallId) { + await this.client.updateToolCall({ + toolCallId, + status: 'error', + content: { type: 'markdown', markdown: error.message }, + }); + } return errorResponse(error); } } @@ -408,7 +406,7 @@ class GeminiAgent implements Agent { `Path ${pathName} not found directly, attempting glob search.`, ); try { - const globResult = await globTool.execute( + const globResult = await globTool.buildAndExecute( { pattern: `**/*${pathName}*`, path: this.config.getTargetDir(), @@ -530,12 +528,15 @@ class GeminiAgent implements Agent { respectGitIgnore, // Use configuration setting }; - const toolCall = await this.client.pushToolCall({ - icon: readManyFilesTool.icon, - label: readManyFilesTool.getDescription(toolArgs), - }); + let toolCallId: number | undefined = undefined; try { - const result = await readManyFilesTool.execute(toolArgs, abortSignal); + const invocation = readManyFilesTool.build(toolArgs); + const toolCall = await this.client.pushToolCall({ + icon: readManyFilesTool.icon, + label: invocation.getDescription(), + }); + toolCallId = toolCall.id; + const result = await invocation.execute(abortSignal); const content = toToolCallContent(result) || { type: 'markdown', markdown: `Successfully read: ${contentLabelsForDisplay.join(', ')}`, @@ -578,14 +579,16 @@ class GeminiAgent implements Agent { return processedQueryParts; } catch (error: unknown) { - await this.client.updateToolCall({ - toolCallId: toolCall.id, - status: 'error', - content: { - type: 'markdown', - markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, - }, - }); + if (toolCallId) { + await this.client.updateToolCall({ + toolCallId, + status: 'error', + content: { + type: 'markdown', + markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, + }, + }); + } throw error; } } diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 165b7b30..cef2f811 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -8,6 +8,7 @@ import * as fs from 'fs/promises'; import * as path from 'path'; import { PartListUnion, PartUnion } from '@google/genai'; import { + AnyToolInvocation, Config, getErrorMessage, isNodeError, @@ -254,7 +255,7 @@ export async function handleAtCommand({ `Path ${pathName} not found directly, attempting glob search.`, ); try { - const globResult = await globTool.execute( + const globResult = await globTool.buildAndExecute( { pattern: `**/*${pathName}*`, path: dir, @@ -411,12 +412,14 @@ export async function handleAtCommand({ }; let toolCallDisplay: IndividualToolCallDisplay; + let invocation: AnyToolInvocation | undefined = undefined; try { - const result = await readManyFilesTool.execute(toolArgs, signal); + invocation = readManyFilesTool.build(toolArgs); + const result = await invocation.execute(signal); toolCallDisplay = { callId: `client-read-${userMessageTimestamp}`, name: readManyFilesTool.displayName, - description: readManyFilesTool.getDescription(toolArgs), + description: invocation.getDescription(), status: ToolCallStatus.Success, resultDisplay: result.returnDisplay || @@ -466,7 +469,9 @@ export async function handleAtCommand({ toolCallDisplay = { callId: `client-read-${userMessageTimestamp}`, name: readManyFilesTool.displayName, - description: readManyFilesTool.getDescription(toolArgs), + description: + invocation?.getDescription() ?? + 'Error attempting to execute tool to read files', status: ToolCallStatus.Error, resultDisplay: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, confirmationDetails: undefined, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 062c1687..dd2428bb 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -21,6 +21,7 @@ import { EditorType, AuthType, GeminiEventType as ServerGeminiEventType, + AnyToolInvocation, } from '@google/gemini-cli-core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -452,9 +453,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'tool1', + displayName: 'tool1', description: 'desc1', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, startTime: Date.now(), endTime: Date.now(), } as TrackedCompletedToolCall, @@ -469,9 +474,13 @@ describe('useGeminiStream', () => { responseSubmittedToGemini: false, tool: { name: 'tool2', + displayName: 'tool2', description: 'desc2', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, startTime: Date.now(), liveOutput: '...', } as TrackedExecutingToolCall, @@ -506,6 +515,12 @@ describe('useGeminiStream', () => { status: 'success', responseSubmittedToGemini: false, response: { callId: 'call1', responseParts: toolCall1ResponseParts }, + tool: { + displayName: 'MockTool', + }, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, } as TrackedCompletedToolCall, { request: { @@ -584,6 +599,12 @@ describe('useGeminiStream', () => { status: 'cancelled', response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, responseSubmittedToGemini: false, + tool: { + displayName: 'mock tool', + }, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, } as TrackedCancelledToolCall, ]; const client = new MockedGeminiClientClass(mockConfig); @@ -644,9 +665,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'toolA', + displayName: 'toolA', description: 'descA', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, status: 'cancelled', response: { callId: 'cancel-1', @@ -668,9 +693,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'toolB', + displayName: 'toolB', description: 'descB', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, status: 'cancelled', response: { callId: 'cancel-2', @@ -760,9 +789,13 @@ describe('useGeminiStream', () => { responseSubmittedToGemini: false, tool: { name: 'tool1', + displayName: 'tool1', description: 'desc', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, startTime: Date.now(), } as TrackedExecutingToolCall, ]; @@ -980,8 +1013,13 @@ describe('useGeminiStream', () => { tool: { name: 'tool1', description: 'desc1', - getDescription: vi.fn(), + build: vi.fn().mockImplementation((_) => ({ + getDescription: () => `Mock description`, + })), } as any, + invocation: { + getDescription: () => `Mock description`, + }, startTime: Date.now(), liveOutput: '...', } as TrackedExecutingToolCall, @@ -1131,9 +1169,13 @@ describe('useGeminiStream', () => { }, tool: { name: 'save_memory', + displayName: 'save_memory', description: 'Saves memory', - getDescription: vi.fn(), + build: vi.fn(), } as any, + invocation: { + getDescription: () => `Mock description`, + } as unknown as AnyToolInvocation, }; // Capture the onComplete callback diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts index 01993650..c6b802fc 100644 --- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -17,7 +17,6 @@ import { OutputUpdateHandler, AllToolCallsCompleteHandler, ToolCallsUpdateHandler, - Tool, ToolCall, Status as CoreStatus, EditorType, @@ -216,23 +215,20 @@ export function mapToDisplay( const toolDisplays = toolCalls.map( (trackedCall): IndividualToolCallDisplay => { - let displayName = trackedCall.request.name; - let description = ''; + let displayName: string; + let description: string; let renderOutputAsMarkdown = false; - const currentToolInstance = - 'tool' in trackedCall && trackedCall.tool - ? (trackedCall as { tool: Tool }).tool - : undefined; - - if (currentToolInstance) { - displayName = currentToolInstance.displayName; - description = currentToolInstance.getDescription( - trackedCall.request.args, - ); - renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown; - } else if ('request' in trackedCall && 'args' in trackedCall.request) { + if (trackedCall.status === 'error') { + displayName = + trackedCall.tool === undefined + ? trackedCall.request.name + : trackedCall.tool.displayName; description = JSON.stringify(trackedCall.request.args); + } else { + displayName = trackedCall.tool.displayName; + description = trackedCall.invocation.getDescription(); + renderOutputAsMarkdown = trackedCall.tool.isOutputMarkdown; } const baseDisplayProperties: Omit< @@ -256,7 +252,6 @@ export function mapToDisplay( case 'error': return { ...baseDisplayProperties, - name: currentToolInstance?.displayName ?? trackedCall.request.name, status: mapCoreStatusToDisplayStatus(trackedCall.status), resultDisplay: trackedCall.response.resultDisplay, confirmationDetails: undefined, diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 5395d18a..ee5251d3 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -15,7 +15,6 @@ import { PartUnion, FunctionResponse } from '@google/genai'; import { Config, ToolCallRequestInfo, - Tool, ToolRegistry, ToolResult, ToolCallConfirmationDetails, @@ -25,6 +24,9 @@ import { Status as ToolCallStatusType, ApprovalMode, Icon, + BaseTool, + AnyDeclarativeTool, + AnyToolInvocation, } from '@google/gemini-cli-core'; import { HistoryItemWithoutId, @@ -53,46 +55,55 @@ const mockConfig = { getDebugMode: () => false, }; -const mockTool: Tool = { - name: 'mockTool', - displayName: 'Mock Tool', - description: 'A mock tool for testing', - icon: Icon.Hammer, - toolLocations: vi.fn(), - isOutputMarkdown: false, - canUpdateOutput: false, - schema: {}, - validateToolParams: vi.fn(), - execute: vi.fn(), - shouldConfirmExecute: vi.fn(), - getDescription: vi.fn((args) => `Description for ${JSON.stringify(args)}`), -}; +class MockTool extends BaseTool { + constructor( + name: string, + displayName: string, + canUpdateOutput = false, + shouldConfirm = false, + isOutputMarkdown = false, + ) { + super( + name, + displayName, + 'A mock tool for testing', + Icon.Hammer, + {}, + isOutputMarkdown, + canUpdateOutput, + ); + if (shouldConfirm) { + this.shouldConfirmExecute = vi.fn( + async (): Promise => ({ + type: 'edit', + title: 'Mock Tool Requires Confirmation', + onConfirm: mockOnUserConfirmForToolConfirmation, + fileName: 'mockToolRequiresConfirmation.ts', + fileDiff: 'Mock tool requires confirmation', + originalContent: 'Original content', + newContent: 'New content', + }), + ); + } + } -const mockToolWithLiveOutput: Tool = { - ...mockTool, - name: 'mockToolWithLiveOutput', - displayName: 'Mock Tool With Live Output', - canUpdateOutput: true, -}; + execute = vi.fn(); + shouldConfirmExecute = vi.fn(); +} +const mockTool = new MockTool('mockTool', 'Mock Tool'); +const mockToolWithLiveOutput = new MockTool( + 'mockToolWithLiveOutput', + 'Mock Tool With Live Output', + true, +); let mockOnUserConfirmForToolConfirmation: Mock; - -const mockToolRequiresConfirmation: Tool = { - ...mockTool, - name: 'mockToolRequiresConfirmation', - displayName: 'Mock Tool Requires Confirmation', - shouldConfirmExecute: vi.fn( - async (): Promise => ({ - type: 'edit', - title: 'Mock Tool Requires Confirmation', - onConfirm: mockOnUserConfirmForToolConfirmation, - fileName: 'mockToolRequiresConfirmation.ts', - fileDiff: 'Mock tool requires confirmation', - originalContent: 'Original content', - newContent: 'New content', - }), - ), -}; +const mockToolRequiresConfirmation = new MockTool( + 'mockToolRequiresConfirmation', + 'Mock Tool Requires Confirmation', + false, + true, +); describe('useReactToolScheduler in YOLO Mode', () => { let onComplete: Mock; @@ -646,28 +657,21 @@ describe('useReactToolScheduler', () => { }); it('should schedule and execute multiple tool calls', async () => { - const tool1 = { - ...mockTool, - name: 'tool1', - displayName: 'Tool 1', - execute: vi.fn().mockResolvedValue({ - llmContent: 'Output 1', - returnDisplay: 'Display 1', - summary: 'Summary 1', - } as ToolResult), - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - }; - const tool2 = { - ...mockTool, - name: 'tool2', - displayName: 'Tool 2', - execute: vi.fn().mockResolvedValue({ - llmContent: 'Output 2', - returnDisplay: 'Display 2', - summary: 'Summary 2', - } as ToolResult), - shouldConfirmExecute: vi.fn().mockResolvedValue(null), - }; + const tool1 = new MockTool('tool1', 'Tool 1'); + tool1.execute.mockResolvedValue({ + llmContent: 'Output 1', + returnDisplay: 'Display 1', + summary: 'Summary 1', + } as ToolResult); + tool1.shouldConfirmExecute.mockResolvedValue(null); + + const tool2 = new MockTool('tool2', 'Tool 2'); + tool2.execute.mockResolvedValue({ + llmContent: 'Output 2', + returnDisplay: 'Display 2', + summary: 'Summary 2', + } as ToolResult); + tool2.shouldConfirmExecute.mockResolvedValue(null); mockToolRegistry.getTool.mockImplementation((name) => { if (name === 'tool1') return tool1; @@ -805,20 +809,7 @@ describe('mapToDisplay', () => { args: { foo: 'bar' }, }; - const baseTool: Tool = { - name: 'testTool', - displayName: 'Test Tool Display', - description: 'Test Description', - isOutputMarkdown: false, - canUpdateOutput: false, - schema: {}, - icon: Icon.Hammer, - toolLocations: vi.fn(), - validateToolParams: vi.fn(), - execute: vi.fn(), - shouldConfirmExecute: vi.fn(), - getDescription: vi.fn((args) => `Desc: ${JSON.stringify(args)}`), - }; + const baseTool = new MockTool('testTool', 'Test Tool Display'); const baseResponse: ToolCallResponseInfo = { callId: 'testCallId', @@ -840,13 +831,15 @@ describe('mapToDisplay', () => { // This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist. type MapToDisplayExtraProps = | { - tool?: Tool; + tool?: AnyDeclarativeTool; + invocation?: AnyToolInvocation; liveOutput?: string; response?: ToolCallResponseInfo; confirmationDetails?: ToolCallConfirmationDetails; } | { - tool: Tool; + tool: AnyDeclarativeTool; + invocation?: AnyToolInvocation; response?: ToolCallResponseInfo; confirmationDetails?: ToolCallConfirmationDetails; } @@ -857,10 +850,12 @@ describe('mapToDisplay', () => { } | { confirmationDetails: ToolCallConfirmationDetails; - tool?: Tool; + tool?: AnyDeclarativeTool; + invocation?: AnyToolInvocation; response?: ToolCallResponseInfo; }; + const baseInvocation = baseTool.build(baseRequest.args); const testCases: Array<{ name: string; status: ToolCallStatusType; @@ -873,7 +868,7 @@ describe('mapToDisplay', () => { { name: 'validating', status: 'validating', - extraProps: { tool: baseTool }, + extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Executing, expectedName: baseTool.displayName, expectedDescription: baseTool.getDescription(baseRequest.args), @@ -883,6 +878,7 @@ describe('mapToDisplay', () => { status: 'awaiting_approval', extraProps: { tool: baseTool, + invocation: baseInvocation, confirmationDetails: { onConfirm: vi.fn(), type: 'edit', @@ -903,7 +899,7 @@ describe('mapToDisplay', () => { { name: 'scheduled', status: 'scheduled', - extraProps: { tool: baseTool }, + extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Pending, expectedName: baseTool.displayName, expectedDescription: baseTool.getDescription(baseRequest.args), @@ -911,7 +907,7 @@ describe('mapToDisplay', () => { { name: 'executing no live output', status: 'executing', - extraProps: { tool: baseTool }, + extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Executing, expectedName: baseTool.displayName, expectedDescription: baseTool.getDescription(baseRequest.args), @@ -919,7 +915,11 @@ describe('mapToDisplay', () => { { name: 'executing with live output', status: 'executing', - extraProps: { tool: baseTool, liveOutput: 'Live test output' }, + extraProps: { + tool: baseTool, + invocation: baseInvocation, + liveOutput: 'Live test output', + }, expectedStatus: ToolCallStatus.Executing, expectedResultDisplay: 'Live test output', expectedName: baseTool.displayName, @@ -928,7 +928,11 @@ describe('mapToDisplay', () => { { name: 'success', status: 'success', - extraProps: { tool: baseTool, response: baseResponse }, + extraProps: { + tool: baseTool, + invocation: baseInvocation, + response: baseResponse, + }, expectedStatus: ToolCallStatus.Success, expectedResultDisplay: baseResponse.resultDisplay as any, expectedName: baseTool.displayName, @@ -970,6 +974,7 @@ describe('mapToDisplay', () => { status: 'cancelled', extraProps: { tool: baseTool, + invocation: baseInvocation, response: { ...baseResponse, resultDisplay: 'Cancelled display', @@ -1030,12 +1035,21 @@ describe('mapToDisplay', () => { request: { ...baseRequest, callId: 'call1' }, status: 'success', tool: baseTool, + invocation: baseTool.build(baseRequest.args), response: { ...baseResponse, callId: 'call1' }, } as ToolCall; + const toolForCall2 = new MockTool( + baseTool.name, + baseTool.displayName, + false, + false, + true, + ); const toolCall2: ToolCall = { request: { ...baseRequest, callId: 'call2' }, status: 'executing', - tool: { ...baseTool, isOutputMarkdown: true }, + tool: toolForCall2, + invocation: toolForCall2.build(baseRequest.args), liveOutput: 'markdown output', } as ToolCall; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 3b6b57f9..f8b9a7de 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -24,7 +24,6 @@ import { import { Config } from '../config/config.js'; import { UserTierId } from '../code_assist/types.js'; import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js'; -import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; import { reportError } from '../utils/errorReporting.js'; @@ -252,18 +251,15 @@ export class GeminiClient { // Add full file context if the flag is set if (this.config.getFullContext()) { try { - const readManyFilesTool = toolRegistry.getTool( - 'read_many_files', - ) as ReadManyFilesTool; + const readManyFilesTool = toolRegistry.getTool('read_many_files'); if (readManyFilesTool) { + const invocation = readManyFilesTool.build({ + paths: ['**/*'], // Read everything recursively + useDefaultExcludes: true, // Use default excludes + }); + // Read all files in the target directory - const result = await readManyFilesTool.execute( - { - paths: ['**/*'], // Read everything recursively - useDefaultExcludes: true, // Use default excludes - }, - AbortSignal.timeout(30000), - ); + const result = await invocation.execute(AbortSignal.timeout(30000)); if (result.llmContent) { initialParts.push({ text: `\n--- Full File Context ---\n${result.llmContent}`, diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 4d786d00..a65443f8 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -24,44 +24,15 @@ import { } from '../index.js'; import { Part, PartListUnion } from '@google/genai'; -import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js'; - -class MockTool extends BaseTool, ToolResult> { - shouldConfirm = false; - executeFn = vi.fn(); - - constructor(name = 'mockTool') { - super(name, name, 'A mock tool', Icon.Hammer, {}); - } - - async shouldConfirmExecute( - _params: Record, - _abortSignal: AbortSignal, - ): Promise { - if (this.shouldConfirm) { - return { - type: 'exec', - title: 'Confirm Mock Tool', - command: 'do_thing', - rootCommand: 'do_thing', - onConfirm: async () => {}, - }; - } - return false; - } - - async execute( - params: Record, - _abortSignal: AbortSignal, - ): Promise { - this.executeFn(params); - return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' }; - } -} +import { + ModifiableDeclarativeTool, + ModifyContext, +} from '../tools/modifiable-tool.js'; +import { MockTool } from '../test-utils/tools.js'; class MockModifiableTool extends MockTool - implements ModifiableTool> + implements ModifiableDeclarativeTool> { constructor(name = 'mockModifiableTool') { super(name); @@ -83,10 +54,7 @@ class MockModifiableTool }; } - async shouldConfirmExecute( - _params: Record, - _abortSignal: AbortSignal, - ): Promise { + async shouldConfirmExecute(): Promise { if (this.shouldConfirm) { return { type: 'edit', @@ -107,14 +75,15 @@ describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { const mockTool = new MockTool(); mockTool.shouldConfirm = true; + const declarativeTool = mockTool; const toolRegistry = { - getTool: () => mockTool, + getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByName: () => mockTool, - getToolByDisplayName: () => mockTool, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -177,14 +146,15 @@ describe('CoreToolScheduler', () => { describe('CoreToolScheduler with payload', () => { it('should update args and diff and execute tool when payload is provided', async () => { const mockTool = new MockModifiableTool(); + const declarativeTool = mockTool; const toolRegistry = { - getTool: () => mockTool, + getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByName: () => mockTool, - getToolByDisplayName: () => mockTool, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -221,10 +191,7 @@ describe('CoreToolScheduler with payload', () => { await scheduler.schedule([request], abortController.signal); - const confirmationDetails = await mockTool.shouldConfirmExecute( - {}, - abortController.signal, - ); + const confirmationDetails = await mockTool.shouldConfirmExecute(); if (confirmationDetails) { const payload: ToolConfirmationPayload = { newContent: 'final version' }; @@ -456,14 +423,15 @@ describe('CoreToolScheduler edit cancellation', () => { } const mockEditTool = new MockEditTool(); + const declarativeTool = mockEditTool; const toolRegistry = { - getTool: () => mockEditTool, + getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByName: () => mockEditTool, - getToolByDisplayName: () => mockEditTool, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -541,18 +509,23 @@ describe('CoreToolScheduler YOLO mode', () => { it('should execute tool requiring confirmation directly without waiting', async () => { // Arrange const mockTool = new MockTool(); + mockTool.executeFn.mockReturnValue({ + llmContent: 'Tool executed', + returnDisplay: 'Tool executed', + }); // This tool would normally require confirmation. mockTool.shouldConfirm = true; + const declarativeTool = mockTool; const toolRegistry = { - getTool: () => mockTool, - getToolByName: () => mockTool, + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, // Other properties are not needed for this test but are included for type consistency. getFunctionDeclarations: () => [], tools: new Map(), discovery: {} as any, registerTool: () => {}, - getToolByDisplayName: () => mockTool, + getToolByDisplayName: () => declarativeTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 9b999b6b..6f098ae3 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -8,7 +8,6 @@ import { ToolCallRequestInfo, ToolCallResponseInfo, ToolConfirmationOutcome, - Tool, ToolCallConfirmationDetails, ToolResult, ToolResultDisplay, @@ -20,11 +19,13 @@ import { ToolCallEvent, ToolConfirmationPayload, ToolErrorType, + AnyDeclarativeTool, + AnyToolInvocation, } from '../index.js'; import { Part, PartListUnion } from '@google/genai'; import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js'; import { - isModifiableTool, + isModifiableDeclarativeTool, ModifyContext, modifyWithEditor, } from '../tools/modifiable-tool.js'; @@ -33,7 +34,8 @@ import * as Diff from 'diff'; export type ValidatingToolCall = { status: 'validating'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; startTime?: number; outcome?: ToolConfirmationOutcome; }; @@ -41,7 +43,8 @@ export type ValidatingToolCall = { export type ScheduledToolCall = { status: 'scheduled'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; startTime?: number; outcome?: ToolConfirmationOutcome; }; @@ -50,6 +53,7 @@ export type ErroredToolCall = { status: 'error'; request: ToolCallRequestInfo; response: ToolCallResponseInfo; + tool?: AnyDeclarativeTool; durationMs?: number; outcome?: ToolConfirmationOutcome; }; @@ -57,8 +61,9 @@ export type ErroredToolCall = { export type SuccessfulToolCall = { status: 'success'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; response: ToolCallResponseInfo; + invocation: AnyToolInvocation; durationMs?: number; outcome?: ToolConfirmationOutcome; }; @@ -66,7 +71,8 @@ export type SuccessfulToolCall = { export type ExecutingToolCall = { status: 'executing'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; liveOutput?: string; startTime?: number; outcome?: ToolConfirmationOutcome; @@ -76,7 +82,8 @@ export type CancelledToolCall = { status: 'cancelled'; request: ToolCallRequestInfo; response: ToolCallResponseInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; durationMs?: number; outcome?: ToolConfirmationOutcome; }; @@ -84,7 +91,8 @@ export type CancelledToolCall = { export type WaitingToolCall = { status: 'awaiting_approval'; request: ToolCallRequestInfo; - tool: Tool; + tool: AnyDeclarativeTool; + invocation: AnyToolInvocation; confirmationDetails: ToolCallConfirmationDetails; startTime?: number; outcome?: ToolConfirmationOutcome; @@ -289,6 +297,7 @@ export class CoreToolScheduler { // currentCall is a non-terminal state here and should have startTime and tool. const existingStartTime = currentCall.startTime; const toolInstance = currentCall.tool; + const invocation = currentCall.invocation; const outcome = currentCall.outcome; @@ -300,6 +309,7 @@ export class CoreToolScheduler { return { request: currentCall.request, tool: toolInstance, + invocation, status: 'success', response: auxiliaryData as ToolCallResponseInfo, durationMs, @@ -313,6 +323,7 @@ export class CoreToolScheduler { return { request: currentCall.request, status: 'error', + tool: toolInstance, response: auxiliaryData as ToolCallResponseInfo, durationMs, outcome, @@ -326,6 +337,7 @@ export class CoreToolScheduler { confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, startTime: existingStartTime, outcome, + invocation, } as WaitingToolCall; case 'scheduled': return { @@ -334,6 +346,7 @@ export class CoreToolScheduler { status: 'scheduled', startTime: existingStartTime, outcome, + invocation, } as ScheduledToolCall; case 'cancelled': { const durationMs = existingStartTime @@ -358,6 +371,7 @@ export class CoreToolScheduler { return { request: currentCall.request, tool: toolInstance, + invocation, status: 'cancelled', response: { callId: currentCall.request.callId, @@ -385,6 +399,7 @@ export class CoreToolScheduler { status: 'validating', startTime: existingStartTime, outcome, + invocation, } as ValidatingToolCall; case 'executing': return { @@ -393,6 +408,7 @@ export class CoreToolScheduler { status: 'executing', startTime: existingStartTime, outcome, + invocation, } as ExecutingToolCall; default: { const exhaustiveCheck: never = newStatus; @@ -406,10 +422,34 @@ export class CoreToolScheduler { private setArgsInternal(targetCallId: string, args: unknown): void { this.toolCalls = this.toolCalls.map((call) => { - if (call.request.callId !== targetCallId) return call; + // We should never be asked to set args on an ErroredToolCall, but + // we guard for the case anyways. + if (call.request.callId !== targetCallId || call.status === 'error') { + return call; + } + + const invocationOrError = this.buildInvocation( + call.tool, + args as Record, + ); + if (invocationOrError instanceof Error) { + const response = createErrorResponse( + call.request, + invocationOrError, + ToolErrorType.INVALID_TOOL_PARAMS, + ); + return { + request: { ...call.request, args: args as Record }, + status: 'error', + tool: call.tool, + response, + } as ErroredToolCall; + } + return { ...call, request: { ...call.request, args: args as Record }, + invocation: invocationOrError, }; }); } @@ -421,6 +461,20 @@ export class CoreToolScheduler { ); } + private buildInvocation( + tool: AnyDeclarativeTool, + args: object, + ): AnyToolInvocation | Error { + try { + return tool.build(args); + } catch (e) { + if (e instanceof Error) { + return e; + } + return new Error(String(e)); + } + } + async schedule( request: ToolCallRequestInfo | ToolCallRequestInfo[], signal: AbortSignal, @@ -448,10 +502,30 @@ export class CoreToolScheduler { durationMs: 0, }; } + + const invocationOrError = this.buildInvocation( + toolInstance, + reqInfo.args, + ); + if (invocationOrError instanceof Error) { + return { + status: 'error', + request: reqInfo, + tool: toolInstance, + response: createErrorResponse( + reqInfo, + invocationOrError, + ToolErrorType.INVALID_TOOL_PARAMS, + ), + durationMs: 0, + }; + } + return { status: 'validating', request: reqInfo, tool: toolInstance, + invocation: invocationOrError, startTime: Date.now(), }; }, @@ -465,7 +539,8 @@ export class CoreToolScheduler { continue; } - const { request: reqInfo, tool: toolInstance } = toolCall; + const { request: reqInfo, invocation } = toolCall; + try { if (this.config.getApprovalMode() === ApprovalMode.YOLO) { this.setToolCallOutcome( @@ -474,10 +549,8 @@ export class CoreToolScheduler { ); this.setStatusInternal(reqInfo.callId, 'scheduled'); } else { - const confirmationDetails = await toolInstance.shouldConfirmExecute( - reqInfo.args, - signal, - ); + const confirmationDetails = + await invocation.shouldConfirmExecute(signal); if (confirmationDetails) { // Allow IDE to resolve confirmation @@ -573,7 +646,7 @@ export class CoreToolScheduler { ); } else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) { const waitingToolCall = toolCall as WaitingToolCall; - if (isModifiableTool(waitingToolCall.tool)) { + if (isModifiableDeclarativeTool(waitingToolCall.tool)) { const modifyContext = waitingToolCall.tool.getModifyContext(signal); const editorType = this.getPreferredEditor(); if (!editorType) { @@ -628,7 +701,7 @@ export class CoreToolScheduler { ): Promise { if ( toolCall.confirmationDetails.type !== 'edit' || - !isModifiableTool(toolCall.tool) + !isModifiableDeclarativeTool(toolCall.tool) ) { return; } @@ -677,6 +750,7 @@ export class CoreToolScheduler { const scheduledCall = toolCall; const { callId, name: toolName } = scheduledCall.request; + const invocation = scheduledCall.invocation; this.setStatusInternal(callId, 'executing'); const liveOutputCallback = @@ -694,8 +768,8 @@ export class CoreToolScheduler { } : undefined; - scheduledCall.tool - .execute(scheduledCall.request.args, signal, liveOutputCallback) + invocation + .execute(signal, liveOutputCallback) .then(async (toolResult: ToolResult) => { if (signal.aborted) { this.setStatusInternal( diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index 1bbb9209..b0ed7107 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -10,12 +10,10 @@ import { ToolRegistry, ToolCallRequestInfo, ToolResult, - Tool, - ToolCallConfirmationDetails, Config, - Icon, } from '../index.js'; -import { Part, Type } from '@google/genai'; +import { Part } from '@google/genai'; +import { MockTool } from '../test-utils/tools.js'; const mockConfig = { getSessionId: () => 'test-session-id', @@ -25,36 +23,11 @@ const mockConfig = { describe('executeToolCall', () => { let mockToolRegistry: ToolRegistry; - let mockTool: Tool; + let mockTool: MockTool; let abortController: AbortController; beforeEach(() => { - mockTool = { - name: 'testTool', - displayName: 'Test Tool', - description: 'A tool for testing', - icon: Icon.Hammer, - schema: { - name: 'testTool', - description: 'A tool for testing', - parameters: { - type: Type.OBJECT, - properties: { - param1: { type: Type.STRING }, - }, - required: ['param1'], - }, - }, - execute: vi.fn(), - validateToolParams: vi.fn(() => null), - shouldConfirmExecute: vi.fn(() => - Promise.resolve(false as false | ToolCallConfirmationDetails), - ), - isOutputMarkdown: false, - canUpdateOutput: false, - getDescription: vi.fn(), - toolLocations: vi.fn(() => []), - }; + mockTool = new MockTool(); mockToolRegistry = { getTool: vi.fn(), @@ -77,7 +50,7 @@ describe('executeToolCall', () => { returnDisplay: 'Success!', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockResolvedValue(toolResult); + vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, @@ -87,7 +60,7 @@ describe('executeToolCall', () => { ); expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool'); - expect(mockTool.execute).toHaveBeenCalledWith( + expect(mockTool.buildAndExecute).toHaveBeenCalledWith( request.args, abortController.signal, ); @@ -149,7 +122,7 @@ describe('executeToolCall', () => { }; const executionError = new Error('Tool execution failed'); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockRejectedValue(executionError); + vi.spyOn(mockTool, 'buildAndExecute').mockRejectedValue(executionError); const response = await executeToolCall( mockConfig, @@ -183,25 +156,27 @@ describe('executeToolCall', () => { const cancellationError = new Error('Operation cancelled'); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockImplementation(async (_args, signal) => { - if (signal?.aborted) { - return Promise.reject(cancellationError); - } - return new Promise((_resolve, reject) => { - signal?.addEventListener('abort', () => { - reject(cancellationError); + vi.spyOn(mockTool, 'buildAndExecute').mockImplementation( + async (_args, signal) => { + if (signal?.aborted) { + return Promise.reject(cancellationError); + } + return new Promise((_resolve, reject) => { + signal?.addEventListener('abort', () => { + reject(cancellationError); + }); + // Simulate work that might happen if not aborted immediately + const timeoutId = setTimeout( + () => + reject( + new Error('Should have been cancelled if not aborted prior'), + ), + 100, + ); + signal?.addEventListener('abort', () => clearTimeout(timeoutId)); }); - // Simulate work that might happen if not aborted immediately - const timeoutId = setTimeout( - () => - reject( - new Error('Should have been cancelled if not aborted prior'), - ), - 100, - ); - signal?.addEventListener('abort', () => clearTimeout(timeoutId)); - }); - }); + }, + ); abortController.abort(); // Abort before calling const response = await executeToolCall( @@ -232,7 +207,7 @@ describe('executeToolCall', () => { returnDisplay: 'Image processed', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.mocked(mockTool.execute).mockResolvedValue(toolResult); + vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult); const response = await executeToolCall( mockConfig, diff --git a/packages/core/src/core/nonInteractiveToolExecutor.ts b/packages/core/src/core/nonInteractiveToolExecutor.ts index ed235cd3..43061f83 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.ts @@ -65,7 +65,7 @@ export async function executeToolCall( try { // Directly execute without confirmation or live output handling const effectiveAbortSignal = abortSignal ?? new AbortController().signal; - const toolResult: ToolResult = await tool.execute( + const toolResult: ToolResult = await tool.buildAndExecute( toolCallRequest.args, effectiveAbortSignal, // No live output callback for non-interactive mode diff --git a/packages/core/src/telemetry/loggers.test.circular.ts b/packages/core/src/telemetry/loggers.test.circular.ts index 80444a0d..3cf85e46 100644 --- a/packages/core/src/telemetry/loggers.test.circular.ts +++ b/packages/core/src/telemetry/loggers.test.circular.ts @@ -14,7 +14,7 @@ import { ToolCallEvent } from './types.js'; import { Config } from '../config/config.js'; import { CompletedToolCall } from '../core/coreToolScheduler.js'; import { ToolCallRequestInfo, ToolCallResponseInfo } from '../core/turn.js'; -import { Tool } from '../tools/tools.js'; +import { MockTool } from '../test-utils/tools.js'; describe('Circular Reference Handling', () => { it('should handle circular references in tool function arguments', () => { @@ -56,11 +56,13 @@ describe('Circular Reference Handling', () => { errorType: undefined, }; + const tool = new MockTool('mock-tool'); const mockCompletedToolCall: CompletedToolCall = { status: 'success', request: mockRequest, response: mockResponse, - tool: {} as Tool, + tool, + invocation: tool.build({}), durationMs: 100, }; @@ -104,11 +106,13 @@ describe('Circular Reference Handling', () => { errorType: undefined, }; + const tool = new MockTool('mock-tool'); const mockCompletedToolCall: CompletedToolCall = { status: 'success', request: mockRequest, response: mockResponse, - tool: {} as Tool, + tool, + invocation: tool.build({}), durationMs: 100, }; diff --git a/packages/core/src/telemetry/loggers.test.ts b/packages/core/src/telemetry/loggers.test.ts index 3d8116cc..14de83a9 100644 --- a/packages/core/src/telemetry/loggers.test.ts +++ b/packages/core/src/telemetry/loggers.test.ts @@ -5,6 +5,7 @@ */ import { + AnyToolInvocation, AuthType, CompletedToolCall, ContentGeneratorConfig, @@ -432,6 +433,7 @@ describe('loggers', () => { }); it('should log a tool call with all fields', () => { + const tool = new EditTool(mockConfig); const call: CompletedToolCall = { status: 'success', request: { @@ -451,7 +453,8 @@ describe('loggers', () => { error: undefined, errorType: undefined, }, - tool: new EditTool(mockConfig), + tool, + invocation: {} as AnyToolInvocation, durationMs: 100, outcome: ToolConfirmationOutcome.ProceedOnce, }; @@ -581,6 +584,7 @@ describe('loggers', () => { }, outcome: ToolConfirmationOutcome.ModifyWithEditor, tool: new EditTool(mockConfig), + invocation: {} as AnyToolInvocation, durationMs: 100, }; const event = new ToolCallEvent(call); @@ -645,6 +649,7 @@ describe('loggers', () => { errorType: undefined, }, tool: new EditTool(mockConfig), + invocation: {} as AnyToolInvocation, durationMs: 100, }; const event = new ToolCallEvent(call); diff --git a/packages/core/src/telemetry/uiTelemetry.test.ts b/packages/core/src/telemetry/uiTelemetry.test.ts index 221804d2..ac9727f1 100644 --- a/packages/core/src/telemetry/uiTelemetry.test.ts +++ b/packages/core/src/telemetry/uiTelemetry.test.ts @@ -23,7 +23,8 @@ import { SuccessfulToolCall, } from '../core/coreToolScheduler.js'; import { ToolErrorType } from '../tools/tool-error.js'; -import { Tool, ToolConfirmationOutcome } from '../tools/tools.js'; +import { ToolConfirmationOutcome } from '../tools/tools.js'; +import { MockTool } from '../test-utils/tools.js'; const createFakeCompletedToolCall = ( name: string, @@ -39,12 +40,14 @@ const createFakeCompletedToolCall = ( isClientInitiated: false, prompt_id: 'prompt-id-1', }; + const tool = new MockTool(name); if (success) { return { status: 'success', request, - tool: { name } as Tool, // Mock tool + tool, + invocation: tool.build({}), response: { callId: request.callId, responseParts: { @@ -65,6 +68,7 @@ const createFakeCompletedToolCall = ( return { status: 'error', request, + tool, response: { callId: request.callId, responseParts: { diff --git a/packages/core/src/test-utils/tools.ts b/packages/core/src/test-utils/tools.ts new file mode 100644 index 00000000..b168db9c --- /dev/null +++ b/packages/core/src/test-utils/tools.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi } from 'vitest'; +import { + BaseTool, + Icon, + ToolCallConfirmationDetails, + ToolResult, +} from '../tools/tools.js'; +import { Schema, Type } from '@google/genai'; + +/** + * A highly configurable mock tool for testing purposes. + */ +export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> { + executeFn = vi.fn(); + shouldConfirm = false; + + constructor( + name = 'mock-tool', + displayName?: string, + description = 'A mock tool for testing.', + params: Schema = { + type: Type.OBJECT, + properties: { param: { type: Type.STRING } }, + }, + ) { + super(name, displayName ?? name, description, Icon.Hammer, params); + } + + async execute( + params: { [key: string]: unknown }, + _abortSignal: AbortSignal, + ): Promise { + const result = this.executeFn(params); + return ( + result ?? { + llmContent: `Tool ${this.name} executed successfully.`, + returnDisplay: `Tool ${this.name} executed successfully.`, + } + ); + } + + async shouldConfirmExecute( + _params: { [key: string]: unknown }, + _abortSignal: AbortSignal, + ): Promise { + if (this.shouldConfirm) { + return { + type: 'exec' as const, + title: `Confirm ${this.displayName}`, + command: this.name, + rootCommand: this.name, + onConfirm: async () => {}, + }; + } + return false; + } +} diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 0d129e42..853ad4c1 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -26,7 +26,7 @@ import { Config, ApprovalMode } from '../config/config.js'; import { ensureCorrectEdit } from '../utils/editCorrector.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { ReadFileTool } from './read-file.js'; -import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; +import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; /** * Parameters for the Edit tool @@ -72,7 +72,7 @@ interface CalculatedEdit { */ export class EditTool extends BaseTool - implements ModifiableTool + implements ModifiableDeclarativeTool { static readonly Name = 'replace'; diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index 847ea5cf..f3bf315b 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -18,7 +18,7 @@ import { homedir } from 'os'; import * as Diff from 'diff'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { tildeifyPath } from '../utils/paths.js'; -import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; +import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -112,7 +112,7 @@ function ensureNewlineSeparation(currentContent: string): string { export class MemoryTool extends BaseTool - implements ModifiableTool + implements ModifiableDeclarativeTool { private static readonly allowlist: Set = new Set(); diff --git a/packages/core/src/tools/modifiable-tool.test.ts b/packages/core/src/tools/modifiable-tool.test.ts index eb7e8dbf..dc68640a 100644 --- a/packages/core/src/tools/modifiable-tool.test.ts +++ b/packages/core/src/tools/modifiable-tool.test.ts @@ -8,8 +8,8 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { modifyWithEditor, ModifyContext, - ModifiableTool, - isModifiableTool, + ModifiableDeclarativeTool, + isModifiableDeclarativeTool, } from './modifiable-tool.js'; import { EditorType } from '../utils/editor.js'; import fs from 'fs'; @@ -338,16 +338,16 @@ describe('isModifiableTool', () => { const mockTool = { name: 'test-tool', getModifyContext: vi.fn(), - } as unknown as ModifiableTool; + } as unknown as ModifiableDeclarativeTool; - expect(isModifiableTool(mockTool)).toBe(true); + expect(isModifiableDeclarativeTool(mockTool)).toBe(true); }); it('should return false for objects without getModifyContext method', () => { const mockTool = { name: 'test-tool', - } as unknown as ModifiableTool; + } as unknown as ModifiableDeclarativeTool; - expect(isModifiableTool(mockTool)).toBe(false); + expect(isModifiableDeclarativeTool(mockTool)).toBe(false); }); }); diff --git a/packages/core/src/tools/modifiable-tool.ts b/packages/core/src/tools/modifiable-tool.ts index 42de3eb6..25a2906b 100644 --- a/packages/core/src/tools/modifiable-tool.ts +++ b/packages/core/src/tools/modifiable-tool.ts @@ -11,13 +11,14 @@ import fs from 'fs'; import * as Diff from 'diff'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { isNodeError } from '../utils/errors.js'; -import { Tool } from './tools.js'; +import { AnyDeclarativeTool, DeclarativeTool, ToolResult } from './tools.js'; /** - * A tool that supports a modify operation. + * A declarative tool that supports a modify operation. */ -export interface ModifiableTool extends Tool { - getModifyContext(abortSignal: AbortSignal): ModifyContext; +export interface ModifiableDeclarativeTool + extends DeclarativeTool { + getModifyContext(abortSignal: AbortSignal): ModifyContext; } export interface ModifyContext { @@ -39,9 +40,12 @@ export interface ModifyResult { updatedDiff: string; } -export function isModifiableTool( - tool: Tool, -): tool is ModifiableTool { +/** + * Type guard to check if a declarative tool is modifiable. + */ +export function isModifiableDeclarativeTool( + tool: AnyDeclarativeTool, +): tool is ModifiableDeclarativeTool { return 'getModifyContext' in tool; } diff --git a/packages/core/src/tools/read-file.test.ts b/packages/core/src/tools/read-file.test.ts index fa1e458c..bb9317fd 100644 --- a/packages/core/src/tools/read-file.test.ts +++ b/packages/core/src/tools/read-file.test.ts @@ -13,6 +13,7 @@ import fsp from 'fs/promises'; import { Config } from '../config/config.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; +import { ToolInvocation, ToolResult } from './tools.js'; describe('ReadFileTool', () => { let tempRootDir: string; @@ -40,57 +41,62 @@ describe('ReadFileTool', () => { } }); - describe('validateToolParams', () => { - it('should return null for valid params (absolute path within root)', () => { + describe('build', () => { + it('should return an invocation for valid params (absolute path within root)', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), }; - expect(tool.validateToolParams(params)).toBeNull(); + const result = tool.build(params); + expect(result).not.toBeTypeOf('string'); + expect(typeof result).toBe('object'); + expect( + (result as ToolInvocation).params, + ).toEqual(params); }); - it('should return null for valid params with offset and limit', () => { + it('should return an invocation for valid params with offset and limit', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), offset: 0, limit: 10, }; - expect(tool.validateToolParams(params)).toBeNull(); + const result = tool.build(params); + expect(result).not.toBeTypeOf('string'); }); - it('should return error for relative path', () => { + it('should throw error for relative path', () => { const params: ReadFileToolParams = { absolute_path: 'test.txt' }; - expect(tool.validateToolParams(params)).toBe( + expect(() => tool.build(params)).toThrow( `File path must be absolute, but was relative: test.txt. You must provide an absolute path.`, ); }); - it('should return error for path outside root', () => { + it('should throw error for path outside root', () => { const outsidePath = path.resolve(os.tmpdir(), 'outside-root.txt'); const params: ReadFileToolParams = { absolute_path: outsidePath }; - const error = tool.validateToolParams(params); - expect(error).toContain( + expect(() => tool.build(params)).toThrow( 'File path must be within one of the workspace directories', ); }); - it('should return error for negative offset', () => { + it('should throw error for negative offset', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), offset: -1, limit: 10, }; - expect(tool.validateToolParams(params)).toBe( + expect(() => tool.build(params)).toThrow( 'Offset must be a non-negative number', ); }); - it('should return error for non-positive limit', () => { + it('should throw error for non-positive limit', () => { const paramsZero: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'test.txt'), offset: 0, limit: 0, }; - expect(tool.validateToolParams(paramsZero)).toBe( + expect(() => tool.build(paramsZero)).toThrow( 'Limit must be a positive number', ); const paramsNegative: ReadFileToolParams = { @@ -98,168 +104,182 @@ describe('ReadFileTool', () => { offset: 0, limit: -5, }; - expect(tool.validateToolParams(paramsNegative)).toBe( + expect(() => tool.build(paramsNegative)).toThrow( 'Limit must be a positive number', ); }); - it('should return error for schema validation failure (e.g. missing path)', () => { + it('should throw error for schema validation failure (e.g. missing path)', () => { const params = { offset: 0 } as unknown as ReadFileToolParams; - expect(tool.validateToolParams(params)).toBe( + expect(() => tool.build(params)).toThrow( `params must have required property 'absolute_path'`, ); }); }); - describe('getDescription', () => { - it('should return a shortened, relative path', () => { - const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt'); - const params: ReadFileToolParams = { absolute_path: filePath }; - expect(tool.getDescription(params)).toBe( - path.join('sub', 'dir', 'file.txt'), - ); - }); + describe('ToolInvocation', () => { + describe('getDescription', () => { + it('should return a shortened, relative path', () => { + const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params); + expect(typeof invocation).not.toBe('string'); + expect( + ( + invocation as ToolInvocation + ).getDescription(), + ).toBe(path.join('sub', 'dir', 'file.txt')); + }); - it('should return . if path is the root directory', () => { - const params: ReadFileToolParams = { absolute_path: tempRootDir }; - expect(tool.getDescription(params)).toBe('.'); - }); - }); - - describe('execute', () => { - it('should return validation error if params are invalid', async () => { - const params: ReadFileToolParams = { - absolute_path: 'relative/path.txt', - }; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: - 'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.', - returnDisplay: - 'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.', + it('should return . if path is the root directory', () => { + const params: ReadFileToolParams = { absolute_path: tempRootDir }; + const invocation = tool.build(params); + expect(typeof invocation).not.toBe('string'); + expect( + ( + invocation as ToolInvocation + ).getDescription(), + ).toBe('.'); }); }); - it('should return error if file does not exist', async () => { - const filePath = path.join(tempRootDir, 'nonexistent.txt'); - const params: ReadFileToolParams = { absolute_path: filePath }; + describe('execute', () => { + it('should return error if file does not exist', async () => { + const filePath = path.join(tempRootDir, 'nonexistent.txt'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: `File not found: ${filePath}`, - returnDisplay: 'File not found.', - }); - }); - - it('should return success result for a text file', async () => { - const filePath = path.join(tempRootDir, 'textfile.txt'); - const fileContent = 'This is a test file.'; - await fsp.writeFile(filePath, fileContent, 'utf-8'); - const params: ReadFileToolParams = { absolute_path: filePath }; - - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: fileContent, - returnDisplay: '', - }); - }); - - it('should return success result for an image file', async () => { - // A minimal 1x1 transparent PNG file. - const pngContent = Buffer.from([ - 137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0, - 1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, 65, - 84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, 0, 73, - 69, 78, 68, 174, 66, 96, 130, - ]); - const filePath = path.join(tempRootDir, 'image.png'); - await fsp.writeFile(filePath, pngContent); - const params: ReadFileToolParams = { absolute_path: filePath }; - - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: { - inlineData: { - mimeType: 'image/png', - data: pngContent.toString('base64'), - }, - }, - returnDisplay: `Read image file: image.png`, - }); - }); - - it('should treat a non-image file with image extension as an image', async () => { - const filePath = path.join(tempRootDir, 'fake-image.png'); - const fileContent = 'This is not a real png.'; - await fsp.writeFile(filePath, fileContent, 'utf-8'); - const params: ReadFileToolParams = { absolute_path: filePath }; - - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: { - inlineData: { - mimeType: 'image/png', - data: Buffer.from(fileContent).toString('base64'), - }, - }, - returnDisplay: `Read image file: fake-image.png`, - }); - }); - - it('should pass offset and limit to read a slice of a text file', async () => { - const filePath = path.join(tempRootDir, 'paginated.txt'); - const fileContent = Array.from( - { length: 20 }, - (_, i) => `Line ${i + 1}`, - ).join('\n'); - await fsp.writeFile(filePath, fileContent, 'utf-8'); - - const params: ReadFileToolParams = { - absolute_path: filePath, - offset: 5, // Start from line 6 - limit: 3, - }; - - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: [ - '[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]', - 'Line 6', - 'Line 7', - 'Line 8', - ].join('\n'), - returnDisplay: 'Read lines 6-8 of 20 from paginated.txt', - }); - }); - - describe('with .geminiignore', () => { - beforeEach(async () => { - await fsp.writeFile( - path.join(tempRootDir, '.geminiignore'), - ['foo.*', 'ignored/'].join('\n'), - ); - }); - - it('should return error if path is ignored by a .geminiignore pattern', async () => { - const ignoredFilePath = path.join(tempRootDir, 'foo.bar'); - await fsp.writeFile(ignoredFilePath, 'content', 'utf-8'); - const params: ReadFileToolParams = { - absolute_path: ignoredFilePath, - }; - const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`, - returnDisplay: expectedError, + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: `File not found: ${filePath}`, + returnDisplay: 'File not found.', }); }); - it('should return error if path is in an ignored directory', async () => { - const ignoredDirPath = path.join(tempRootDir, 'ignored'); - await fsp.mkdir(ignoredDirPath); - const filePath = path.join(ignoredDirPath, 'somefile.txt'); - await fsp.writeFile(filePath, 'content', 'utf-8'); + it('should return success result for a text file', async () => { + const filePath = path.join(tempRootDir, 'textfile.txt'); + const fileContent = 'This is a test file.'; + await fsp.writeFile(filePath, fileContent, 'utf-8'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; + + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: fileContent, + returnDisplay: '', + }); + }); + + it('should return success result for an image file', async () => { + // A minimal 1x1 transparent PNG file. + const pngContent = Buffer.from([ + 137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0, + 1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, + 65, 84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, + 0, 73, 69, 78, 68, 174, 66, 96, 130, + ]); + const filePath = path.join(tempRootDir, 'image.png'); + await fsp.writeFile(filePath, pngContent); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; + + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: { + inlineData: { + mimeType: 'image/png', + data: pngContent.toString('base64'), + }, + }, + returnDisplay: `Read image file: image.png`, + }); + }); + + it('should treat a non-image file with image extension as an image', async () => { + const filePath = path.join(tempRootDir, 'fake-image.png'); + const fileContent = 'This is not a real png.'; + await fsp.writeFile(filePath, fileContent, 'utf-8'); + const params: ReadFileToolParams = { absolute_path: filePath }; + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; + + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: { + inlineData: { + mimeType: 'image/png', + data: Buffer.from(fileContent).toString('base64'), + }, + }, + returnDisplay: `Read image file: fake-image.png`, + }); + }); + + it('should pass offset and limit to read a slice of a text file', async () => { + const filePath = path.join(tempRootDir, 'paginated.txt'); + const fileContent = Array.from( + { length: 20 }, + (_, i) => `Line ${i + 1}`, + ).join('\n'); + await fsp.writeFile(filePath, fileContent, 'utf-8'); const params: ReadFileToolParams = { absolute_path: filePath, + offset: 5, // Start from line 6 + limit: 3, }; - const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`; - expect(await tool.execute(params, abortSignal)).toEqual({ - llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`, - returnDisplay: expectedError, + const invocation = tool.build(params) as ToolInvocation< + ReadFileToolParams, + ToolResult + >; + + expect(await invocation.execute(abortSignal)).toEqual({ + llmContent: [ + '[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]', + 'Line 6', + 'Line 7', + 'Line 8', + ].join('\n'), + returnDisplay: 'Read lines 6-8 of 20 from paginated.txt', + }); + }); + + describe('with .geminiignore', () => { + beforeEach(async () => { + await fsp.writeFile( + path.join(tempRootDir, '.geminiignore'), + ['foo.*', 'ignored/'].join('\n'), + ); + }); + + it('should throw error if path is ignored by a .geminiignore pattern', async () => { + const ignoredFilePath = path.join(tempRootDir, 'foo.bar'); + await fsp.writeFile(ignoredFilePath, 'content', 'utf-8'); + const params: ReadFileToolParams = { + absolute_path: ignoredFilePath, + }; + const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`; + expect(() => tool.build(params)).toThrow(expectedError); + }); + + it('should throw error if path is in an ignored directory', async () => { + const ignoredDirPath = path.join(tempRootDir, 'ignored'); + await fsp.mkdir(ignoredDirPath); + const filePath = path.join(ignoredDirPath, 'somefile.txt'); + await fsp.writeFile(filePath, 'content', 'utf-8'); + + const params: ReadFileToolParams = { + absolute_path: filePath, + }; + const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`; + expect(() => tool.build(params)).toThrow(expectedError); }); }); }); @@ -270,18 +290,16 @@ describe('ReadFileTool', () => { const params: ReadFileToolParams = { absolute_path: path.join(tempRootDir, 'file.txt'), }; - expect(tool.validateToolParams(params)).toBeNull(); + expect(() => tool.build(params)).not.toThrow(); }); it('should reject paths outside workspace root', () => { const params: ReadFileToolParams = { absolute_path: '/etc/passwd', }; - const error = tool.validateToolParams(params); - expect(error).toContain( + expect(() => tool.build(params)).toThrow( 'File path must be within one of the workspace directories', ); - expect(error).toContain(tempRootDir); }); it('should provide clear error message with workspace directories', () => { @@ -289,11 +307,9 @@ describe('ReadFileTool', () => { const params: ReadFileToolParams = { absolute_path: outsidePath, }; - const error = tool.validateToolParams(params); - expect(error).toContain( + expect(() => tool.build(params)).toThrow( 'File path must be within one of the workspace directories', ); - expect(error).toContain(tempRootDir); }); }); }); diff --git a/packages/core/src/tools/read-file.ts b/packages/core/src/tools/read-file.ts index 31282c20..3a05da06 100644 --- a/packages/core/src/tools/read-file.ts +++ b/packages/core/src/tools/read-file.ts @@ -7,7 +7,13 @@ import path from 'path'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; -import { BaseTool, Icon, ToolLocation, ToolResult } from './tools.js'; +import { + BaseDeclarativeTool, + Icon, + ToolInvocation, + ToolLocation, + ToolResult, +} from './tools.js'; import { Type } from '@google/genai'; import { processSingleFileContent, @@ -39,10 +45,72 @@ export interface ReadFileToolParams { limit?: number; } +class ReadFileToolInvocation + implements ToolInvocation +{ + constructor( + private config: Config, + public params: ReadFileToolParams, + ) {} + + getDescription(): string { + const relativePath = makeRelative( + this.params.absolute_path, + this.config.getTargetDir(), + ); + return shortenPath(relativePath); + } + + toolLocations(): ToolLocation[] { + return [{ path: this.params.absolute_path, line: this.params.offset }]; + } + + shouldConfirmExecute(): Promise { + return Promise.resolve(false); + } + + async execute(): Promise { + const result = await processSingleFileContent( + this.params.absolute_path, + this.config.getTargetDir(), + this.params.offset, + this.params.limit, + ); + + if (result.error) { + return { + llmContent: result.error, // The detailed error for LLM + returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error + }; + } + + const lines = + typeof result.llmContent === 'string' + ? result.llmContent.split('\n').length + : undefined; + const mimetype = getSpecificMimeType(this.params.absolute_path); + recordFileOperationMetric( + this.config, + FileOperation.READ, + lines, + mimetype, + path.extname(this.params.absolute_path), + ); + + return { + llmContent: result.llmContent || '', + returnDisplay: result.returnDisplay || '', + }; + } +} + /** * Implementation of the ReadFile tool logic */ -export class ReadFileTool extends BaseTool { +export class ReadFileTool extends BaseDeclarativeTool< + ReadFileToolParams, + ToolResult +> { static readonly Name: string = 'read_file'; constructor(private config: Config) { @@ -75,7 +143,7 @@ export class ReadFileTool extends BaseTool { ); } - validateToolParams(params: ReadFileToolParams): string | null { + protected validateToolParams(params: ReadFileToolParams): string | null { const errors = SchemaValidator.validate(this.schema.parameters, params); if (errors) { return errors; @@ -106,67 +174,9 @@ export class ReadFileTool extends BaseTool { return null; } - getDescription(params: ReadFileToolParams): string { - if ( - !params || - typeof params.absolute_path !== 'string' || - params.absolute_path.trim() === '' - ) { - return `Path unavailable`; - } - const relativePath = makeRelative( - params.absolute_path, - this.config.getTargetDir(), - ); - return shortenPath(relativePath); - } - - toolLocations(params: ReadFileToolParams): ToolLocation[] { - return [{ path: params.absolute_path, line: params.offset }]; - } - - async execute( + protected createInvocation( params: ReadFileToolParams, - _signal: AbortSignal, - ): Promise { - const validationError = this.validateToolParams(params); - if (validationError) { - return { - llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`, - returnDisplay: validationError, - }; - } - - const result = await processSingleFileContent( - params.absolute_path, - this.config.getTargetDir(), - params.offset, - params.limit, - ); - - if (result.error) { - return { - llmContent: result.error, // The detailed error for LLM - returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error - }; - } - - const lines = - typeof result.llmContent === 'string' - ? result.llmContent.split('\n').length - : undefined; - const mimetype = getSpecificMimeType(params.absolute_path); - recordFileOperationMetric( - this.config, - FileOperation.READ, - lines, - mimetype, - path.extname(params.absolute_path), - ); - - return { - llmContent: result.llmContent || '', - returnDisplay: result.returnDisplay || '', - }; + ): ToolInvocation { + return new ReadFileToolInvocation(this.config, params); } } diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 24b6ca5f..e7c71e14 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -21,7 +21,6 @@ import { sanitizeParameters, } from './tool-registry.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; -import { BaseTool, Icon, ToolResult } from './tools.js'; import { FunctionDeclaration, CallableTool, @@ -32,6 +31,7 @@ import { import { spawn } from 'node:child_process'; import fs from 'node:fs'; +import { MockTool } from '../test-utils/tools.js'; vi.mock('node:fs'); @@ -107,28 +107,6 @@ const createMockCallableTool = ( callTool: vi.fn(), }); -class MockTool extends BaseTool<{ param: string }, ToolResult> { - constructor( - name = 'mock-tool', - displayName = 'A mock tool', - description = 'A mock tool description', - ) { - super(name, displayName, description, Icon.Hammer, { - type: Type.OBJECT, - properties: { - param: { type: Type.STRING }, - }, - required: ['param'], - }); - } - async execute(params: { param: string }): Promise { - return { - llmContent: `Executed with ${params.param}`, - returnDisplay: `Executed with ${params.param}`, - }; - } -} - const baseConfigParams: ConfigParameters = { cwd: '/tmp', model: 'test-model', diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index e60b8f74..73b427d4 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -5,7 +5,7 @@ */ import { FunctionDeclaration, Schema, Type } from '@google/genai'; -import { Tool, ToolResult, BaseTool, Icon } from './tools.js'; +import { AnyDeclarativeTool, Icon, ToolResult, BaseTool } from './tools.js'; import { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; @@ -125,7 +125,7 @@ Signal: Signal number or \`(none)\` if no signal was received. } export class ToolRegistry { - private tools: Map = new Map(); + private tools: Map = new Map(); private config: Config; constructor(config: Config) { @@ -136,7 +136,7 @@ export class ToolRegistry { * Registers a tool definition. * @param tool - The tool object containing schema and execution logic. */ - registerTool(tool: Tool): void { + registerTool(tool: AnyDeclarativeTool): void { if (this.tools.has(tool.name)) { if (tool instanceof DiscoveredMCPTool) { tool = tool.asFullyQualifiedTool(); @@ -368,7 +368,7 @@ export class ToolRegistry { /** * Returns an array of all registered and discovered tool instances. */ - getAllTools(): Tool[] { + getAllTools(): AnyDeclarativeTool[] { return Array.from(this.tools.values()).sort((a, b) => a.displayName.localeCompare(b.displayName), ); @@ -377,8 +377,8 @@ export class ToolRegistry { /** * Returns an array of tools registered from a specific MCP server. */ - getToolsByServer(serverName: string): Tool[] { - const serverTools: Tool[] = []; + getToolsByServer(serverName: string): AnyDeclarativeTool[] { + const serverTools: AnyDeclarativeTool[] = []; for (const tool of this.tools.values()) { if ((tool as DiscoveredMCPTool)?.serverName === serverName) { serverTools.push(tool); @@ -390,7 +390,7 @@ export class ToolRegistry { /** * Get the definition of a specific tool. */ - getTool(name: string): Tool | undefined { + getTool(name: string): AnyDeclarativeTool | undefined { return this.tools.get(name); } } diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 3404093f..79e6f010 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -9,101 +9,243 @@ import { ToolErrorType } from './tool-error.js'; import { DiffUpdateResult } from '../ide/ideContext.js'; /** - * Interface representing the base Tool functionality + * Represents a validated and ready-to-execute tool call. + * An instance of this is created by a `ToolBuilder`. */ -export interface Tool< - TParams = unknown, - TResult extends ToolResult = ToolResult, +export interface ToolInvocation< + TParams extends object, + TResult extends ToolResult, > { /** - * The internal name of the tool (used for API calls) + * The validated parameters for this specific invocation. */ - name: string; + params: TParams; /** - * The user-friendly display name of the tool + * Gets a pre-execution description of the tool operation. + * @returns A markdown string describing what the tool will do. */ - displayName: string; + getDescription(): string; /** - * Description of what the tool does + * Determines what file system paths the tool will affect. + * @returns A list of such paths. */ - description: string; + toolLocations(): ToolLocation[]; /** - * The icon to display when interacting via ACP - */ - icon: Icon; - - /** - * Function declaration schema from @google/genai - */ - schema: FunctionDeclaration; - - /** - * Whether the tool's output should be rendered as markdown - */ - isOutputMarkdown: boolean; - - /** - * Whether the tool supports live (streaming) output - */ - canUpdateOutput: boolean; - - /** - * Validates the parameters for the tool - * Should be called from both `shouldConfirmExecute` and `execute` - * `shouldConfirmExecute` should return false immediately if invalid - * @param params Parameters to validate - * @returns An error message string if invalid, null otherwise - */ - validateToolParams(params: TParams): string | null; - - /** - * Gets a pre-execution description of the tool operation - * @param params Parameters for the tool execution - * @returns A markdown string describing what the tool will do - * Optional for backward compatibility - */ - getDescription(params: TParams): string; - - /** - * Determines what file system paths the tool will affect - * @param params Parameters for the tool execution - * @returns A list of such paths - */ - toolLocations(params: TParams): ToolLocation[]; - - /** - * Determines if the tool should prompt for confirmation before execution - * @param params Parameters for the tool execution - * @returns Whether execute should be confirmed. + * Determines if the tool should prompt for confirmation before execution. + * @returns Confirmation details or false if no confirmation is needed. */ shouldConfirmExecute( - params: TParams, abortSignal: AbortSignal, ): Promise; /** - * Executes the tool with the given parameters - * @param params Parameters for the tool execution - * @returns Result of the tool execution + * Executes the tool with the validated parameters. + * @param signal AbortSignal for tool cancellation. + * @param updateOutput Optional callback to stream output. + * @returns Result of the tool execution. */ execute( - params: TParams, signal: AbortSignal, updateOutput?: (output: string) => void, ): Promise; } +/** + * A type alias for a tool invocation where the specific parameter and result types are not known. + */ +export type AnyToolInvocation = ToolInvocation; + +/** + * An adapter that wraps the legacy `Tool` interface to make it compatible + * with the new `ToolInvocation` pattern. + */ +export class LegacyToolInvocation< + TParams extends object, + TResult extends ToolResult, +> implements ToolInvocation +{ + constructor( + private readonly legacyTool: BaseTool, + readonly params: TParams, + ) {} + + getDescription(): string { + return this.legacyTool.getDescription(this.params); + } + + toolLocations(): ToolLocation[] { + return this.legacyTool.toolLocations(this.params); + } + + shouldConfirmExecute( + abortSignal: AbortSignal, + ): Promise { + return this.legacyTool.shouldConfirmExecute(this.params, abortSignal); + } + + execute( + signal: AbortSignal, + updateOutput?: (output: string) => void, + ): Promise { + return this.legacyTool.execute(this.params, signal, updateOutput); + } +} + +/** + * Interface for a tool builder that validates parameters and creates invocations. + */ +export interface ToolBuilder< + TParams extends object, + TResult extends ToolResult, +> { + /** + * The internal name of the tool (used for API calls). + */ + name: string; + + /** + * The user-friendly display name of the tool. + */ + displayName: string; + + /** + * Description of what the tool does. + */ + description: string; + + /** + * The icon to display when interacting via ACP. + */ + icon: Icon; + + /** + * Function declaration schema from @google/genai. + */ + schema: FunctionDeclaration; + + /** + * Whether the tool's output should be rendered as markdown. + */ + isOutputMarkdown: boolean; + + /** + * Whether the tool supports live (streaming) output. + */ + canUpdateOutput: boolean; + + /** + * Validates raw parameters and builds a ready-to-execute invocation. + * @param params The raw, untrusted parameters from the model. + * @returns A valid `ToolInvocation` if successful. Throws an error if validation fails. + */ + build(params: TParams): ToolInvocation; +} + +/** + * New base class for tools that separates validation from execution. + * New tools should extend this class. + */ +export abstract class DeclarativeTool< + TParams extends object, + TResult extends ToolResult, +> implements ToolBuilder +{ + constructor( + readonly name: string, + readonly displayName: string, + readonly description: string, + readonly icon: Icon, + readonly parameterSchema: Schema, + readonly isOutputMarkdown: boolean = true, + readonly canUpdateOutput: boolean = false, + ) {} + + get schema(): FunctionDeclaration { + return { + name: this.name, + description: this.description, + parameters: this.parameterSchema, + }; + } + + /** + * Validates the raw tool parameters. + * Subclasses should override this to add custom validation logic + * beyond the JSON schema check. + * @param params The raw parameters from the model. + * @returns An error message string if invalid, null otherwise. + */ + protected validateToolParams(_params: TParams): string | null { + // Base implementation can be extended by subclasses. + return null; + } + + /** + * The core of the new pattern. It validates parameters and, if successful, + * returns a `ToolInvocation` object that encapsulates the logic for the + * specific, validated call. + * @param params The raw, untrusted parameters from the model. + * @returns A `ToolInvocation` instance. + */ + abstract build(params: TParams): ToolInvocation; + + /** + * A convenience method that builds and executes the tool in one step. + * Throws an error if validation fails. + * @param params The raw, untrusted parameters from the model. + * @param signal AbortSignal for tool cancellation. + * @param updateOutput Optional callback to stream output. + * @returns The result of the tool execution. + */ + async buildAndExecute( + params: TParams, + signal: AbortSignal, + updateOutput?: (output: string) => void, + ): Promise { + const invocation = this.build(params); + return invocation.execute(signal, updateOutput); + } +} + +/** + * New base class for declarative tools that separates validation from execution. + * New tools should extend this class, which provides a `build` method that + * validates parameters before deferring to a `createInvocation` method for + * the final `ToolInvocation` object instantiation. + */ +export abstract class BaseDeclarativeTool< + TParams extends object, + TResult extends ToolResult, +> extends DeclarativeTool { + build(params: TParams): ToolInvocation { + const validationError = this.validateToolParams(params); + if (validationError) { + throw new Error(validationError); + } + return this.createInvocation(params); + } + + protected abstract createInvocation( + params: TParams, + ): ToolInvocation; +} + +/** + * A type alias for a declarative tool where the specific parameter and result types are not known. + */ +export type AnyDeclarativeTool = DeclarativeTool; + /** * Base implementation for tools with common functionality + * @deprecated Use `DeclarativeTool` for new tools. */ export abstract class BaseTool< - TParams = unknown, + TParams extends object, TResult extends ToolResult = ToolResult, -> implements Tool -{ +> extends DeclarativeTool { /** * Creates a new instance of BaseTool * @param name Internal name of the tool (used for API calls) @@ -121,17 +263,24 @@ export abstract class BaseTool< readonly parameterSchema: Schema, readonly isOutputMarkdown: boolean = true, readonly canUpdateOutput: boolean = false, - ) {} + ) { + super( + name, + displayName, + description, + icon, + parameterSchema, + isOutputMarkdown, + canUpdateOutput, + ); + } - /** - * Function declaration schema computed from name, description, and parameterSchema - */ - get schema(): FunctionDeclaration { - return { - name: this.name, - description: this.description, - parameters: this.parameterSchema, - }; + build(params: TParams): ToolInvocation { + const validationError = this.validateToolParams(params); + if (validationError) { + throw new Error(validationError); + } + return new LegacyToolInvocation(this, params); } /** diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 32ecc068..9e7e3813 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -26,7 +26,7 @@ import { ensureCorrectFileContent, } from '../utils/editCorrector.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; -import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; +import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; import { getSpecificMimeType } from '../utils/fileUtils.js'; import { recordFileOperationMetric, @@ -66,7 +66,7 @@ interface GetCorrectedFileContentResult { */ export class WriteFileTool extends BaseTool - implements ModifiableTool + implements ModifiableDeclarativeTool { static readonly Name: string = 'write_file';