From 8f2fa5a537b048c3d76c581279c8cd4ae52688b2 Mon Sep 17 00:00:00 2001 From: joshualitt Date: Fri, 15 Aug 2025 08:44:26 -0700 Subject: [PATCH] feat(core): Migrate MockTools to declarative pattern. (#6197) --- .../cli/src/ui/hooks/useToolScheduler.test.ts | 114 +++++++--- packages/cli/tsconfig.json | 1 - .../core/src/core/coreToolScheduler.test.ts | 213 +++++++----------- packages/core/src/core/coreToolScheduler.ts | 8 + packages/core/src/test-utils/tools.ts | 135 +++++++++-- packages/core/src/tools/tools.ts | 156 ------------- 6 files changed, 292 insertions(+), 335 deletions(-) diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 36fa8825..c5d968fe 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -24,7 +24,9 @@ import { Status as ToolCallStatusType, ApprovalMode, Kind, - BaseTool, + BaseDeclarativeTool, + BaseToolInvocation, + ToolInvocation, AnyDeclarativeTool, AnyToolInvocation, } from '@google/gemini-cli-core'; @@ -62,7 +64,41 @@ const mockConfig = { getDebugMode: () => false, }; -class MockTool extends BaseTool { +class MockToolInvocation extends BaseToolInvocation { + constructor( + private readonly tool: MockTool, + params: object, + ) { + super(params); + } + + getDescription(): string { + return JSON.stringify(this.params); + } + + override shouldConfirmExecute( + abortSignal: AbortSignal, + ): Promise { + return this.tool.shouldConfirmExecute(this.params, abortSignal); + } + + execute( + signal: AbortSignal, + updateOutput?: (output: string) => void, + terminalColumns?: number, + terminalRows?: number, + ): Promise { + return this.tool.execute( + this.params, + signal, + updateOutput, + terminalColumns, + terminalRows, + ); + } +} + +class MockTool extends BaseDeclarativeTool { constructor( name: string, displayName: string, @@ -80,11 +116,12 @@ class MockTool extends BaseTool { canUpdateOutput, ); if (shouldConfirm) { - this.shouldConfirmExecute = vi.fn( + this.shouldConfirmExecute.mockImplementation( async (): Promise => ({ type: 'edit', title: 'Mock Tool Requires Confirmation', onConfirm: mockOnUserConfirmForToolConfirmation, + filePath: 'mock', fileName: 'mockToolRequiresConfirmation.ts', fileDiff: 'Mock tool requires confirmation', originalContent: 'Original content', @@ -96,6 +133,12 @@ class MockTool extends BaseTool { execute = vi.fn(); shouldConfirmExecute = vi.fn(); + + protected createInvocation( + params: object, + ): ToolInvocation { + return new MockToolInvocation(this, params); + } } const mockTool = new MockTool('mockTool', 'Mock Tool'); @@ -142,6 +185,8 @@ describe('useReactToolScheduler in YOLO Mode', () => { onComplete, mockConfig as unknown as Config, setPendingHistoryItem, + () => undefined, + () => {}, ), ); @@ -160,7 +205,7 @@ describe('useReactToolScheduler in YOLO Mode', () => { callId: 'yoloCall', name: 'mockToolRequiresConfirmation', args: { data: 'any data' }, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -270,13 +315,14 @@ describe('useReactToolScheduler', () => { ( mockToolRequiresConfirmation.shouldConfirmExecute as Mock ).mockImplementation( - async (): Promise => ({ - onConfirm: mockOnUserConfirmForToolConfirmation, - fileName: 'mockToolRequiresConfirmation.ts', - fileDiff: 'Mock tool requires confirmation', - type: 'edit', - title: 'Mock Tool Requires Confirmation', - }), + async (): Promise => + ({ + onConfirm: mockOnUserConfirmForToolConfirmation, + fileName: 'mockToolRequiresConfirmation.ts', + fileDiff: 'Mock tool requires confirmation', + type: 'edit', + title: 'Mock Tool Requires Confirmation', + }) as any, ); vi.useFakeTimers(); @@ -293,6 +339,8 @@ describe('useReactToolScheduler', () => { onComplete, mockConfig as unknown as Config, setPendingHistoryItem, + () => undefined, + () => {}, ), ); @@ -316,7 +364,7 @@ describe('useReactToolScheduler', () => { callId: 'call1', name: 'mockTool', args: { param: 'value' }, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -365,7 +413,7 @@ describe('useReactToolScheduler', () => { callId: 'call1', name: 'nonexistentTool', args: {}, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -402,7 +450,7 @@ describe('useReactToolScheduler', () => { callId: 'call1', name: 'mockTool', args: {}, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -438,7 +486,7 @@ describe('useReactToolScheduler', () => { callId: 'call1', name: 'mockTool', args: {}, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -480,7 +528,7 @@ describe('useReactToolScheduler', () => { callId: 'callConfirm', name: 'mockToolRequiresConfirmation', args: { data: 'sensitive' }, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -536,7 +584,7 @@ describe('useReactToolScheduler', () => { callId: 'callConfirmCancel', name: 'mockToolRequiresConfirmation', args: {}, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -608,7 +656,7 @@ describe('useReactToolScheduler', () => { callId: 'liveCall', name: 'mockToolWithLiveOutput', args: {}, - }; + } as any; act(() => { schedule(request, new AbortController().signal); @@ -693,8 +741,8 @@ describe('useReactToolScheduler', () => { const { result } = renderScheduler(); const schedule = result.current[1]; const requests: ToolCallRequestInfo[] = [ - { callId: 'multi1', name: 'tool1', args: { p: 1 } }, - { callId: 'multi2', name: 'tool2', args: { p: 2 } }, + { callId: 'multi1', name: 'tool1', args: { p: 1 } } as any, + { callId: 'multi2', name: 'tool2', args: { p: 2 } } as any, ]; act(() => { @@ -777,12 +825,12 @@ describe('useReactToolScheduler', () => { callId: 'run1', name: 'mockTool', args: {}, - }; + } as any; const request2: ToolCallRequestInfo = { callId: 'run2', name: 'mockTool', args: {}, - }; + } as any; act(() => { schedule(request1, new AbortController().signal); @@ -818,7 +866,7 @@ describe('mapToDisplay', () => { callId: 'testCallId', name: 'testTool', args: { foo: 'bar' }, - }; + } as any; const baseTool = new MockTool('testTool', 'Test Tool Display'); @@ -834,9 +882,8 @@ describe('mapToDisplay', () => { } as PartUnion, ], resultDisplay: 'Test display output', - summary: 'Test summary', error: undefined, - }; + } as any; // Define a more specific type for extraProps for these tests // This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist. @@ -882,7 +929,7 @@ describe('mapToDisplay', () => { extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Executing, expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'awaiting_approval', @@ -897,6 +944,7 @@ describe('mapToDisplay', () => { serverName: 'testTool', toolName: 'testTool', toolDisplayName: 'Test Tool Display', + filePath: 'mock', fileName: 'test.ts', fileDiff: 'Test diff', originalContent: 'Original content', @@ -905,7 +953,7 @@ describe('mapToDisplay', () => { }, expectedStatus: ToolCallStatus.Confirming, expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'scheduled', @@ -913,7 +961,7 @@ describe('mapToDisplay', () => { extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Pending, expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'executing no live output', @@ -921,7 +969,7 @@ describe('mapToDisplay', () => { extraProps: { tool: baseTool, invocation: baseInvocation }, expectedStatus: ToolCallStatus.Executing, expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'executing with live output', @@ -934,7 +982,7 @@ describe('mapToDisplay', () => { expectedStatus: ToolCallStatus.Executing, expectedResultDisplay: 'Live test output', expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'success', @@ -947,7 +995,7 @@ describe('mapToDisplay', () => { expectedStatus: ToolCallStatus.Success, expectedResultDisplay: baseResponse.resultDisplay as any, expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'error tool not found', @@ -978,7 +1026,7 @@ describe('mapToDisplay', () => { expectedStatus: ToolCallStatus.Error, expectedResultDisplay: 'Execution failed display', expectedName: baseTool.displayName, // Changed from baseTool.name - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, { name: 'cancelled', @@ -994,7 +1042,7 @@ describe('mapToDisplay', () => { expectedStatus: ToolCallStatus.Canceled, expectedResultDisplay: 'Cancelled display', expectedName: baseTool.displayName, - expectedDescription: baseTool.getDescription(baseRequest.args), + expectedDescription: baseInvocation.getDescription(), }, ]; diff --git a/packages/cli/tsconfig.json b/packages/cli/tsconfig.json index 65324f37..46fce707 100644 --- a/packages/cli/tsconfig.json +++ b/packages/cli/tsconfig.json @@ -76,7 +76,6 @@ "src/ui/hooks/useGeminiStream.test.tsx", "src/ui/hooks/useKeypress.test.ts", "src/ui/hooks/usePhraseCycler.test.ts", - "src/ui/hooks/useToolScheduler.test.ts", "src/ui/hooks/vim.test.ts", "src/ui/utils/computeStats.test.ts", "src/ui/themes/theme.test.ts", diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 9d7d45ea..6ba85b04 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -4,7 +4,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -/* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi } from 'vitest'; import { CoreToolScheduler, @@ -12,65 +11,20 @@ import { convertToFunctionResponse, } from './coreToolScheduler.js'; import { - BaseTool, + BaseDeclarativeTool, + BaseToolInvocation, ToolCallConfirmationDetails, ToolConfirmationOutcome, ToolConfirmationPayload, + ToolInvocation, ToolResult, Config, Kind, ApprovalMode, + ToolRegistry, } from '../index.js'; import { Part, PartListUnion } from '@google/genai'; - -import { - ModifiableDeclarativeTool, - ModifyContext, -} from '../tools/modifiable-tool.js'; -import { MockTool } from '../test-utils/tools.js'; - -class MockModifiableTool - extends MockTool - implements ModifiableDeclarativeTool> -{ - constructor(name = 'mockModifiableTool') { - super(name); - this.shouldConfirm = true; - } - - getModifyContext( - _abortSignal: AbortSignal, - ): ModifyContext> { - return { - getFilePath: () => 'test.txt', - getCurrentContent: async () => 'old content', - getProposedContent: async () => 'new content', - createUpdatedParams: ( - _oldContent: string, - modifiedProposedContent: string, - _originalParams: Record, - ) => ({ newContent: modifiedProposedContent }), - }; - } - - override async shouldConfirmExecute(): Promise< - ToolCallConfirmationDetails | false - > { - if (this.shouldConfirm) { - return { - type: 'edit', - title: 'Confirm Mock Tool', - fileName: 'test.txt', - filePath: 'test.txt', - fileDiff: 'diff', - originalContent: 'originalContent', - newContent: 'newContent', - onConfirm: async () => {}, - }; - } - return false; - } -} +import { MockModifiableTool, MockTool } from '../test-utils/tools.js'; describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { @@ -81,7 +35,7 @@ describe('CoreToolScheduler', () => { getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), - discovery: {} as any, + discovery: {}, registerTool: () => {}, getToolByName: () => declarativeTool, getToolByDisplayName: () => declarativeTool, @@ -103,7 +57,7 @@ describe('CoreToolScheduler', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: Promise.resolve(toolRegistry as any), + toolRegistry: Promise.resolve(toolRegistry as unknown as ToolRegistry), onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -123,19 +77,6 @@ describe('CoreToolScheduler', () => { abortController.abort(); await scheduler.schedule([request], abortController.signal); - const confirmationDetails = await mockTool.shouldConfirmExecute( - {}, - abortController.signal, - ); - if (confirmationDetails) { - await scheduler.handleConfirmationResponse( - '1', - confirmationDetails.onConfirm, - ToolConfirmationOutcome.ProceedOnce, - abortController.signal, - ); - } - expect(onAllToolCallsComplete).toHaveBeenCalled(); const completedCalls = onAllToolCallsComplete.mock .calls[0][0] as ToolCall[]; @@ -151,7 +92,7 @@ describe('CoreToolScheduler with payload', () => { getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), - discovery: {} as any, + discovery: {}, registerTool: () => {}, getToolByName: () => declarativeTool, getToolByDisplayName: () => declarativeTool, @@ -173,7 +114,7 @@ describe('CoreToolScheduler with payload', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: Promise.resolve(toolRegistry as any), + toolRegistry: Promise.resolve(toolRegistry as unknown as ToolRegistry), onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -192,15 +133,22 @@ describe('CoreToolScheduler with payload', () => { await scheduler.schedule([request], abortController.signal); - const confirmationDetails = await mockTool.shouldConfirmExecute(); + await vi.waitFor(() => { + const awaitingCall = onToolCallsUpdate.mock.calls.find( + (call) => call[0][0].status === 'awaiting_approval', + )?.[0][0]; + expect(awaitingCall).toBeDefined(); + }); + + const awaitingCall = onToolCallsUpdate.mock.calls.find( + (call) => call[0][0].status === 'awaiting_approval', + )?.[0][0]; + const confirmationDetails = awaitingCall.confirmationDetails; if (confirmationDetails) { const payload: ToolConfirmationPayload = { newContent: 'final version' }; - await scheduler.handleConfirmationResponse( - '1', - confirmationDetails.onConfirm, + await confirmationDetails.onConfirm( ToolConfirmationOutcome.ProceedOnce, - abortController.signal, payload, ); } @@ -382,54 +330,66 @@ describe('convertToFunctionResponse', () => { }); }); +class MockEditToolInvocation extends BaseToolInvocation< + Record, + ToolResult +> { + constructor(params: Record) { + super(params); + } + + getDescription(): string { + return 'A mock edit tool invocation'; + } + + override async shouldConfirmExecute( + _abortSignal: AbortSignal, + ): Promise { + return { + type: 'edit', + title: 'Confirm Edit', + fileName: 'test.txt', + filePath: 'test.txt', + fileDiff: + '--- test.txt\n+++ test.txt\n@@ -1,1 +1,1 @@\n-old content\n+new content', + originalContent: 'old content', + newContent: 'new content', + onConfirm: async () => {}, + }; + } + + async execute(_abortSignal: AbortSignal): Promise { + return { + llmContent: 'Edited successfully', + returnDisplay: 'Edited successfully', + }; + } +} + +class MockEditTool extends BaseDeclarativeTool< + Record, + ToolResult +> { + constructor() { + super('mockEditTool', 'mockEditTool', 'A mock edit tool', Kind.Edit, {}); + } + + protected createInvocation( + params: Record, + ): ToolInvocation, ToolResult> { + return new MockEditToolInvocation(params); + } +} + describe('CoreToolScheduler edit cancellation', () => { it('should preserve diff when an edit is cancelled', async () => { - class MockEditTool extends BaseTool, ToolResult> { - constructor() { - super( - 'mockEditTool', - 'mockEditTool', - 'A mock edit tool', - Kind.Edit, - {}, - ); - } - - override async shouldConfirmExecute( - _params: Record, - _abortSignal: AbortSignal, - ): Promise { - return { - type: 'edit', - title: 'Confirm Edit', - fileName: 'test.txt', - filePath: 'test.txt', - fileDiff: - '--- test.txt\n+++ test.txt\n@@ -1,1 +1,1 @@\n-old content\n+new content', - originalContent: 'old content', - newContent: 'new content', - onConfirm: async () => {}, - }; - } - - async execute( - _params: Record, - _abortSignal: AbortSignal, - ): Promise { - return { - llmContent: 'Edited successfully', - returnDisplay: 'Edited successfully', - }; - } - } - const mockEditTool = new MockEditTool(); const declarativeTool = mockEditTool; const toolRegistry = { getTool: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), - discovery: {} as any, + discovery: {}, registerTool: () => {}, getToolByName: () => declarativeTool, getToolByDisplayName: () => declarativeTool, @@ -451,7 +411,7 @@ describe('CoreToolScheduler edit cancellation', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: Promise.resolve(toolRegistry as any), + toolRegistry: Promise.resolve(toolRegistry as unknown as ToolRegistry), onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -478,17 +438,9 @@ describe('CoreToolScheduler edit cancellation', () => { expect(awaitingCall).toBeDefined(); // Cancel the edit - const confirmationDetails = await mockEditTool.shouldConfirmExecute( - {}, - abortController.signal, - ); + const confirmationDetails = awaitingCall.confirmationDetails; if (confirmationDetails) { - await scheduler.handleConfirmationResponse( - '1', - confirmationDetails.onConfirm, - ToolConfirmationOutcome.Cancel, - abortController.signal, - ); + await confirmationDetails.onConfirm(ToolConfirmationOutcome.Cancel); } expect(onAllToolCallsComplete).toHaveBeenCalled(); @@ -498,6 +450,7 @@ describe('CoreToolScheduler edit cancellation', () => { expect(completedCalls[0].status).toBe('cancelled'); // Check that the diff is preserved + // eslint-disable-next-line @typescript-eslint/no-explicit-any const cancelledCall = completedCalls[0] as any; expect(cancelledCall.response.resultDisplay).toBeDefined(); expect(cancelledCall.response.resultDisplay.fileDiff).toBe( @@ -525,7 +478,7 @@ describe('CoreToolScheduler YOLO mode', () => { // Other properties are not needed for this test but are included for type consistency. getFunctionDeclarations: () => [], tools: new Map(), - discovery: {} as any, + discovery: {}, registerTool: () => {}, getToolByDisplayName: () => declarativeTool, getTools: () => [], @@ -547,7 +500,7 @@ describe('CoreToolScheduler YOLO mode', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: Promise.resolve(toolRegistry as any), + toolRegistry: Promise.resolve(toolRegistry as unknown as ToolRegistry), onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -612,7 +565,7 @@ describe('CoreToolScheduler request queueing', () => { getToolByName: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), - discovery: {} as any, + discovery: {}, registerTool: () => {}, getToolByDisplayName: () => declarativeTool, getTools: () => [], @@ -633,7 +586,7 @@ describe('CoreToolScheduler request queueing', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: Promise.resolve(toolRegistry as any), + toolRegistry: Promise.resolve(toolRegistry as unknown as ToolRegistry), onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -722,7 +675,7 @@ describe('CoreToolScheduler request queueing', () => { getToolByName: () => declarativeTool, getFunctionDeclarations: () => [], tools: new Map(), - discovery: {} as any, + discovery: {}, registerTool: () => {}, getToolByDisplayName: () => declarativeTool, getTools: () => [], @@ -743,7 +696,7 @@ describe('CoreToolScheduler request queueing', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: Promise.resolve(toolRegistry as any), + toolRegistry: Promise.resolve(toolRegistry as unknown as ToolRegistry), onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index aac8f9a6..bccb724a 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -594,6 +594,14 @@ export class CoreToolScheduler { const { request: reqInfo, invocation } = toolCall; try { + if (signal.aborted) { + this.setStatusInternal( + reqInfo.callId, + 'cancelled', + 'Tool call cancelled by user.', + ); + continue; + } if (this.config.getApprovalMode() === ApprovalMode.YOLO) { this.setToolCallOutcome( reqInfo.callId, diff --git a/packages/core/src/test-utils/tools.ts b/packages/core/src/test-utils/tools.ts index da642212..0e3e6b86 100644 --- a/packages/core/src/test-utils/tools.ts +++ b/packages/core/src/test-utils/tools.ts @@ -6,17 +6,67 @@ import { vi } from 'vitest'; import { - BaseTool, + BaseDeclarativeTool, + BaseToolInvocation, ToolCallConfirmationDetails, + ToolInvocation, ToolResult, Kind, } from '../tools/tools.js'; import { Schema, Type } from '@google/genai'; +import { + ModifiableDeclarativeTool, + ModifyContext, +} from '../tools/modifiable-tool.js'; + +class MockToolInvocation extends BaseToolInvocation< + { [key: string]: unknown }, + ToolResult +> { + constructor( + private readonly tool: MockTool, + params: { [key: string]: unknown }, + ) { + super(params); + } + + async execute(_abortSignal: AbortSignal): Promise { + const result = this.tool.executeFn(this.params); + return ( + result ?? { + llmContent: `Tool ${this.tool.name} executed successfully.`, + returnDisplay: `Tool ${this.tool.name} executed successfully.`, + } + ); + } + + override async shouldConfirmExecute( + _abortSignal: AbortSignal, + ): Promise { + if (this.tool.shouldConfirm) { + return { + type: 'exec' as const, + title: `Confirm ${this.tool.displayName}`, + command: this.tool.name, + rootCommand: this.tool.name, + onConfirm: async () => {}, + }; + } + return false; + } + + getDescription(): string { + return `A mock tool invocation for ${this.tool.name}`; + } +} /** * A highly configurable mock tool for testing purposes. */ -export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> { +export class MockTool extends BaseDeclarativeTool< + { [key: string]: unknown }, + ToolResult +> { executeFn = vi.fn(); shouldConfirm = false; @@ -32,32 +82,87 @@ export class MockTool extends BaseTool<{ [key: string]: unknown }, ToolResult> { super(name, displayName ?? name, description, Kind.Other, params); } - async execute( - params: { [key: string]: unknown }, - _abortSignal: AbortSignal, - ): Promise { - const result = this.executeFn(params); + protected createInvocation(params: { + [key: string]: unknown; + }): ToolInvocation<{ [key: string]: unknown }, ToolResult> { + return new MockToolInvocation(this, params); + } +} + +export class MockModifiableToolInvocation extends BaseToolInvocation< + Record, + ToolResult +> { + constructor( + private readonly tool: MockModifiableTool, + params: Record, + ) { + super(params); + } + + async execute(_abortSignal: AbortSignal): Promise { + const result = this.tool.executeFn(this.params); return ( result ?? { - llmContent: `Tool ${this.name} executed successfully.`, - returnDisplay: `Tool ${this.name} executed successfully.`, + llmContent: `Tool ${this.tool.name} executed successfully.`, + returnDisplay: `Tool ${this.tool.name} executed successfully.`, } ); } override async shouldConfirmExecute( - _params: { [key: string]: unknown }, _abortSignal: AbortSignal, ): Promise { - if (this.shouldConfirm) { + if (this.tool.shouldConfirm) { return { - type: 'exec' as const, - title: `Confirm ${this.displayName}`, - command: this.name, - rootCommand: this.name, + type: 'edit', + title: 'Confirm Mock Tool', + fileName: 'test.txt', + filePath: 'test.txt', + fileDiff: 'diff', + originalContent: 'originalContent', + newContent: 'newContent', onConfirm: async () => {}, }; } return false; } + + getDescription(): string { + return `A mock modifiable tool invocation for ${this.tool.name}`; + } +} + +/** + * Configurable mock modifiable tool for testing. + */ +export class MockModifiableTool + extends MockTool + implements ModifiableDeclarativeTool> +{ + constructor(name = 'mockModifiableTool') { + super(name); + this.shouldConfirm = true; + } + + getModifyContext( + _abortSignal: AbortSignal, + ): ModifyContext> { + return { + getFilePath: () => 'test.txt', + getCurrentContent: async () => 'old content', + getProposedContent: async () => 'new content', + createUpdatedParams: ( + _oldContent: string, + modifiedProposedContent: string, + _originalParams: Record, + ) => ({ newContent: modifiedProposedContent }), + }; + } + + protected override createInvocation( + params: Record, + ): ToolInvocation, ToolResult> { + return new MockModifiableToolInvocation(this, params); + } } diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index ee8b830b..5684e4ac 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -90,50 +90,6 @@ export abstract class BaseToolInvocation< */ export type AnyToolInvocation = ToolInvocation; -/** - * An adapter that wraps the legacy `Tool` interface to make it compatible - * with the new `ToolInvocation` pattern. - */ -export class LegacyToolInvocation< - TParams extends object, - TResult extends ToolResult, -> implements ToolInvocation -{ - constructor( - private readonly legacyTool: BaseTool, - readonly params: TParams, - ) {} - - getDescription(): string { - return this.legacyTool.getDescription(this.params); - } - - toolLocations(): ToolLocation[] { - return this.legacyTool.toolLocations(this.params); - } - - shouldConfirmExecute( - abortSignal: AbortSignal, - ): Promise { - return this.legacyTool.shouldConfirmExecute(this.params, abortSignal); - } - - execute( - signal: AbortSignal, - updateOutput?: (output: string) => void, - terminalColumns?: number, - terminalRows?: number, - ): Promise { - return this.legacyTool.execute( - this.params, - signal, - updateOutput, - terminalColumns, - terminalRows, - ); - } -} - /** * Interface for a tool builder that validates parameters and creates invocations. */ @@ -285,118 +241,6 @@ export abstract class BaseDeclarativeTool< */ export type AnyDeclarativeTool = DeclarativeTool; -/** - * Base implementation for tools with common functionality - * @deprecated Use `DeclarativeTool` for new tools. - */ -export abstract class BaseTool< - TParams extends object, - TResult extends ToolResult = ToolResult, -> extends DeclarativeTool { - /** - * Creates a new instance of BaseTool - * @param name Internal name of the tool (used for API calls) - * @param displayName User-friendly display name of the tool - * @param description Description of what the tool does - * @param isOutputMarkdown Whether the tool's output should be rendered as markdown - * @param canUpdateOutput Whether the tool supports live (streaming) output - * @param parameterSchema JSON Schema defining the parameters - */ - constructor( - override readonly name: string, - override readonly displayName: string, - override readonly description: string, - override readonly kind: Kind, - override readonly parameterSchema: unknown, - override readonly isOutputMarkdown: boolean = true, - override readonly canUpdateOutput: boolean = false, - ) { - super( - name, - displayName, - description, - kind, - parameterSchema, - isOutputMarkdown, - canUpdateOutput, - ); - } - - build(params: TParams): ToolInvocation { - const validationError = this.validateToolParams(params); - if (validationError) { - throw new Error(validationError); - } - return new LegacyToolInvocation(this, params); - } - - /** - * Validates the parameters for the tool - * This is a placeholder implementation and should be overridden - * Should be called from both `shouldConfirmExecute` and `execute` - * `shouldConfirmExecute` should return false immediately if invalid - * @param params Parameters to validate - * @returns An error message string if invalid, null otherwise - */ - // eslint-disable-next-line @typescript-eslint/no-unused-vars - override validateToolParams(params: TParams): string | null { - // Implementation would typically use a JSON Schema validator - // This is a placeholder that should be implemented by derived classes - return null; - } - - /** - * Gets a pre-execution description of the tool operation - * Default implementation that should be overridden by derived classes - * @param params Parameters for the tool execution - * @returns A markdown string describing what the tool will do - */ - getDescription(params: TParams): string { - return JSON.stringify(params); - } - - /** - * Determines if the tool should prompt for confirmation before execution - * @param params Parameters for the tool execution - * @returns Whether or not execute should be confirmed by the user. - */ - shouldConfirmExecute( - // eslint-disable-next-line @typescript-eslint/no-unused-vars - params: TParams, - // eslint-disable-next-line @typescript-eslint/no-unused-vars - abortSignal: AbortSignal, - ): Promise { - return Promise.resolve(false); - } - - /** - * Determines what file system paths the tool will affect - * @param params Parameters for the tool execution - * @returns A list of such paths - */ - toolLocations( - // eslint-disable-next-line @typescript-eslint/no-unused-vars - params: TParams, - ): ToolLocation[] { - return []; - } - - /** - * Abstract method to execute the tool with the given parameters - * Must be implemented by derived classes - * @param params Parameters for the tool execution - * @param signal AbortSignal for tool cancellation - * @returns Result of the tool execution - */ - abstract execute( - params: TParams, - signal: AbortSignal, - updateOutput?: (output: string) => void, - terminalColumns?: number, - terminalRows?: number, - ): Promise; -} - export interface ToolResult { /** * A short, one-line summary of the tool's action and result.