feat(core): Introduce `DeclarativeTool` and `ToolInvocation`. (#5613)

This commit is contained in:
joshualitt 2025-08-06 10:50:02 -07:00 committed by GitHub
parent 882a97aff9
commit 6133bea388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 991 additions and 681 deletions

View File

@ -239,65 +239,62 @@ class GeminiAgent implements Agent {
); );
} }
let toolCallId; let toolCallId: number | undefined = undefined;
const confirmationDetails = await tool.shouldConfirmExecute(
args,
abortSignal,
);
if (confirmationDetails) {
let content: acp.ToolCallContent | null = null;
if (confirmationDetails.type === 'edit') {
content = {
type: 'diff',
path: confirmationDetails.fileName,
oldText: confirmationDetails.originalContent,
newText: confirmationDetails.newContent,
};
}
const result = await this.client.requestToolCallConfirmation({
label: tool.getDescription(args),
icon: tool.icon,
content,
confirmation: toAcpToolCallConfirmation(confirmationDetails),
locations: tool.toolLocations(args),
});
await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome));
switch (result.outcome) {
case 'reject':
return errorResponse(
new Error(`Tool "${fc.name}" not allowed to run by the user.`),
);
case 'cancel':
return errorResponse(
new Error(`Tool "${fc.name}" was canceled by the user.`),
);
case 'allow':
case 'alwaysAllow':
case 'alwaysAllowMcpServer':
case 'alwaysAllowTool':
break;
default: {
const resultOutcome: never = result.outcome;
throw new Error(`Unexpected: ${resultOutcome}`);
}
}
toolCallId = result.id;
} else {
const result = await this.client.pushToolCall({
icon: tool.icon,
label: tool.getDescription(args),
locations: tool.toolLocations(args),
});
toolCallId = result.id;
}
try { try {
const toolResult: ToolResult = await tool.execute(args, abortSignal); const invocation = tool.build(args);
const confirmationDetails =
await invocation.shouldConfirmExecute(abortSignal);
if (confirmationDetails) {
let content: acp.ToolCallContent | null = null;
if (confirmationDetails.type === 'edit') {
content = {
type: 'diff',
path: confirmationDetails.fileName,
oldText: confirmationDetails.originalContent,
newText: confirmationDetails.newContent,
};
}
const result = await this.client.requestToolCallConfirmation({
label: invocation.getDescription(),
icon: tool.icon,
content,
confirmation: toAcpToolCallConfirmation(confirmationDetails),
locations: invocation.toolLocations(),
});
await confirmationDetails.onConfirm(toToolCallOutcome(result.outcome));
switch (result.outcome) {
case 'reject':
return errorResponse(
new Error(`Tool "${fc.name}" not allowed to run by the user.`),
);
case 'cancel':
return errorResponse(
new Error(`Tool "${fc.name}" was canceled by the user.`),
);
case 'allow':
case 'alwaysAllow':
case 'alwaysAllowMcpServer':
case 'alwaysAllowTool':
break;
default: {
const resultOutcome: never = result.outcome;
throw new Error(`Unexpected: ${resultOutcome}`);
}
}
toolCallId = result.id;
} else {
const result = await this.client.pushToolCall({
icon: tool.icon,
label: invocation.getDescription(),
locations: invocation.toolLocations(),
});
toolCallId = result.id;
}
const toolResult: ToolResult = await invocation.execute(abortSignal);
const toolCallContent = toToolCallContent(toolResult); const toolCallContent = toToolCallContent(toolResult);
await this.client.updateToolCall({ await this.client.updateToolCall({
@ -320,12 +317,13 @@ class GeminiAgent implements Agent {
return convertToFunctionResponse(fc.name, callId, toolResult.llmContent); return convertToFunctionResponse(fc.name, callId, toolResult.llmContent);
} catch (e) { } catch (e) {
const error = e instanceof Error ? e : new Error(String(e)); const error = e instanceof Error ? e : new Error(String(e));
await this.client.updateToolCall({ if (toolCallId) {
toolCallId, await this.client.updateToolCall({
status: 'error', toolCallId,
content: { type: 'markdown', markdown: error.message }, status: 'error',
}); content: { type: 'markdown', markdown: error.message },
});
}
return errorResponse(error); return errorResponse(error);
} }
} }
@ -408,7 +406,7 @@ class GeminiAgent implements Agent {
`Path ${pathName} not found directly, attempting glob search.`, `Path ${pathName} not found directly, attempting glob search.`,
); );
try { try {
const globResult = await globTool.execute( const globResult = await globTool.buildAndExecute(
{ {
pattern: `**/*${pathName}*`, pattern: `**/*${pathName}*`,
path: this.config.getTargetDir(), path: this.config.getTargetDir(),
@ -530,12 +528,15 @@ class GeminiAgent implements Agent {
respectGitIgnore, // Use configuration setting respectGitIgnore, // Use configuration setting
}; };
const toolCall = await this.client.pushToolCall({ let toolCallId: number | undefined = undefined;
icon: readManyFilesTool.icon,
label: readManyFilesTool.getDescription(toolArgs),
});
try { try {
const result = await readManyFilesTool.execute(toolArgs, abortSignal); const invocation = readManyFilesTool.build(toolArgs);
const toolCall = await this.client.pushToolCall({
icon: readManyFilesTool.icon,
label: invocation.getDescription(),
});
toolCallId = toolCall.id;
const result = await invocation.execute(abortSignal);
const content = toToolCallContent(result) || { const content = toToolCallContent(result) || {
type: 'markdown', type: 'markdown',
markdown: `Successfully read: ${contentLabelsForDisplay.join(', ')}`, markdown: `Successfully read: ${contentLabelsForDisplay.join(', ')}`,
@ -578,14 +579,16 @@ class GeminiAgent implements Agent {
return processedQueryParts; return processedQueryParts;
} catch (error: unknown) { } catch (error: unknown) {
await this.client.updateToolCall({ if (toolCallId) {
toolCallId: toolCall.id, await this.client.updateToolCall({
status: 'error', toolCallId,
content: { status: 'error',
type: 'markdown', content: {
markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, type: 'markdown',
}, markdown: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`,
}); },
});
}
throw error; throw error;
} }
} }

View File

@ -8,6 +8,7 @@ import * as fs from 'fs/promises';
import * as path from 'path'; import * as path from 'path';
import { PartListUnion, PartUnion } from '@google/genai'; import { PartListUnion, PartUnion } from '@google/genai';
import { import {
AnyToolInvocation,
Config, Config,
getErrorMessage, getErrorMessage,
isNodeError, isNodeError,
@ -254,7 +255,7 @@ export async function handleAtCommand({
`Path ${pathName} not found directly, attempting glob search.`, `Path ${pathName} not found directly, attempting glob search.`,
); );
try { try {
const globResult = await globTool.execute( const globResult = await globTool.buildAndExecute(
{ {
pattern: `**/*${pathName}*`, pattern: `**/*${pathName}*`,
path: dir, path: dir,
@ -411,12 +412,14 @@ export async function handleAtCommand({
}; };
let toolCallDisplay: IndividualToolCallDisplay; let toolCallDisplay: IndividualToolCallDisplay;
let invocation: AnyToolInvocation | undefined = undefined;
try { try {
const result = await readManyFilesTool.execute(toolArgs, signal); invocation = readManyFilesTool.build(toolArgs);
const result = await invocation.execute(signal);
toolCallDisplay = { toolCallDisplay = {
callId: `client-read-${userMessageTimestamp}`, callId: `client-read-${userMessageTimestamp}`,
name: readManyFilesTool.displayName, name: readManyFilesTool.displayName,
description: readManyFilesTool.getDescription(toolArgs), description: invocation.getDescription(),
status: ToolCallStatus.Success, status: ToolCallStatus.Success,
resultDisplay: resultDisplay:
result.returnDisplay || result.returnDisplay ||
@ -466,7 +469,9 @@ export async function handleAtCommand({
toolCallDisplay = { toolCallDisplay = {
callId: `client-read-${userMessageTimestamp}`, callId: `client-read-${userMessageTimestamp}`,
name: readManyFilesTool.displayName, name: readManyFilesTool.displayName,
description: readManyFilesTool.getDescription(toolArgs), description:
invocation?.getDescription() ??
'Error attempting to execute tool to read files',
status: ToolCallStatus.Error, status: ToolCallStatus.Error,
resultDisplay: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`, resultDisplay: `Error reading files (${contentLabelsForDisplay.join(', ')}): ${getErrorMessage(error)}`,
confirmationDetails: undefined, confirmationDetails: undefined,

View File

@ -21,6 +21,7 @@ import {
EditorType, EditorType,
AuthType, AuthType,
GeminiEventType as ServerGeminiEventType, GeminiEventType as ServerGeminiEventType,
AnyToolInvocation,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js';
@ -452,9 +453,13 @@ describe('useGeminiStream', () => {
}, },
tool: { tool: {
name: 'tool1', name: 'tool1',
displayName: 'tool1',
description: 'desc1', description: 'desc1',
getDescription: vi.fn(), build: vi.fn(),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
startTime: Date.now(), startTime: Date.now(),
endTime: Date.now(), endTime: Date.now(),
} as TrackedCompletedToolCall, } as TrackedCompletedToolCall,
@ -469,9 +474,13 @@ describe('useGeminiStream', () => {
responseSubmittedToGemini: false, responseSubmittedToGemini: false,
tool: { tool: {
name: 'tool2', name: 'tool2',
displayName: 'tool2',
description: 'desc2', description: 'desc2',
getDescription: vi.fn(), build: vi.fn(),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
startTime: Date.now(), startTime: Date.now(),
liveOutput: '...', liveOutput: '...',
} as TrackedExecutingToolCall, } as TrackedExecutingToolCall,
@ -506,6 +515,12 @@ describe('useGeminiStream', () => {
status: 'success', status: 'success',
responseSubmittedToGemini: false, responseSubmittedToGemini: false,
response: { callId: 'call1', responseParts: toolCall1ResponseParts }, response: { callId: 'call1', responseParts: toolCall1ResponseParts },
tool: {
displayName: 'MockTool',
},
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
} as TrackedCompletedToolCall, } as TrackedCompletedToolCall,
{ {
request: { request: {
@ -584,6 +599,12 @@ describe('useGeminiStream', () => {
status: 'cancelled', status: 'cancelled',
response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
responseSubmittedToGemini: false, responseSubmittedToGemini: false,
tool: {
displayName: 'mock tool',
},
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
} as TrackedCancelledToolCall, } as TrackedCancelledToolCall,
]; ];
const client = new MockedGeminiClientClass(mockConfig); const client = new MockedGeminiClientClass(mockConfig);
@ -644,9 +665,13 @@ describe('useGeminiStream', () => {
}, },
tool: { tool: {
name: 'toolA', name: 'toolA',
displayName: 'toolA',
description: 'descA', description: 'descA',
getDescription: vi.fn(), build: vi.fn(),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
status: 'cancelled', status: 'cancelled',
response: { response: {
callId: 'cancel-1', callId: 'cancel-1',
@ -668,9 +693,13 @@ describe('useGeminiStream', () => {
}, },
tool: { tool: {
name: 'toolB', name: 'toolB',
displayName: 'toolB',
description: 'descB', description: 'descB',
getDescription: vi.fn(), build: vi.fn(),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
status: 'cancelled', status: 'cancelled',
response: { response: {
callId: 'cancel-2', callId: 'cancel-2',
@ -760,9 +789,13 @@ describe('useGeminiStream', () => {
responseSubmittedToGemini: false, responseSubmittedToGemini: false,
tool: { tool: {
name: 'tool1', name: 'tool1',
displayName: 'tool1',
description: 'desc', description: 'desc',
getDescription: vi.fn(), build: vi.fn(),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
startTime: Date.now(), startTime: Date.now(),
} as TrackedExecutingToolCall, } as TrackedExecutingToolCall,
]; ];
@ -980,8 +1013,13 @@ describe('useGeminiStream', () => {
tool: { tool: {
name: 'tool1', name: 'tool1',
description: 'desc1', description: 'desc1',
getDescription: vi.fn(), build: vi.fn().mockImplementation((_) => ({
getDescription: () => `Mock description`,
})),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
},
startTime: Date.now(), startTime: Date.now(),
liveOutput: '...', liveOutput: '...',
} as TrackedExecutingToolCall, } as TrackedExecutingToolCall,
@ -1131,9 +1169,13 @@ describe('useGeminiStream', () => {
}, },
tool: { tool: {
name: 'save_memory', name: 'save_memory',
displayName: 'save_memory',
description: 'Saves memory', description: 'Saves memory',
getDescription: vi.fn(), build: vi.fn(),
} as any, } as any,
invocation: {
getDescription: () => `Mock description`,
} as unknown as AnyToolInvocation,
}; };
// Capture the onComplete callback // Capture the onComplete callback

View File

@ -17,7 +17,6 @@ import {
OutputUpdateHandler, OutputUpdateHandler,
AllToolCallsCompleteHandler, AllToolCallsCompleteHandler,
ToolCallsUpdateHandler, ToolCallsUpdateHandler,
Tool,
ToolCall, ToolCall,
Status as CoreStatus, Status as CoreStatus,
EditorType, EditorType,
@ -216,23 +215,20 @@ export function mapToDisplay(
const toolDisplays = toolCalls.map( const toolDisplays = toolCalls.map(
(trackedCall): IndividualToolCallDisplay => { (trackedCall): IndividualToolCallDisplay => {
let displayName = trackedCall.request.name; let displayName: string;
let description = ''; let description: string;
let renderOutputAsMarkdown = false; let renderOutputAsMarkdown = false;
const currentToolInstance = if (trackedCall.status === 'error') {
'tool' in trackedCall && trackedCall.tool displayName =
? (trackedCall as { tool: Tool }).tool trackedCall.tool === undefined
: undefined; ? trackedCall.request.name
: trackedCall.tool.displayName;
if (currentToolInstance) {
displayName = currentToolInstance.displayName;
description = currentToolInstance.getDescription(
trackedCall.request.args,
);
renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown;
} else if ('request' in trackedCall && 'args' in trackedCall.request) {
description = JSON.stringify(trackedCall.request.args); description = JSON.stringify(trackedCall.request.args);
} else {
displayName = trackedCall.tool.displayName;
description = trackedCall.invocation.getDescription();
renderOutputAsMarkdown = trackedCall.tool.isOutputMarkdown;
} }
const baseDisplayProperties: Omit< const baseDisplayProperties: Omit<
@ -256,7 +252,6 @@ export function mapToDisplay(
case 'error': case 'error':
return { return {
...baseDisplayProperties, ...baseDisplayProperties,
name: currentToolInstance?.displayName ?? trackedCall.request.name,
status: mapCoreStatusToDisplayStatus(trackedCall.status), status: mapCoreStatusToDisplayStatus(trackedCall.status),
resultDisplay: trackedCall.response.resultDisplay, resultDisplay: trackedCall.response.resultDisplay,
confirmationDetails: undefined, confirmationDetails: undefined,

View File

@ -15,7 +15,6 @@ import { PartUnion, FunctionResponse } from '@google/genai';
import { import {
Config, Config,
ToolCallRequestInfo, ToolCallRequestInfo,
Tool,
ToolRegistry, ToolRegistry,
ToolResult, ToolResult,
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
@ -25,6 +24,9 @@ import {
Status as ToolCallStatusType, Status as ToolCallStatusType,
ApprovalMode, ApprovalMode,
Icon, Icon,
BaseTool,
AnyDeclarativeTool,
AnyToolInvocation,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { import {
HistoryItemWithoutId, HistoryItemWithoutId,
@ -53,46 +55,55 @@ const mockConfig = {
getDebugMode: () => false, getDebugMode: () => false,
}; };
const mockTool: Tool = { class MockTool extends BaseTool<object, ToolResult> {
name: 'mockTool', constructor(
displayName: 'Mock Tool', name: string,
description: 'A mock tool for testing', displayName: string,
icon: Icon.Hammer, canUpdateOutput = false,
toolLocations: vi.fn(), shouldConfirm = false,
isOutputMarkdown: false, isOutputMarkdown = false,
canUpdateOutput: false, ) {
schema: {}, super(
validateToolParams: vi.fn(), name,
execute: vi.fn(), displayName,
shouldConfirmExecute: vi.fn(), 'A mock tool for testing',
getDescription: vi.fn((args) => `Description for ${JSON.stringify(args)}`), Icon.Hammer,
}; {},
isOutputMarkdown,
canUpdateOutput,
);
if (shouldConfirm) {
this.shouldConfirmExecute = vi.fn(
async (): Promise<ToolCallConfirmationDetails | false> => ({
type: 'edit',
title: 'Mock Tool Requires Confirmation',
onConfirm: mockOnUserConfirmForToolConfirmation,
fileName: 'mockToolRequiresConfirmation.ts',
fileDiff: 'Mock tool requires confirmation',
originalContent: 'Original content',
newContent: 'New content',
}),
);
}
}
const mockToolWithLiveOutput: Tool = { execute = vi.fn();
...mockTool, shouldConfirmExecute = vi.fn();
name: 'mockToolWithLiveOutput', }
displayName: 'Mock Tool With Live Output',
canUpdateOutput: true,
};
const mockTool = new MockTool('mockTool', 'Mock Tool');
const mockToolWithLiveOutput = new MockTool(
'mockToolWithLiveOutput',
'Mock Tool With Live Output',
true,
);
let mockOnUserConfirmForToolConfirmation: Mock; let mockOnUserConfirmForToolConfirmation: Mock;
const mockToolRequiresConfirmation = new MockTool(
const mockToolRequiresConfirmation: Tool = { 'mockToolRequiresConfirmation',
...mockTool, 'Mock Tool Requires Confirmation',
name: 'mockToolRequiresConfirmation', false,
displayName: 'Mock Tool Requires Confirmation', true,
shouldConfirmExecute: vi.fn( );
async (): Promise<ToolCallConfirmationDetails | false> => ({
type: 'edit',
title: 'Mock Tool Requires Confirmation',
onConfirm: mockOnUserConfirmForToolConfirmation,
fileName: 'mockToolRequiresConfirmation.ts',
fileDiff: 'Mock tool requires confirmation',
originalContent: 'Original content',
newContent: 'New content',
}),
),
};
describe('useReactToolScheduler in YOLO Mode', () => { describe('useReactToolScheduler in YOLO Mode', () => {
let onComplete: Mock; let onComplete: Mock;
@ -646,28 +657,21 @@ describe('useReactToolScheduler', () => {
}); });
it('should schedule and execute multiple tool calls', async () => { it('should schedule and execute multiple tool calls', async () => {
const tool1 = { const tool1 = new MockTool('tool1', 'Tool 1');
...mockTool, tool1.execute.mockResolvedValue({
name: 'tool1', llmContent: 'Output 1',
displayName: 'Tool 1', returnDisplay: 'Display 1',
execute: vi.fn().mockResolvedValue({ summary: 'Summary 1',
llmContent: 'Output 1', } as ToolResult);
returnDisplay: 'Display 1', tool1.shouldConfirmExecute.mockResolvedValue(null);
summary: 'Summary 1',
} as ToolResult), const tool2 = new MockTool('tool2', 'Tool 2');
shouldConfirmExecute: vi.fn().mockResolvedValue(null), tool2.execute.mockResolvedValue({
}; llmContent: 'Output 2',
const tool2 = { returnDisplay: 'Display 2',
...mockTool, summary: 'Summary 2',
name: 'tool2', } as ToolResult);
displayName: 'Tool 2', tool2.shouldConfirmExecute.mockResolvedValue(null);
execute: vi.fn().mockResolvedValue({
llmContent: 'Output 2',
returnDisplay: 'Display 2',
summary: 'Summary 2',
} as ToolResult),
shouldConfirmExecute: vi.fn().mockResolvedValue(null),
};
mockToolRegistry.getTool.mockImplementation((name) => { mockToolRegistry.getTool.mockImplementation((name) => {
if (name === 'tool1') return tool1; if (name === 'tool1') return tool1;
@ -805,20 +809,7 @@ describe('mapToDisplay', () => {
args: { foo: 'bar' }, args: { foo: 'bar' },
}; };
const baseTool: Tool = { const baseTool = new MockTool('testTool', 'Test Tool Display');
name: 'testTool',
displayName: 'Test Tool Display',
description: 'Test Description',
isOutputMarkdown: false,
canUpdateOutput: false,
schema: {},
icon: Icon.Hammer,
toolLocations: vi.fn(),
validateToolParams: vi.fn(),
execute: vi.fn(),
shouldConfirmExecute: vi.fn(),
getDescription: vi.fn((args) => `Desc: ${JSON.stringify(args)}`),
};
const baseResponse: ToolCallResponseInfo = { const baseResponse: ToolCallResponseInfo = {
callId: 'testCallId', callId: 'testCallId',
@ -840,13 +831,15 @@ describe('mapToDisplay', () => {
// This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist. // This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist.
type MapToDisplayExtraProps = type MapToDisplayExtraProps =
| { | {
tool?: Tool; tool?: AnyDeclarativeTool;
invocation?: AnyToolInvocation;
liveOutput?: string; liveOutput?: string;
response?: ToolCallResponseInfo; response?: ToolCallResponseInfo;
confirmationDetails?: ToolCallConfirmationDetails; confirmationDetails?: ToolCallConfirmationDetails;
} }
| { | {
tool: Tool; tool: AnyDeclarativeTool;
invocation?: AnyToolInvocation;
response?: ToolCallResponseInfo; response?: ToolCallResponseInfo;
confirmationDetails?: ToolCallConfirmationDetails; confirmationDetails?: ToolCallConfirmationDetails;
} }
@ -857,10 +850,12 @@ describe('mapToDisplay', () => {
} }
| { | {
confirmationDetails: ToolCallConfirmationDetails; confirmationDetails: ToolCallConfirmationDetails;
tool?: Tool; tool?: AnyDeclarativeTool;
invocation?: AnyToolInvocation;
response?: ToolCallResponseInfo; response?: ToolCallResponseInfo;
}; };
const baseInvocation = baseTool.build(baseRequest.args);
const testCases: Array<{ const testCases: Array<{
name: string; name: string;
status: ToolCallStatusType; status: ToolCallStatusType;
@ -873,7 +868,7 @@ describe('mapToDisplay', () => {
{ {
name: 'validating', name: 'validating',
status: 'validating', status: 'validating',
extraProps: { tool: baseTool }, extraProps: { tool: baseTool, invocation: baseInvocation },
expectedStatus: ToolCallStatus.Executing, expectedStatus: ToolCallStatus.Executing,
expectedName: baseTool.displayName, expectedName: baseTool.displayName,
expectedDescription: baseTool.getDescription(baseRequest.args), expectedDescription: baseTool.getDescription(baseRequest.args),
@ -883,6 +878,7 @@ describe('mapToDisplay', () => {
status: 'awaiting_approval', status: 'awaiting_approval',
extraProps: { extraProps: {
tool: baseTool, tool: baseTool,
invocation: baseInvocation,
confirmationDetails: { confirmationDetails: {
onConfirm: vi.fn(), onConfirm: vi.fn(),
type: 'edit', type: 'edit',
@ -903,7 +899,7 @@ describe('mapToDisplay', () => {
{ {
name: 'scheduled', name: 'scheduled',
status: 'scheduled', status: 'scheduled',
extraProps: { tool: baseTool }, extraProps: { tool: baseTool, invocation: baseInvocation },
expectedStatus: ToolCallStatus.Pending, expectedStatus: ToolCallStatus.Pending,
expectedName: baseTool.displayName, expectedName: baseTool.displayName,
expectedDescription: baseTool.getDescription(baseRequest.args), expectedDescription: baseTool.getDescription(baseRequest.args),
@ -911,7 +907,7 @@ describe('mapToDisplay', () => {
{ {
name: 'executing no live output', name: 'executing no live output',
status: 'executing', status: 'executing',
extraProps: { tool: baseTool }, extraProps: { tool: baseTool, invocation: baseInvocation },
expectedStatus: ToolCallStatus.Executing, expectedStatus: ToolCallStatus.Executing,
expectedName: baseTool.displayName, expectedName: baseTool.displayName,
expectedDescription: baseTool.getDescription(baseRequest.args), expectedDescription: baseTool.getDescription(baseRequest.args),
@ -919,7 +915,11 @@ describe('mapToDisplay', () => {
{ {
name: 'executing with live output', name: 'executing with live output',
status: 'executing', status: 'executing',
extraProps: { tool: baseTool, liveOutput: 'Live test output' }, extraProps: {
tool: baseTool,
invocation: baseInvocation,
liveOutput: 'Live test output',
},
expectedStatus: ToolCallStatus.Executing, expectedStatus: ToolCallStatus.Executing,
expectedResultDisplay: 'Live test output', expectedResultDisplay: 'Live test output',
expectedName: baseTool.displayName, expectedName: baseTool.displayName,
@ -928,7 +928,11 @@ describe('mapToDisplay', () => {
{ {
name: 'success', name: 'success',
status: 'success', status: 'success',
extraProps: { tool: baseTool, response: baseResponse }, extraProps: {
tool: baseTool,
invocation: baseInvocation,
response: baseResponse,
},
expectedStatus: ToolCallStatus.Success, expectedStatus: ToolCallStatus.Success,
expectedResultDisplay: baseResponse.resultDisplay as any, expectedResultDisplay: baseResponse.resultDisplay as any,
expectedName: baseTool.displayName, expectedName: baseTool.displayName,
@ -970,6 +974,7 @@ describe('mapToDisplay', () => {
status: 'cancelled', status: 'cancelled',
extraProps: { extraProps: {
tool: baseTool, tool: baseTool,
invocation: baseInvocation,
response: { response: {
...baseResponse, ...baseResponse,
resultDisplay: 'Cancelled display', resultDisplay: 'Cancelled display',
@ -1030,12 +1035,21 @@ describe('mapToDisplay', () => {
request: { ...baseRequest, callId: 'call1' }, request: { ...baseRequest, callId: 'call1' },
status: 'success', status: 'success',
tool: baseTool, tool: baseTool,
invocation: baseTool.build(baseRequest.args),
response: { ...baseResponse, callId: 'call1' }, response: { ...baseResponse, callId: 'call1' },
} as ToolCall; } as ToolCall;
const toolForCall2 = new MockTool(
baseTool.name,
baseTool.displayName,
false,
false,
true,
);
const toolCall2: ToolCall = { const toolCall2: ToolCall = {
request: { ...baseRequest, callId: 'call2' }, request: { ...baseRequest, callId: 'call2' },
status: 'executing', status: 'executing',
tool: { ...baseTool, isOutputMarkdown: true }, tool: toolForCall2,
invocation: toolForCall2.build(baseRequest.args),
liveOutput: 'markdown output', liveOutput: 'markdown output',
} as ToolCall; } as ToolCall;

View File

@ -24,7 +24,6 @@ import {
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { UserTierId } from '../code_assist/types.js'; import { UserTierId } from '../code_assist/types.js';
import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js'; import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
import { reportError } from '../utils/errorReporting.js'; import { reportError } from '../utils/errorReporting.js';
@ -252,18 +251,15 @@ export class GeminiClient {
// Add full file context if the flag is set // Add full file context if the flag is set
if (this.config.getFullContext()) { if (this.config.getFullContext()) {
try { try {
const readManyFilesTool = toolRegistry.getTool( const readManyFilesTool = toolRegistry.getTool('read_many_files');
'read_many_files',
) as ReadManyFilesTool;
if (readManyFilesTool) { if (readManyFilesTool) {
const invocation = readManyFilesTool.build({
paths: ['**/*'], // Read everything recursively
useDefaultExcludes: true, // Use default excludes
});
// Read all files in the target directory // Read all files in the target directory
const result = await readManyFilesTool.execute( const result = await invocation.execute(AbortSignal.timeout(30000));
{
paths: ['**/*'], // Read everything recursively
useDefaultExcludes: true, // Use default excludes
},
AbortSignal.timeout(30000),
);
if (result.llmContent) { if (result.llmContent) {
initialParts.push({ initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`, text: `\n--- Full File Context ---\n${result.llmContent}`,

View File

@ -24,44 +24,15 @@ import {
} from '../index.js'; } from '../index.js';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js'; import {
ModifiableDeclarativeTool,
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> { ModifyContext,
shouldConfirm = false; } from '../tools/modifiable-tool.js';
executeFn = vi.fn(); import { MockTool } from '../test-utils/tools.js';
constructor(name = 'mockTool') {
super(name, name, 'A mock tool', Icon.Hammer, {});
}
async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'exec',
title: 'Confirm Mock Tool',
command: 'do_thing',
rootCommand: 'do_thing',
onConfirm: async () => {},
};
}
return false;
}
async execute(
params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolResult> {
this.executeFn(params);
return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
}
}
class MockModifiableTool class MockModifiableTool
extends MockTool extends MockTool
implements ModifiableTool<Record<string, unknown>> implements ModifiableDeclarativeTool<Record<string, unknown>>
{ {
constructor(name = 'mockModifiableTool') { constructor(name = 'mockModifiableTool') {
super(name); super(name);
@ -83,10 +54,7 @@ class MockModifiableTool
}; };
} }
async shouldConfirmExecute( async shouldConfirmExecute(): Promise<ToolCallConfirmationDetails | false> {
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) { if (this.shouldConfirm) {
return { return {
type: 'edit', type: 'edit',
@ -107,14 +75,15 @@ describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool(); const mockTool = new MockTool();
mockTool.shouldConfirm = true; mockTool.shouldConfirm = true;
const declarativeTool = mockTool;
const toolRegistry = { const toolRegistry = {
getTool: () => mockTool, getTool: () => declarativeTool,
getFunctionDeclarations: () => [], getFunctionDeclarations: () => [],
tools: new Map(), tools: new Map(),
discovery: {} as any, discovery: {} as any,
registerTool: () => {}, registerTool: () => {},
getToolByName: () => mockTool, getToolByName: () => declarativeTool,
getToolByDisplayName: () => mockTool, getToolByDisplayName: () => declarativeTool,
getTools: () => [], getTools: () => [],
discoverTools: async () => {}, discoverTools: async () => {},
getAllTools: () => [], getAllTools: () => [],
@ -177,14 +146,15 @@ describe('CoreToolScheduler', () => {
describe('CoreToolScheduler with payload', () => { describe('CoreToolScheduler with payload', () => {
it('should update args and diff and execute tool when payload is provided', async () => { it('should update args and diff and execute tool when payload is provided', async () => {
const mockTool = new MockModifiableTool(); const mockTool = new MockModifiableTool();
const declarativeTool = mockTool;
const toolRegistry = { const toolRegistry = {
getTool: () => mockTool, getTool: () => declarativeTool,
getFunctionDeclarations: () => [], getFunctionDeclarations: () => [],
tools: new Map(), tools: new Map(),
discovery: {} as any, discovery: {} as any,
registerTool: () => {}, registerTool: () => {},
getToolByName: () => mockTool, getToolByName: () => declarativeTool,
getToolByDisplayName: () => mockTool, getToolByDisplayName: () => declarativeTool,
getTools: () => [], getTools: () => [],
discoverTools: async () => {}, discoverTools: async () => {},
getAllTools: () => [], getAllTools: () => [],
@ -221,10 +191,7 @@ describe('CoreToolScheduler with payload', () => {
await scheduler.schedule([request], abortController.signal); await scheduler.schedule([request], abortController.signal);
const confirmationDetails = await mockTool.shouldConfirmExecute( const confirmationDetails = await mockTool.shouldConfirmExecute();
{},
abortController.signal,
);
if (confirmationDetails) { if (confirmationDetails) {
const payload: ToolConfirmationPayload = { newContent: 'final version' }; const payload: ToolConfirmationPayload = { newContent: 'final version' };
@ -456,14 +423,15 @@ describe('CoreToolScheduler edit cancellation', () => {
} }
const mockEditTool = new MockEditTool(); const mockEditTool = new MockEditTool();
const declarativeTool = mockEditTool;
const toolRegistry = { const toolRegistry = {
getTool: () => mockEditTool, getTool: () => declarativeTool,
getFunctionDeclarations: () => [], getFunctionDeclarations: () => [],
tools: new Map(), tools: new Map(),
discovery: {} as any, discovery: {} as any,
registerTool: () => {}, registerTool: () => {},
getToolByName: () => mockEditTool, getToolByName: () => declarativeTool,
getToolByDisplayName: () => mockEditTool, getToolByDisplayName: () => declarativeTool,
getTools: () => [], getTools: () => [],
discoverTools: async () => {}, discoverTools: async () => {},
getAllTools: () => [], getAllTools: () => [],
@ -541,18 +509,23 @@ describe('CoreToolScheduler YOLO mode', () => {
it('should execute tool requiring confirmation directly without waiting', async () => { it('should execute tool requiring confirmation directly without waiting', async () => {
// Arrange // Arrange
const mockTool = new MockTool(); const mockTool = new MockTool();
mockTool.executeFn.mockReturnValue({
llmContent: 'Tool executed',
returnDisplay: 'Tool executed',
});
// This tool would normally require confirmation. // This tool would normally require confirmation.
mockTool.shouldConfirm = true; mockTool.shouldConfirm = true;
const declarativeTool = mockTool;
const toolRegistry = { const toolRegistry = {
getTool: () => mockTool, getTool: () => declarativeTool,
getToolByName: () => mockTool, getToolByName: () => declarativeTool,
// Other properties are not needed for this test but are included for type consistency. // Other properties are not needed for this test but are included for type consistency.
getFunctionDeclarations: () => [], getFunctionDeclarations: () => [],
tools: new Map(), tools: new Map(),
discovery: {} as any, discovery: {} as any,
registerTool: () => {}, registerTool: () => {},
getToolByDisplayName: () => mockTool, getToolByDisplayName: () => declarativeTool,
getTools: () => [], getTools: () => [],
discoverTools: async () => {}, discoverTools: async () => {},
getAllTools: () => [], getAllTools: () => [],

View File

@ -8,7 +8,6 @@ import {
ToolCallRequestInfo, ToolCallRequestInfo,
ToolCallResponseInfo, ToolCallResponseInfo,
ToolConfirmationOutcome, ToolConfirmationOutcome,
Tool,
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
ToolResult, ToolResult,
ToolResultDisplay, ToolResultDisplay,
@ -20,11 +19,13 @@ import {
ToolCallEvent, ToolCallEvent,
ToolConfirmationPayload, ToolConfirmationPayload,
ToolErrorType, ToolErrorType,
AnyDeclarativeTool,
AnyToolInvocation,
} from '../index.js'; } from '../index.js';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js'; import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js';
import { import {
isModifiableTool, isModifiableDeclarativeTool,
ModifyContext, ModifyContext,
modifyWithEditor, modifyWithEditor,
} from '../tools/modifiable-tool.js'; } from '../tools/modifiable-tool.js';
@ -33,7 +34,8 @@ import * as Diff from 'diff';
export type ValidatingToolCall = { export type ValidatingToolCall = {
status: 'validating'; status: 'validating';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
tool: Tool; tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
startTime?: number; startTime?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
}; };
@ -41,7 +43,8 @@ export type ValidatingToolCall = {
export type ScheduledToolCall = { export type ScheduledToolCall = {
status: 'scheduled'; status: 'scheduled';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
tool: Tool; tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
startTime?: number; startTime?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
}; };
@ -50,6 +53,7 @@ export type ErroredToolCall = {
status: 'error'; status: 'error';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
response: ToolCallResponseInfo; response: ToolCallResponseInfo;
tool?: AnyDeclarativeTool;
durationMs?: number; durationMs?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
}; };
@ -57,8 +61,9 @@ export type ErroredToolCall = {
export type SuccessfulToolCall = { export type SuccessfulToolCall = {
status: 'success'; status: 'success';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
tool: Tool; tool: AnyDeclarativeTool;
response: ToolCallResponseInfo; response: ToolCallResponseInfo;
invocation: AnyToolInvocation;
durationMs?: number; durationMs?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
}; };
@ -66,7 +71,8 @@ export type SuccessfulToolCall = {
export type ExecutingToolCall = { export type ExecutingToolCall = {
status: 'executing'; status: 'executing';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
tool: Tool; tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
liveOutput?: string; liveOutput?: string;
startTime?: number; startTime?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
@ -76,7 +82,8 @@ export type CancelledToolCall = {
status: 'cancelled'; status: 'cancelled';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
response: ToolCallResponseInfo; response: ToolCallResponseInfo;
tool: Tool; tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
durationMs?: number; durationMs?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
}; };
@ -84,7 +91,8 @@ export type CancelledToolCall = {
export type WaitingToolCall = { export type WaitingToolCall = {
status: 'awaiting_approval'; status: 'awaiting_approval';
request: ToolCallRequestInfo; request: ToolCallRequestInfo;
tool: Tool; tool: AnyDeclarativeTool;
invocation: AnyToolInvocation;
confirmationDetails: ToolCallConfirmationDetails; confirmationDetails: ToolCallConfirmationDetails;
startTime?: number; startTime?: number;
outcome?: ToolConfirmationOutcome; outcome?: ToolConfirmationOutcome;
@ -289,6 +297,7 @@ export class CoreToolScheduler {
// currentCall is a non-terminal state here and should have startTime and tool. // currentCall is a non-terminal state here and should have startTime and tool.
const existingStartTime = currentCall.startTime; const existingStartTime = currentCall.startTime;
const toolInstance = currentCall.tool; const toolInstance = currentCall.tool;
const invocation = currentCall.invocation;
const outcome = currentCall.outcome; const outcome = currentCall.outcome;
@ -300,6 +309,7 @@ export class CoreToolScheduler {
return { return {
request: currentCall.request, request: currentCall.request,
tool: toolInstance, tool: toolInstance,
invocation,
status: 'success', status: 'success',
response: auxiliaryData as ToolCallResponseInfo, response: auxiliaryData as ToolCallResponseInfo,
durationMs, durationMs,
@ -313,6 +323,7 @@ export class CoreToolScheduler {
return { return {
request: currentCall.request, request: currentCall.request,
status: 'error', status: 'error',
tool: toolInstance,
response: auxiliaryData as ToolCallResponseInfo, response: auxiliaryData as ToolCallResponseInfo,
durationMs, durationMs,
outcome, outcome,
@ -326,6 +337,7 @@ export class CoreToolScheduler {
confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, confirmationDetails: auxiliaryData as ToolCallConfirmationDetails,
startTime: existingStartTime, startTime: existingStartTime,
outcome, outcome,
invocation,
} as WaitingToolCall; } as WaitingToolCall;
case 'scheduled': case 'scheduled':
return { return {
@ -334,6 +346,7 @@ export class CoreToolScheduler {
status: 'scheduled', status: 'scheduled',
startTime: existingStartTime, startTime: existingStartTime,
outcome, outcome,
invocation,
} as ScheduledToolCall; } as ScheduledToolCall;
case 'cancelled': { case 'cancelled': {
const durationMs = existingStartTime const durationMs = existingStartTime
@ -358,6 +371,7 @@ export class CoreToolScheduler {
return { return {
request: currentCall.request, request: currentCall.request,
tool: toolInstance, tool: toolInstance,
invocation,
status: 'cancelled', status: 'cancelled',
response: { response: {
callId: currentCall.request.callId, callId: currentCall.request.callId,
@ -385,6 +399,7 @@ export class CoreToolScheduler {
status: 'validating', status: 'validating',
startTime: existingStartTime, startTime: existingStartTime,
outcome, outcome,
invocation,
} as ValidatingToolCall; } as ValidatingToolCall;
case 'executing': case 'executing':
return { return {
@ -393,6 +408,7 @@ export class CoreToolScheduler {
status: 'executing', status: 'executing',
startTime: existingStartTime, startTime: existingStartTime,
outcome, outcome,
invocation,
} as ExecutingToolCall; } as ExecutingToolCall;
default: { default: {
const exhaustiveCheck: never = newStatus; const exhaustiveCheck: never = newStatus;
@ -406,10 +422,34 @@ export class CoreToolScheduler {
private setArgsInternal(targetCallId: string, args: unknown): void { private setArgsInternal(targetCallId: string, args: unknown): void {
this.toolCalls = this.toolCalls.map((call) => { this.toolCalls = this.toolCalls.map((call) => {
if (call.request.callId !== targetCallId) return call; // We should never be asked to set args on an ErroredToolCall, but
// we guard for the case anyways.
if (call.request.callId !== targetCallId || call.status === 'error') {
return call;
}
const invocationOrError = this.buildInvocation(
call.tool,
args as Record<string, unknown>,
);
if (invocationOrError instanceof Error) {
const response = createErrorResponse(
call.request,
invocationOrError,
ToolErrorType.INVALID_TOOL_PARAMS,
);
return {
request: { ...call.request, args: args as Record<string, unknown> },
status: 'error',
tool: call.tool,
response,
} as ErroredToolCall;
}
return { return {
...call, ...call,
request: { ...call.request, args: args as Record<string, unknown> }, request: { ...call.request, args: args as Record<string, unknown> },
invocation: invocationOrError,
}; };
}); });
} }
@ -421,6 +461,20 @@ export class CoreToolScheduler {
); );
} }
private buildInvocation(
tool: AnyDeclarativeTool,
args: object,
): AnyToolInvocation | Error {
try {
return tool.build(args);
} catch (e) {
if (e instanceof Error) {
return e;
}
return new Error(String(e));
}
}
async schedule( async schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[], request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal, signal: AbortSignal,
@ -448,10 +502,30 @@ export class CoreToolScheduler {
durationMs: 0, durationMs: 0,
}; };
} }
const invocationOrError = this.buildInvocation(
toolInstance,
reqInfo.args,
);
if (invocationOrError instanceof Error) {
return {
status: 'error',
request: reqInfo,
tool: toolInstance,
response: createErrorResponse(
reqInfo,
invocationOrError,
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
};
}
return { return {
status: 'validating', status: 'validating',
request: reqInfo, request: reqInfo,
tool: toolInstance, tool: toolInstance,
invocation: invocationOrError,
startTime: Date.now(), startTime: Date.now(),
}; };
}, },
@ -465,7 +539,8 @@ export class CoreToolScheduler {
continue; continue;
} }
const { request: reqInfo, tool: toolInstance } = toolCall; const { request: reqInfo, invocation } = toolCall;
try { try {
if (this.config.getApprovalMode() === ApprovalMode.YOLO) { if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
this.setToolCallOutcome( this.setToolCallOutcome(
@ -474,10 +549,8 @@ export class CoreToolScheduler {
); );
this.setStatusInternal(reqInfo.callId, 'scheduled'); this.setStatusInternal(reqInfo.callId, 'scheduled');
} else { } else {
const confirmationDetails = await toolInstance.shouldConfirmExecute( const confirmationDetails =
reqInfo.args, await invocation.shouldConfirmExecute(signal);
signal,
);
if (confirmationDetails) { if (confirmationDetails) {
// Allow IDE to resolve confirmation // Allow IDE to resolve confirmation
@ -573,7 +646,7 @@ export class CoreToolScheduler {
); );
} else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) { } else if (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
const waitingToolCall = toolCall as WaitingToolCall; const waitingToolCall = toolCall as WaitingToolCall;
if (isModifiableTool(waitingToolCall.tool)) { if (isModifiableDeclarativeTool(waitingToolCall.tool)) {
const modifyContext = waitingToolCall.tool.getModifyContext(signal); const modifyContext = waitingToolCall.tool.getModifyContext(signal);
const editorType = this.getPreferredEditor(); const editorType = this.getPreferredEditor();
if (!editorType) { if (!editorType) {
@ -628,7 +701,7 @@ export class CoreToolScheduler {
): Promise<void> { ): Promise<void> {
if ( if (
toolCall.confirmationDetails.type !== 'edit' || toolCall.confirmationDetails.type !== 'edit' ||
!isModifiableTool(toolCall.tool) !isModifiableDeclarativeTool(toolCall.tool)
) { ) {
return; return;
} }
@ -677,6 +750,7 @@ export class CoreToolScheduler {
const scheduledCall = toolCall; const scheduledCall = toolCall;
const { callId, name: toolName } = scheduledCall.request; const { callId, name: toolName } = scheduledCall.request;
const invocation = scheduledCall.invocation;
this.setStatusInternal(callId, 'executing'); this.setStatusInternal(callId, 'executing');
const liveOutputCallback = const liveOutputCallback =
@ -694,8 +768,8 @@ export class CoreToolScheduler {
} }
: undefined; : undefined;
scheduledCall.tool invocation
.execute(scheduledCall.request.args, signal, liveOutputCallback) .execute(signal, liveOutputCallback)
.then(async (toolResult: ToolResult) => { .then(async (toolResult: ToolResult) => {
if (signal.aborted) { if (signal.aborted) {
this.setStatusInternal( this.setStatusInternal(

View File

@ -10,12 +10,10 @@ import {
ToolRegistry, ToolRegistry,
ToolCallRequestInfo, ToolCallRequestInfo,
ToolResult, ToolResult,
Tool,
ToolCallConfirmationDetails,
Config, Config,
Icon,
} from '../index.js'; } from '../index.js';
import { Part, Type } from '@google/genai'; import { Part } from '@google/genai';
import { MockTool } from '../test-utils/tools.js';
const mockConfig = { const mockConfig = {
getSessionId: () => 'test-session-id', getSessionId: () => 'test-session-id',
@ -25,36 +23,11 @@ const mockConfig = {
describe('executeToolCall', () => { describe('executeToolCall', () => {
let mockToolRegistry: ToolRegistry; let mockToolRegistry: ToolRegistry;
let mockTool: Tool; let mockTool: MockTool;
let abortController: AbortController; let abortController: AbortController;
beforeEach(() => { beforeEach(() => {
mockTool = { mockTool = new MockTool();
name: 'testTool',
displayName: 'Test Tool',
description: 'A tool for testing',
icon: Icon.Hammer,
schema: {
name: 'testTool',
description: 'A tool for testing',
parameters: {
type: Type.OBJECT,
properties: {
param1: { type: Type.STRING },
},
required: ['param1'],
},
},
execute: vi.fn(),
validateToolParams: vi.fn(() => null),
shouldConfirmExecute: vi.fn(() =>
Promise.resolve(false as false | ToolCallConfirmationDetails),
),
isOutputMarkdown: false,
canUpdateOutput: false,
getDescription: vi.fn(),
toolLocations: vi.fn(() => []),
};
mockToolRegistry = { mockToolRegistry = {
getTool: vi.fn(), getTool: vi.fn(),
@ -77,7 +50,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Success!', returnDisplay: 'Success!',
}; };
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockResolvedValue(toolResult); vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult);
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
@ -87,7 +60,7 @@ describe('executeToolCall', () => {
); );
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool'); expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool');
expect(mockTool.execute).toHaveBeenCalledWith( expect(mockTool.buildAndExecute).toHaveBeenCalledWith(
request.args, request.args,
abortController.signal, abortController.signal,
); );
@ -149,7 +122,7 @@ describe('executeToolCall', () => {
}; };
const executionError = new Error('Tool execution failed'); const executionError = new Error('Tool execution failed');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockRejectedValue(executionError); vi.spyOn(mockTool, 'buildAndExecute').mockRejectedValue(executionError);
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
@ -183,25 +156,27 @@ describe('executeToolCall', () => {
const cancellationError = new Error('Operation cancelled'); const cancellationError = new Error('Operation cancelled');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockImplementation(async (_args, signal) => { vi.spyOn(mockTool, 'buildAndExecute').mockImplementation(
if (signal?.aborted) { async (_args, signal) => {
return Promise.reject(cancellationError); if (signal?.aborted) {
} return Promise.reject(cancellationError);
return new Promise((_resolve, reject) => { }
signal?.addEventListener('abort', () => { return new Promise((_resolve, reject) => {
reject(cancellationError); signal?.addEventListener('abort', () => {
reject(cancellationError);
});
// Simulate work that might happen if not aborted immediately
const timeoutId = setTimeout(
() =>
reject(
new Error('Should have been cancelled if not aborted prior'),
),
100,
);
signal?.addEventListener('abort', () => clearTimeout(timeoutId));
}); });
// Simulate work that might happen if not aborted immediately },
const timeoutId = setTimeout( );
() =>
reject(
new Error('Should have been cancelled if not aborted prior'),
),
100,
);
signal?.addEventListener('abort', () => clearTimeout(timeoutId));
});
});
abortController.abort(); // Abort before calling abortController.abort(); // Abort before calling
const response = await executeToolCall( const response = await executeToolCall(
@ -232,7 +207,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Image processed', returnDisplay: 'Image processed',
}; };
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.mocked(mockTool.execute).mockResolvedValue(toolResult); vi.spyOn(mockTool, 'buildAndExecute').mockResolvedValue(toolResult);
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,

View File

@ -65,7 +65,7 @@ export async function executeToolCall(
try { try {
// Directly execute without confirmation or live output handling // Directly execute without confirmation or live output handling
const effectiveAbortSignal = abortSignal ?? new AbortController().signal; const effectiveAbortSignal = abortSignal ?? new AbortController().signal;
const toolResult: ToolResult = await tool.execute( const toolResult: ToolResult = await tool.buildAndExecute(
toolCallRequest.args, toolCallRequest.args,
effectiveAbortSignal, effectiveAbortSignal,
// No live output callback for non-interactive mode // No live output callback for non-interactive mode

View File

@ -14,7 +14,7 @@ import { ToolCallEvent } from './types.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { CompletedToolCall } from '../core/coreToolScheduler.js'; import { CompletedToolCall } from '../core/coreToolScheduler.js';
import { ToolCallRequestInfo, ToolCallResponseInfo } from '../core/turn.js'; import { ToolCallRequestInfo, ToolCallResponseInfo } from '../core/turn.js';
import { Tool } from '../tools/tools.js'; import { MockTool } from '../test-utils/tools.js';
describe('Circular Reference Handling', () => { describe('Circular Reference Handling', () => {
it('should handle circular references in tool function arguments', () => { it('should handle circular references in tool function arguments', () => {
@ -56,11 +56,13 @@ describe('Circular Reference Handling', () => {
errorType: undefined, errorType: undefined,
}; };
const tool = new MockTool('mock-tool');
const mockCompletedToolCall: CompletedToolCall = { const mockCompletedToolCall: CompletedToolCall = {
status: 'success', status: 'success',
request: mockRequest, request: mockRequest,
response: mockResponse, response: mockResponse,
tool: {} as Tool, tool,
invocation: tool.build({}),
durationMs: 100, durationMs: 100,
}; };
@ -104,11 +106,13 @@ describe('Circular Reference Handling', () => {
errorType: undefined, errorType: undefined,
}; };
const tool = new MockTool('mock-tool');
const mockCompletedToolCall: CompletedToolCall = { const mockCompletedToolCall: CompletedToolCall = {
status: 'success', status: 'success',
request: mockRequest, request: mockRequest,
response: mockResponse, response: mockResponse,
tool: {} as Tool, tool,
invocation: tool.build({}),
durationMs: 100, durationMs: 100,
}; };

View File

@ -5,6 +5,7 @@
*/ */
import { import {
AnyToolInvocation,
AuthType, AuthType,
CompletedToolCall, CompletedToolCall,
ContentGeneratorConfig, ContentGeneratorConfig,
@ -432,6 +433,7 @@ describe('loggers', () => {
}); });
it('should log a tool call with all fields', () => { it('should log a tool call with all fields', () => {
const tool = new EditTool(mockConfig);
const call: CompletedToolCall = { const call: CompletedToolCall = {
status: 'success', status: 'success',
request: { request: {
@ -451,7 +453,8 @@ describe('loggers', () => {
error: undefined, error: undefined,
errorType: undefined, errorType: undefined,
}, },
tool: new EditTool(mockConfig), tool,
invocation: {} as AnyToolInvocation,
durationMs: 100, durationMs: 100,
outcome: ToolConfirmationOutcome.ProceedOnce, outcome: ToolConfirmationOutcome.ProceedOnce,
}; };
@ -581,6 +584,7 @@ describe('loggers', () => {
}, },
outcome: ToolConfirmationOutcome.ModifyWithEditor, outcome: ToolConfirmationOutcome.ModifyWithEditor,
tool: new EditTool(mockConfig), tool: new EditTool(mockConfig),
invocation: {} as AnyToolInvocation,
durationMs: 100, durationMs: 100,
}; };
const event = new ToolCallEvent(call); const event = new ToolCallEvent(call);
@ -645,6 +649,7 @@ describe('loggers', () => {
errorType: undefined, errorType: undefined,
}, },
tool: new EditTool(mockConfig), tool: new EditTool(mockConfig),
invocation: {} as AnyToolInvocation,
durationMs: 100, durationMs: 100,
}; };
const event = new ToolCallEvent(call); const event = new ToolCallEvent(call);

View File

@ -23,7 +23,8 @@ import {
SuccessfulToolCall, SuccessfulToolCall,
} from '../core/coreToolScheduler.js'; } from '../core/coreToolScheduler.js';
import { ToolErrorType } from '../tools/tool-error.js'; import { ToolErrorType } from '../tools/tool-error.js';
import { Tool, ToolConfirmationOutcome } from '../tools/tools.js'; import { ToolConfirmationOutcome } from '../tools/tools.js';
import { MockTool } from '../test-utils/tools.js';
const createFakeCompletedToolCall = ( const createFakeCompletedToolCall = (
name: string, name: string,
@ -39,12 +40,14 @@ const createFakeCompletedToolCall = (
isClientInitiated: false, isClientInitiated: false,
prompt_id: 'prompt-id-1', prompt_id: 'prompt-id-1',
}; };
const tool = new MockTool(name);
if (success) { if (success) {
return { return {
status: 'success', status: 'success',
request, request,
tool: { name } as Tool, // Mock tool tool,
invocation: tool.build({}),
response: { response: {
callId: request.callId, callId: request.callId,
responseParts: { responseParts: {
@ -65,6 +68,7 @@ const createFakeCompletedToolCall = (
return { return {
status: 'error', status: 'error',
request, request,
tool,
response: { response: {
callId: request.callId, callId: request.callId,
responseParts: { responseParts: {

View File

@ -0,0 +1,63 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import {
BaseTool,
Icon,
ToolCallConfirmationDetails,
ToolResult,
} from '../tools/tools.js';
import { Schema, Type } from '@google/genai';
/**
* A highly configurable mock tool for testing purposes.
*/
export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> {
executeFn = vi.fn();
shouldConfirm = false;
constructor(
name = 'mock-tool',
displayName?: string,
description = 'A mock tool for testing.',
params: Schema = {
type: Type.OBJECT,
properties: { param: { type: Type.STRING } },
},
) {
super(name, displayName ?? name, description, Icon.Hammer, params);
}
async execute(
params: { [key: string]: unknown },
_abortSignal: AbortSignal,
): Promise<ToolResult> {
const result = this.executeFn(params);
return (
result ?? {
llmContent: `Tool ${this.name} executed successfully.`,
returnDisplay: `Tool ${this.name} executed successfully.`,
}
);
}
async shouldConfirmExecute(
_params: { [key: string]: unknown },
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'exec' as const,
title: `Confirm ${this.displayName}`,
command: this.name,
rootCommand: this.name,
onConfirm: async () => {},
};
}
return false;
}
}

View File

@ -26,7 +26,7 @@ import { Config, ApprovalMode } from '../config/config.js';
import { ensureCorrectEdit } from '../utils/editCorrector.js'; import { ensureCorrectEdit } from '../utils/editCorrector.js';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { ReadFileTool } from './read-file.js'; import { ReadFileTool } from './read-file.js';
import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
/** /**
* Parameters for the Edit tool * Parameters for the Edit tool
@ -72,7 +72,7 @@ interface CalculatedEdit {
*/ */
export class EditTool export class EditTool
extends BaseTool<EditToolParams, ToolResult> extends BaseTool<EditToolParams, ToolResult>
implements ModifiableTool<EditToolParams> implements ModifiableDeclarativeTool<EditToolParams>
{ {
static readonly Name = 'replace'; static readonly Name = 'replace';

View File

@ -18,7 +18,7 @@ import { homedir } from 'os';
import * as Diff from 'diff'; import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js'; import { tildeifyPath } from '../utils/paths.js';
import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
const memoryToolSchemaData: FunctionDeclaration = { const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory', name: 'save_memory',
@ -112,7 +112,7 @@ function ensureNewlineSeparation(currentContent: string): string {
export class MemoryTool export class MemoryTool
extends BaseTool<SaveMemoryParams, ToolResult> extends BaseTool<SaveMemoryParams, ToolResult>
implements ModifiableTool<SaveMemoryParams> implements ModifiableDeclarativeTool<SaveMemoryParams>
{ {
private static readonly allowlist: Set<string> = new Set(); private static readonly allowlist: Set<string> = new Set();

View File

@ -8,8 +8,8 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import { import {
modifyWithEditor, modifyWithEditor,
ModifyContext, ModifyContext,
ModifiableTool, ModifiableDeclarativeTool,
isModifiableTool, isModifiableDeclarativeTool,
} from './modifiable-tool.js'; } from './modifiable-tool.js';
import { EditorType } from '../utils/editor.js'; import { EditorType } from '../utils/editor.js';
import fs from 'fs'; import fs from 'fs';
@ -338,16 +338,16 @@ describe('isModifiableTool', () => {
const mockTool = { const mockTool = {
name: 'test-tool', name: 'test-tool',
getModifyContext: vi.fn(), getModifyContext: vi.fn(),
} as unknown as ModifiableTool<TestParams>; } as unknown as ModifiableDeclarativeTool<TestParams>;
expect(isModifiableTool(mockTool)).toBe(true); expect(isModifiableDeclarativeTool(mockTool)).toBe(true);
}); });
it('should return false for objects without getModifyContext method', () => { it('should return false for objects without getModifyContext method', () => {
const mockTool = { const mockTool = {
name: 'test-tool', name: 'test-tool',
} as unknown as ModifiableTool<TestParams>; } as unknown as ModifiableDeclarativeTool<TestParams>;
expect(isModifiableTool(mockTool)).toBe(false); expect(isModifiableDeclarativeTool(mockTool)).toBe(false);
}); });
}); });

View File

@ -11,13 +11,14 @@ import fs from 'fs';
import * as Diff from 'diff'; import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { isNodeError } from '../utils/errors.js'; import { isNodeError } from '../utils/errors.js';
import { Tool } from './tools.js'; import { AnyDeclarativeTool, DeclarativeTool, ToolResult } from './tools.js';
/** /**
* A tool that supports a modify operation. * A declarative tool that supports a modify operation.
*/ */
export interface ModifiableTool<ToolParams> extends Tool<ToolParams> { export interface ModifiableDeclarativeTool<TParams extends object>
getModifyContext(abortSignal: AbortSignal): ModifyContext<ToolParams>; extends DeclarativeTool<TParams, ToolResult> {
getModifyContext(abortSignal: AbortSignal): ModifyContext<TParams>;
} }
export interface ModifyContext<ToolParams> { export interface ModifyContext<ToolParams> {
@ -39,9 +40,12 @@ export interface ModifyResult<ToolParams> {
updatedDiff: string; updatedDiff: string;
} }
export function isModifiableTool<TParams>( /**
tool: Tool<TParams>, * Type guard to check if a declarative tool is modifiable.
): tool is ModifiableTool<TParams> { */
export function isModifiableDeclarativeTool(
tool: AnyDeclarativeTool,
): tool is ModifiableDeclarativeTool<object> {
return 'getModifyContext' in tool; return 'getModifyContext' in tool;
} }

View File

@ -13,6 +13,7 @@ import fsp from 'fs/promises';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
import { ToolInvocation, ToolResult } from './tools.js';
describe('ReadFileTool', () => { describe('ReadFileTool', () => {
let tempRootDir: string; let tempRootDir: string;
@ -40,57 +41,62 @@ describe('ReadFileTool', () => {
} }
}); });
describe('validateToolParams', () => { describe('build', () => {
it('should return null for valid params (absolute path within root)', () => { it('should return an invocation for valid params (absolute path within root)', () => {
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'), absolute_path: path.join(tempRootDir, 'test.txt'),
}; };
expect(tool.validateToolParams(params)).toBeNull(); const result = tool.build(params);
expect(result).not.toBeTypeOf('string');
expect(typeof result).toBe('object');
expect(
(result as ToolInvocation<ReadFileToolParams, ToolResult>).params,
).toEqual(params);
}); });
it('should return null for valid params with offset and limit', () => { it('should return an invocation for valid params with offset and limit', () => {
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'), absolute_path: path.join(tempRootDir, 'test.txt'),
offset: 0, offset: 0,
limit: 10, limit: 10,
}; };
expect(tool.validateToolParams(params)).toBeNull(); const result = tool.build(params);
expect(result).not.toBeTypeOf('string');
}); });
it('should return error for relative path', () => { it('should throw error for relative path', () => {
const params: ReadFileToolParams = { absolute_path: 'test.txt' }; const params: ReadFileToolParams = { absolute_path: 'test.txt' };
expect(tool.validateToolParams(params)).toBe( expect(() => tool.build(params)).toThrow(
`File path must be absolute, but was relative: test.txt. You must provide an absolute path.`, `File path must be absolute, but was relative: test.txt. You must provide an absolute path.`,
); );
}); });
it('should return error for path outside root', () => { it('should throw error for path outside root', () => {
const outsidePath = path.resolve(os.tmpdir(), 'outside-root.txt'); const outsidePath = path.resolve(os.tmpdir(), 'outside-root.txt');
const params: ReadFileToolParams = { absolute_path: outsidePath }; const params: ReadFileToolParams = { absolute_path: outsidePath };
const error = tool.validateToolParams(params); expect(() => tool.build(params)).toThrow(
expect(error).toContain(
'File path must be within one of the workspace directories', 'File path must be within one of the workspace directories',
); );
}); });
it('should return error for negative offset', () => { it('should throw error for negative offset', () => {
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'), absolute_path: path.join(tempRootDir, 'test.txt'),
offset: -1, offset: -1,
limit: 10, limit: 10,
}; };
expect(tool.validateToolParams(params)).toBe( expect(() => tool.build(params)).toThrow(
'Offset must be a non-negative number', 'Offset must be a non-negative number',
); );
}); });
it('should return error for non-positive limit', () => { it('should throw error for non-positive limit', () => {
const paramsZero: ReadFileToolParams = { const paramsZero: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'test.txt'), absolute_path: path.join(tempRootDir, 'test.txt'),
offset: 0, offset: 0,
limit: 0, limit: 0,
}; };
expect(tool.validateToolParams(paramsZero)).toBe( expect(() => tool.build(paramsZero)).toThrow(
'Limit must be a positive number', 'Limit must be a positive number',
); );
const paramsNegative: ReadFileToolParams = { const paramsNegative: ReadFileToolParams = {
@ -98,168 +104,182 @@ describe('ReadFileTool', () => {
offset: 0, offset: 0,
limit: -5, limit: -5,
}; };
expect(tool.validateToolParams(paramsNegative)).toBe( expect(() => tool.build(paramsNegative)).toThrow(
'Limit must be a positive number', 'Limit must be a positive number',
); );
}); });
it('should return error for schema validation failure (e.g. missing path)', () => { it('should throw error for schema validation failure (e.g. missing path)', () => {
const params = { offset: 0 } as unknown as ReadFileToolParams; const params = { offset: 0 } as unknown as ReadFileToolParams;
expect(tool.validateToolParams(params)).toBe( expect(() => tool.build(params)).toThrow(
`params must have required property 'absolute_path'`, `params must have required property 'absolute_path'`,
); );
}); });
}); });
describe('getDescription', () => { describe('ToolInvocation', () => {
it('should return a shortened, relative path', () => { describe('getDescription', () => {
const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt'); it('should return a shortened, relative path', () => {
const params: ReadFileToolParams = { absolute_path: filePath }; const filePath = path.join(tempRootDir, 'sub', 'dir', 'file.txt');
expect(tool.getDescription(params)).toBe( const params: ReadFileToolParams = { absolute_path: filePath };
path.join('sub', 'dir', 'file.txt'), const invocation = tool.build(params);
); expect(typeof invocation).not.toBe('string');
}); expect(
(
invocation as ToolInvocation<ReadFileToolParams, ToolResult>
).getDescription(),
).toBe(path.join('sub', 'dir', 'file.txt'));
});
it('should return . if path is the root directory', () => { it('should return . if path is the root directory', () => {
const params: ReadFileToolParams = { absolute_path: tempRootDir }; const params: ReadFileToolParams = { absolute_path: tempRootDir };
expect(tool.getDescription(params)).toBe('.'); const invocation = tool.build(params);
}); expect(typeof invocation).not.toBe('string');
}); expect(
(
describe('execute', () => { invocation as ToolInvocation<ReadFileToolParams, ToolResult>
it('should return validation error if params are invalid', async () => { ).getDescription(),
const params: ReadFileToolParams = { ).toBe('.');
absolute_path: 'relative/path.txt',
};
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent:
'Error: Invalid parameters provided. Reason: File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
returnDisplay:
'File path must be absolute, but was relative: relative/path.txt. You must provide an absolute path.',
}); });
}); });
it('should return error if file does not exist', async () => { describe('execute', () => {
const filePath = path.join(tempRootDir, 'nonexistent.txt'); it('should return error if file does not exist', async () => {
const params: ReadFileToolParams = { absolute_path: filePath }; const filePath = path.join(tempRootDir, 'nonexistent.txt');
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await tool.execute(params, abortSignal)).toEqual({ expect(await invocation.execute(abortSignal)).toEqual({
llmContent: `File not found: ${filePath}`, llmContent: `File not found: ${filePath}`,
returnDisplay: 'File not found.', returnDisplay: 'File not found.',
});
});
it('should return success result for a text file', async () => {
const filePath = path.join(tempRootDir, 'textfile.txt');
const fileContent = 'This is a test file.';
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: fileContent,
returnDisplay: '',
});
});
it('should return success result for an image file', async () => {
// A minimal 1x1 transparent PNG file.
const pngContent = Buffer.from([
137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0,
1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68, 65,
84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0, 0, 73,
69, 78, 68, 174, 66, 96, 130,
]);
const filePath = path.join(tempRootDir, 'image.png');
await fsp.writeFile(filePath, pngContent);
const params: ReadFileToolParams = { absolute_path: filePath };
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: pngContent.toString('base64'),
},
},
returnDisplay: `Read image file: image.png`,
});
});
it('should treat a non-image file with image extension as an image', async () => {
const filePath = path.join(tempRootDir, 'fake-image.png');
const fileContent = 'This is not a real png.';
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: Buffer.from(fileContent).toString('base64'),
},
},
returnDisplay: `Read image file: fake-image.png`,
});
});
it('should pass offset and limit to read a slice of a text file', async () => {
const filePath = path.join(tempRootDir, 'paginated.txt');
const fileContent = Array.from(
{ length: 20 },
(_, i) => `Line ${i + 1}`,
).join('\n');
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = {
absolute_path: filePath,
offset: 5, // Start from line 6
limit: 3,
};
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: [
'[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]',
'Line 6',
'Line 7',
'Line 8',
].join('\n'),
returnDisplay: 'Read lines 6-8 of 20 from paginated.txt',
});
});
describe('with .geminiignore', () => {
beforeEach(async () => {
await fsp.writeFile(
path.join(tempRootDir, '.geminiignore'),
['foo.*', 'ignored/'].join('\n'),
);
});
it('should return error if path is ignored by a .geminiignore pattern', async () => {
const ignoredFilePath = path.join(tempRootDir, 'foo.bar');
await fsp.writeFile(ignoredFilePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: ignoredFilePath,
};
const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`;
expect(await tool.execute(params, abortSignal)).toEqual({
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`,
returnDisplay: expectedError,
}); });
}); });
it('should return error if path is in an ignored directory', async () => { it('should return success result for a text file', async () => {
const ignoredDirPath = path.join(tempRootDir, 'ignored'); const filePath = path.join(tempRootDir, 'textfile.txt');
await fsp.mkdir(ignoredDirPath); const fileContent = 'This is a test file.';
const filePath = path.join(ignoredDirPath, 'somefile.txt'); await fsp.writeFile(filePath, fileContent, 'utf-8');
await fsp.writeFile(filePath, 'content', 'utf-8'); const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await invocation.execute(abortSignal)).toEqual({
llmContent: fileContent,
returnDisplay: '',
});
});
it('should return success result for an image file', async () => {
// A minimal 1x1 transparent PNG file.
const pngContent = Buffer.from([
137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82, 0, 0, 0,
1, 0, 0, 0, 1, 8, 6, 0, 0, 0, 31, 21, 196, 137, 0, 0, 0, 10, 73, 68,
65, 84, 120, 156, 99, 0, 1, 0, 0, 5, 0, 1, 13, 10, 45, 180, 0, 0, 0,
0, 73, 69, 78, 68, 174, 66, 96, 130,
]);
const filePath = path.join(tempRootDir, 'image.png');
await fsp.writeFile(filePath, pngContent);
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await invocation.execute(abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: pngContent.toString('base64'),
},
},
returnDisplay: `Read image file: image.png`,
});
});
it('should treat a non-image file with image extension as an image', async () => {
const filePath = path.join(tempRootDir, 'fake-image.png');
const fileContent = 'This is not a real png.';
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { absolute_path: filePath };
const invocation = tool.build(params) as ToolInvocation<
ReadFileToolParams,
ToolResult
>;
expect(await invocation.execute(abortSignal)).toEqual({
llmContent: {
inlineData: {
mimeType: 'image/png',
data: Buffer.from(fileContent).toString('base64'),
},
},
returnDisplay: `Read image file: fake-image.png`,
});
});
it('should pass offset and limit to read a slice of a text file', async () => {
const filePath = path.join(tempRootDir, 'paginated.txt');
const fileContent = Array.from(
{ length: 20 },
(_, i) => `Line ${i + 1}`,
).join('\n');
await fsp.writeFile(filePath, fileContent, 'utf-8');
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: filePath, absolute_path: filePath,
offset: 5, // Start from line 6
limit: 3,
}; };
const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`; const invocation = tool.build(params) as ToolInvocation<
expect(await tool.execute(params, abortSignal)).toEqual({ ReadFileToolParams,
llmContent: `Error: Invalid parameters provided. Reason: ${expectedError}`, ToolResult
returnDisplay: expectedError, >;
expect(await invocation.execute(abortSignal)).toEqual({
llmContent: [
'[File content truncated: showing lines 6-8 of 20 total lines. Use offset/limit parameters to view more.]',
'Line 6',
'Line 7',
'Line 8',
].join('\n'),
returnDisplay: 'Read lines 6-8 of 20 from paginated.txt',
});
});
describe('with .geminiignore', () => {
beforeEach(async () => {
await fsp.writeFile(
path.join(tempRootDir, '.geminiignore'),
['foo.*', 'ignored/'].join('\n'),
);
});
it('should throw error if path is ignored by a .geminiignore pattern', async () => {
const ignoredFilePath = path.join(tempRootDir, 'foo.bar');
await fsp.writeFile(ignoredFilePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: ignoredFilePath,
};
const expectedError = `File path '${ignoredFilePath}' is ignored by .geminiignore pattern(s).`;
expect(() => tool.build(params)).toThrow(expectedError);
});
it('should throw error if path is in an ignored directory', async () => {
const ignoredDirPath = path.join(tempRootDir, 'ignored');
await fsp.mkdir(ignoredDirPath);
const filePath = path.join(ignoredDirPath, 'somefile.txt');
await fsp.writeFile(filePath, 'content', 'utf-8');
const params: ReadFileToolParams = {
absolute_path: filePath,
};
const expectedError = `File path '${filePath}' is ignored by .geminiignore pattern(s).`;
expect(() => tool.build(params)).toThrow(expectedError);
}); });
}); });
}); });
@ -270,18 +290,16 @@ describe('ReadFileTool', () => {
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: path.join(tempRootDir, 'file.txt'), absolute_path: path.join(tempRootDir, 'file.txt'),
}; };
expect(tool.validateToolParams(params)).toBeNull(); expect(() => tool.build(params)).not.toThrow();
}); });
it('should reject paths outside workspace root', () => { it('should reject paths outside workspace root', () => {
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: '/etc/passwd', absolute_path: '/etc/passwd',
}; };
const error = tool.validateToolParams(params); expect(() => tool.build(params)).toThrow(
expect(error).toContain(
'File path must be within one of the workspace directories', 'File path must be within one of the workspace directories',
); );
expect(error).toContain(tempRootDir);
}); });
it('should provide clear error message with workspace directories', () => { it('should provide clear error message with workspace directories', () => {
@ -289,11 +307,9 @@ describe('ReadFileTool', () => {
const params: ReadFileToolParams = { const params: ReadFileToolParams = {
absolute_path: outsidePath, absolute_path: outsidePath,
}; };
const error = tool.validateToolParams(params); expect(() => tool.build(params)).toThrow(
expect(error).toContain(
'File path must be within one of the workspace directories', 'File path must be within one of the workspace directories',
); );
expect(error).toContain(tempRootDir);
}); });
}); });
}); });

View File

@ -7,7 +7,13 @@
import path from 'path'; import path from 'path';
import { SchemaValidator } from '../utils/schemaValidator.js'; import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import { BaseTool, Icon, ToolLocation, ToolResult } from './tools.js'; import {
BaseDeclarativeTool,
Icon,
ToolInvocation,
ToolLocation,
ToolResult,
} from './tools.js';
import { Type } from '@google/genai'; import { Type } from '@google/genai';
import { import {
processSingleFileContent, processSingleFileContent,
@ -39,10 +45,72 @@ export interface ReadFileToolParams {
limit?: number; limit?: number;
} }
class ReadFileToolInvocation
implements ToolInvocation<ReadFileToolParams, ToolResult>
{
constructor(
private config: Config,
public params: ReadFileToolParams,
) {}
getDescription(): string {
const relativePath = makeRelative(
this.params.absolute_path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
toolLocations(): ToolLocation[] {
return [{ path: this.params.absolute_path, line: this.params.offset }];
}
shouldConfirmExecute(): Promise<false> {
return Promise.resolve(false);
}
async execute(): Promise<ToolResult> {
const result = await processSingleFileContent(
this.params.absolute_path,
this.config.getTargetDir(),
this.params.offset,
this.params.limit,
);
if (result.error) {
return {
llmContent: result.error, // The detailed error for LLM
returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error
};
}
const lines =
typeof result.llmContent === 'string'
? result.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(this.params.absolute_path);
recordFileOperationMetric(
this.config,
FileOperation.READ,
lines,
mimetype,
path.extname(this.params.absolute_path),
);
return {
llmContent: result.llmContent || '',
returnDisplay: result.returnDisplay || '',
};
}
}
/** /**
* Implementation of the ReadFile tool logic * Implementation of the ReadFile tool logic
*/ */
export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> { export class ReadFileTool extends BaseDeclarativeTool<
ReadFileToolParams,
ToolResult
> {
static readonly Name: string = 'read_file'; static readonly Name: string = 'read_file';
constructor(private config: Config) { constructor(private config: Config) {
@ -75,7 +143,7 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
); );
} }
validateToolParams(params: ReadFileToolParams): string | null { protected validateToolParams(params: ReadFileToolParams): string | null {
const errors = SchemaValidator.validate(this.schema.parameters, params); const errors = SchemaValidator.validate(this.schema.parameters, params);
if (errors) { if (errors) {
return errors; return errors;
@ -106,67 +174,9 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
return null; return null;
} }
getDescription(params: ReadFileToolParams): string { protected createInvocation(
if (
!params ||
typeof params.absolute_path !== 'string' ||
params.absolute_path.trim() === ''
) {
return `Path unavailable`;
}
const relativePath = makeRelative(
params.absolute_path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
toolLocations(params: ReadFileToolParams): ToolLocation[] {
return [{ path: params.absolute_path, line: params.offset }];
}
async execute(
params: ReadFileToolParams, params: ReadFileToolParams,
_signal: AbortSignal, ): ToolInvocation<ReadFileToolParams, ToolResult> {
): Promise<ToolResult> { return new ReadFileToolInvocation(this.config, params);
const validationError = this.validateToolParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: validationError,
};
}
const result = await processSingleFileContent(
params.absolute_path,
this.config.getTargetDir(),
params.offset,
params.limit,
);
if (result.error) {
return {
llmContent: result.error, // The detailed error for LLM
returnDisplay: result.returnDisplay || 'Error reading file', // User-friendly error
};
}
const lines =
typeof result.llmContent === 'string'
? result.llmContent.split('\n').length
: undefined;
const mimetype = getSpecificMimeType(params.absolute_path);
recordFileOperationMetric(
this.config,
FileOperation.READ,
lines,
mimetype,
path.extname(params.absolute_path),
);
return {
llmContent: result.llmContent || '',
returnDisplay: result.returnDisplay || '',
};
} }
} }

View File

@ -21,7 +21,6 @@ import {
sanitizeParameters, sanitizeParameters,
} from './tool-registry.js'; } from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js'; import { DiscoveredMCPTool } from './mcp-tool.js';
import { BaseTool, Icon, ToolResult } from './tools.js';
import { import {
FunctionDeclaration, FunctionDeclaration,
CallableTool, CallableTool,
@ -32,6 +31,7 @@ import {
import { spawn } from 'node:child_process'; import { spawn } from 'node:child_process';
import fs from 'node:fs'; import fs from 'node:fs';
import { MockTool } from '../test-utils/tools.js';
vi.mock('node:fs'); vi.mock('node:fs');
@ -107,28 +107,6 @@ const createMockCallableTool = (
callTool: vi.fn(), callTool: vi.fn(),
}); });
class MockTool extends BaseTool<{ param: string }, ToolResult> {
constructor(
name = 'mock-tool',
displayName = 'A mock tool',
description = 'A mock tool description',
) {
super(name, displayName, description, Icon.Hammer, {
type: Type.OBJECT,
properties: {
param: { type: Type.STRING },
},
required: ['param'],
});
}
async execute(params: { param: string }): Promise<ToolResult> {
return {
llmContent: `Executed with ${params.param}`,
returnDisplay: `Executed with ${params.param}`,
};
}
}
const baseConfigParams: ConfigParameters = { const baseConfigParams: ConfigParameters = {
cwd: '/tmp', cwd: '/tmp',
model: 'test-model', model: 'test-model',

View File

@ -5,7 +5,7 @@
*/ */
import { FunctionDeclaration, Schema, Type } from '@google/genai'; import { FunctionDeclaration, Schema, Type } from '@google/genai';
import { Tool, ToolResult, BaseTool, Icon } from './tools.js'; import { AnyDeclarativeTool, Icon, ToolResult, BaseTool } from './tools.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { spawn } from 'node:child_process'; import { spawn } from 'node:child_process';
import { StringDecoder } from 'node:string_decoder'; import { StringDecoder } from 'node:string_decoder';
@ -125,7 +125,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
} }
export class ToolRegistry { export class ToolRegistry {
private tools: Map<string, Tool> = new Map(); private tools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config; private config: Config;
constructor(config: Config) { constructor(config: Config) {
@ -136,7 +136,7 @@ export class ToolRegistry {
* Registers a tool definition. * Registers a tool definition.
* @param tool - The tool object containing schema and execution logic. * @param tool - The tool object containing schema and execution logic.
*/ */
registerTool(tool: Tool): void { registerTool(tool: AnyDeclarativeTool): void {
if (this.tools.has(tool.name)) { if (this.tools.has(tool.name)) {
if (tool instanceof DiscoveredMCPTool) { if (tool instanceof DiscoveredMCPTool) {
tool = tool.asFullyQualifiedTool(); tool = tool.asFullyQualifiedTool();
@ -368,7 +368,7 @@ export class ToolRegistry {
/** /**
* Returns an array of all registered and discovered tool instances. * Returns an array of all registered and discovered tool instances.
*/ */
getAllTools(): Tool[] { getAllTools(): AnyDeclarativeTool[] {
return Array.from(this.tools.values()).sort((a, b) => return Array.from(this.tools.values()).sort((a, b) =>
a.displayName.localeCompare(b.displayName), a.displayName.localeCompare(b.displayName),
); );
@ -377,8 +377,8 @@ export class ToolRegistry {
/** /**
* Returns an array of tools registered from a specific MCP server. * Returns an array of tools registered from a specific MCP server.
*/ */
getToolsByServer(serverName: string): Tool[] { getToolsByServer(serverName: string): AnyDeclarativeTool[] {
const serverTools: Tool[] = []; const serverTools: AnyDeclarativeTool[] = [];
for (const tool of this.tools.values()) { for (const tool of this.tools.values()) {
if ((tool as DiscoveredMCPTool)?.serverName === serverName) { if ((tool as DiscoveredMCPTool)?.serverName === serverName) {
serverTools.push(tool); serverTools.push(tool);
@ -390,7 +390,7 @@ export class ToolRegistry {
/** /**
* Get the definition of a specific tool. * Get the definition of a specific tool.
*/ */
getTool(name: string): Tool | undefined { getTool(name: string): AnyDeclarativeTool | undefined {
return this.tools.get(name); return this.tools.get(name);
} }
} }

View File

@ -9,101 +9,243 @@ import { ToolErrorType } from './tool-error.js';
import { DiffUpdateResult } from '../ide/ideContext.js'; import { DiffUpdateResult } from '../ide/ideContext.js';
/** /**
* Interface representing the base Tool functionality * Represents a validated and ready-to-execute tool call.
* An instance of this is created by a `ToolBuilder`.
*/ */
export interface Tool< export interface ToolInvocation<
TParams = unknown, TParams extends object,
TResult extends ToolResult = ToolResult, TResult extends ToolResult,
> { > {
/** /**
* The internal name of the tool (used for API calls) * The validated parameters for this specific invocation.
*/ */
name: string; params: TParams;
/** /**
* The user-friendly display name of the tool * Gets a pre-execution description of the tool operation.
* @returns A markdown string describing what the tool will do.
*/ */
displayName: string; getDescription(): string;
/** /**
* Description of what the tool does * Determines what file system paths the tool will affect.
* @returns A list of such paths.
*/ */
description: string; toolLocations(): ToolLocation[];
/** /**
* The icon to display when interacting via ACP * Determines if the tool should prompt for confirmation before execution.
*/ * @returns Confirmation details or false if no confirmation is needed.
icon: Icon;
/**
* Function declaration schema from @google/genai
*/
schema: FunctionDeclaration;
/**
* Whether the tool's output should be rendered as markdown
*/
isOutputMarkdown: boolean;
/**
* Whether the tool supports live (streaming) output
*/
canUpdateOutput: boolean;
/**
* Validates the parameters for the tool
* Should be called from both `shouldConfirmExecute` and `execute`
* `shouldConfirmExecute` should return false immediately if invalid
* @param params Parameters to validate
* @returns An error message string if invalid, null otherwise
*/
validateToolParams(params: TParams): string | null;
/**
* Gets a pre-execution description of the tool operation
* @param params Parameters for the tool execution
* @returns A markdown string describing what the tool will do
* Optional for backward compatibility
*/
getDescription(params: TParams): string;
/**
* Determines what file system paths the tool will affect
* @param params Parameters for the tool execution
* @returns A list of such paths
*/
toolLocations(params: TParams): ToolLocation[];
/**
* Determines if the tool should prompt for confirmation before execution
* @param params Parameters for the tool execution
* @returns Whether execute should be confirmed.
*/ */
shouldConfirmExecute( shouldConfirmExecute(
params: TParams,
abortSignal: AbortSignal, abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false>; ): Promise<ToolCallConfirmationDetails | false>;
/** /**
* Executes the tool with the given parameters * Executes the tool with the validated parameters.
* @param params Parameters for the tool execution * @param signal AbortSignal for tool cancellation.
* @returns Result of the tool execution * @param updateOutput Optional callback to stream output.
* @returns Result of the tool execution.
*/ */
execute( execute(
params: TParams,
signal: AbortSignal, signal: AbortSignal,
updateOutput?: (output: string) => void, updateOutput?: (output: string) => void,
): Promise<TResult>; ): Promise<TResult>;
} }
/**
* A type alias for a tool invocation where the specific parameter and result types are not known.
*/
export type AnyToolInvocation = ToolInvocation<object, ToolResult>;
/**
* An adapter that wraps the legacy `Tool` interface to make it compatible
* with the new `ToolInvocation` pattern.
*/
export class LegacyToolInvocation<
TParams extends object,
TResult extends ToolResult,
> implements ToolInvocation<TParams, TResult>
{
constructor(
private readonly legacyTool: BaseTool<TParams, TResult>,
readonly params: TParams,
) {}
getDescription(): string {
return this.legacyTool.getDescription(this.params);
}
toolLocations(): ToolLocation[] {
return this.legacyTool.toolLocations(this.params);
}
shouldConfirmExecute(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
return this.legacyTool.shouldConfirmExecute(this.params, abortSignal);
}
execute(
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult> {
return this.legacyTool.execute(this.params, signal, updateOutput);
}
}
/**
* Interface for a tool builder that validates parameters and creates invocations.
*/
export interface ToolBuilder<
TParams extends object,
TResult extends ToolResult,
> {
/**
* The internal name of the tool (used for API calls).
*/
name: string;
/**
* The user-friendly display name of the tool.
*/
displayName: string;
/**
* Description of what the tool does.
*/
description: string;
/**
* The icon to display when interacting via ACP.
*/
icon: Icon;
/**
* Function declaration schema from @google/genai.
*/
schema: FunctionDeclaration;
/**
* Whether the tool's output should be rendered as markdown.
*/
isOutputMarkdown: boolean;
/**
* Whether the tool supports live (streaming) output.
*/
canUpdateOutput: boolean;
/**
* Validates raw parameters and builds a ready-to-execute invocation.
* @param params The raw, untrusted parameters from the model.
* @returns A valid `ToolInvocation` if successful. Throws an error if validation fails.
*/
build(params: TParams): ToolInvocation<TParams, TResult>;
}
/**
* New base class for tools that separates validation from execution.
* New tools should extend this class.
*/
export abstract class DeclarativeTool<
TParams extends object,
TResult extends ToolResult,
> implements ToolBuilder<TParams, TResult>
{
constructor(
readonly name: string,
readonly displayName: string,
readonly description: string,
readonly icon: Icon,
readonly parameterSchema: Schema,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
) {}
get schema(): FunctionDeclaration {
return {
name: this.name,
description: this.description,
parameters: this.parameterSchema,
};
}
/**
* Validates the raw tool parameters.
* Subclasses should override this to add custom validation logic
* beyond the JSON schema check.
* @param params The raw parameters from the model.
* @returns An error message string if invalid, null otherwise.
*/
protected validateToolParams(_params: TParams): string | null {
// Base implementation can be extended by subclasses.
return null;
}
/**
* The core of the new pattern. It validates parameters and, if successful,
* returns a `ToolInvocation` object that encapsulates the logic for the
* specific, validated call.
* @param params The raw, untrusted parameters from the model.
* @returns A `ToolInvocation` instance.
*/
abstract build(params: TParams): ToolInvocation<TParams, TResult>;
/**
* A convenience method that builds and executes the tool in one step.
* Throws an error if validation fails.
* @param params The raw, untrusted parameters from the model.
* @param signal AbortSignal for tool cancellation.
* @param updateOutput Optional callback to stream output.
* @returns The result of the tool execution.
*/
async buildAndExecute(
params: TParams,
signal: AbortSignal,
updateOutput?: (output: string) => void,
): Promise<TResult> {
const invocation = this.build(params);
return invocation.execute(signal, updateOutput);
}
}
/**
* New base class for declarative tools that separates validation from execution.
* New tools should extend this class, which provides a `build` method that
* validates parameters before deferring to a `createInvocation` method for
* the final `ToolInvocation` object instantiation.
*/
export abstract class BaseDeclarativeTool<
TParams extends object,
TResult extends ToolResult,
> extends DeclarativeTool<TParams, TResult> {
build(params: TParams): ToolInvocation<TParams, TResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
throw new Error(validationError);
}
return this.createInvocation(params);
}
protected abstract createInvocation(
params: TParams,
): ToolInvocation<TParams, TResult>;
}
/**
* A type alias for a declarative tool where the specific parameter and result types are not known.
*/
export type AnyDeclarativeTool = DeclarativeTool<object, ToolResult>;
/** /**
* Base implementation for tools with common functionality * Base implementation for tools with common functionality
* @deprecated Use `DeclarativeTool` for new tools.
*/ */
export abstract class BaseTool< export abstract class BaseTool<
TParams = unknown, TParams extends object,
TResult extends ToolResult = ToolResult, TResult extends ToolResult = ToolResult,
> implements Tool<TParams, TResult> > extends DeclarativeTool<TParams, TResult> {
{
/** /**
* Creates a new instance of BaseTool * Creates a new instance of BaseTool
* @param name Internal name of the tool (used for API calls) * @param name Internal name of the tool (used for API calls)
@ -121,17 +263,24 @@ export abstract class BaseTool<
readonly parameterSchema: Schema, readonly parameterSchema: Schema,
readonly isOutputMarkdown: boolean = true, readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false, readonly canUpdateOutput: boolean = false,
) {} ) {
super(
name,
displayName,
description,
icon,
parameterSchema,
isOutputMarkdown,
canUpdateOutput,
);
}
/** build(params: TParams): ToolInvocation<TParams, TResult> {
* Function declaration schema computed from name, description, and parameterSchema const validationError = this.validateToolParams(params);
*/ if (validationError) {
get schema(): FunctionDeclaration { throw new Error(validationError);
return { }
name: this.name, return new LegacyToolInvocation(this, params);
description: this.description,
parameters: this.parameterSchema,
};
} }
/** /**

View File

@ -26,7 +26,7 @@ import {
ensureCorrectFileContent, ensureCorrectFileContent,
} from '../utils/editCorrector.js'; } from '../utils/editCorrector.js';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { getSpecificMimeType } from '../utils/fileUtils.js'; import { getSpecificMimeType } from '../utils/fileUtils.js';
import { import {
recordFileOperationMetric, recordFileOperationMetric,
@ -66,7 +66,7 @@ interface GetCorrectedFileContentResult {
*/ */
export class WriteFileTool export class WriteFileTool
extends BaseTool<WriteFileToolParams, ToolResult> extends BaseTool<WriteFileToolParams, ToolResult>
implements ModifiableTool<WriteFileToolParams> implements ModifiableDeclarativeTool<WriteFileToolParams>
{ {
static readonly Name: string = 'write_file'; static readonly Name: string = 'write_file';