diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 36337c8f..d986c1eb 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -13,7 +13,7 @@ import { GeminiEventType, parseAndFormatApiError, } from '@google/gemini-cli-core'; -import { Content, Part, FunctionCall } from '@google/genai'; +import { Content, Part } from '@google/genai'; import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; import { handleAtCommand } from './ui/hooks/atCommandProcessor.js'; @@ -74,7 +74,7 @@ export async function runNonInteractive( ); return; } - const functionCalls: FunctionCall[] = []; + const toolCallRequests: ToolCallRequestInfo[] = []; const responseStream = geminiClient.sendMessageStream( currentMessages[0]?.parts || [], @@ -91,29 +91,13 @@ export async function runNonInteractive( if (event.type === GeminiEventType.Content) { process.stdout.write(event.value); } else if (event.type === GeminiEventType.ToolCallRequest) { - const toolCallRequest = event.value; - const fc: FunctionCall = { - name: toolCallRequest.name, - args: toolCallRequest.args, - id: toolCallRequest.callId, - }; - functionCalls.push(fc); + toolCallRequests.push(event.value); } } - if (functionCalls.length > 0) { + if (toolCallRequests.length > 0) { const toolResponseParts: Part[] = []; - - for (const fc of functionCalls) { - const callId = fc.id ?? `${fc.name}-${Date.now()}`; - const requestInfo: ToolCallRequestInfo = { - callId, - name: fc.name as string, - args: (fc.args ?? {}) as Record, - isClientInitiated: false, - prompt_id, - }; - + for (const requestInfo of toolCallRequests) { const toolResponse = await executeToolCall( config, requestInfo, @@ -122,7 +106,7 @@ export async function runNonInteractive( if (toolResponse.error) { console.error( - `Error executing tool ${fc.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, + `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, ); } diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts index 93e05387..e4238f99 100644 --- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -134,7 +134,6 @@ export function useReactToolScheduler( const scheduler = useMemo( () => new CoreToolScheduler({ - toolRegistry: config.getToolRegistry(), outputUpdateHandler, onAllToolCallsComplete: allToolCallsCompleteHandler, onToolCallsUpdate: toolCallsUpdateHandler, diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 1c400d52..291c1862 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -129,11 +129,11 @@ describe('CoreToolScheduler', () => { model: 'test-model', authType: 'oauth-personal', }), + getToolRegistry: () => mockToolRegistry, } as unknown as Config; const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: mockToolRegistry, onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -189,11 +189,11 @@ describe('CoreToolScheduler with payload', () => { model: 'test-model', authType: 'oauth-personal', }), + getToolRegistry: () => mockToolRegistry, } as unknown as Config; const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: mockToolRegistry, onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -462,15 +462,14 @@ class MockEditTool extends BaseDeclarativeTool< describe('CoreToolScheduler edit cancellation', () => { it('should preserve diff when an edit is cancelled', async () => { const mockEditTool = new MockEditTool(); - const declarativeTool = mockEditTool; const mockToolRegistry = { - getTool: () => declarativeTool, + getTool: () => mockEditTool, getFunctionDeclarations: () => [], tools: new Map(), discovery: {}, registerTool: () => {}, - getToolByName: () => declarativeTool, - getToolByDisplayName: () => declarativeTool, + getToolByName: () => mockEditTool, + getToolByDisplayName: () => mockEditTool, getTools: () => [], discoverTools: async () => {}, getAllTools: () => [], @@ -489,11 +488,11 @@ describe('CoreToolScheduler edit cancellation', () => { model: 'test-model', authType: 'oauth-personal', }), + getToolRegistry: () => mockToolRegistry, } as unknown as Config; const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: mockToolRegistry, onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -581,11 +580,11 @@ describe('CoreToolScheduler YOLO mode', () => { model: 'test-model', authType: 'oauth-personal', }), + getToolRegistry: () => mockToolRegistry, } as unknown as Config; const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: mockToolRegistry, onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -670,11 +669,11 @@ describe('CoreToolScheduler request queueing', () => { model: 'test-model', authType: 'oauth-personal', }), + getToolRegistry: () => mockToolRegistry, } as unknown as Config; const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: mockToolRegistry, onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -783,11 +782,11 @@ describe('CoreToolScheduler request queueing', () => { model: 'test-model', authType: 'oauth-personal', }), + getToolRegistry: () => mockToolRegistry, } as unknown as Config; const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: mockToolRegistry, onAllToolCallsComplete, onToolCallsUpdate, getPreferredEditor: () => 'vscode', @@ -864,7 +863,9 @@ describe('CoreToolScheduler request queueing', () => { getTools: () => [], discoverTools: async () => {}, discovery: {}, - }; + } as unknown as ToolRegistry; + + mockConfig.getToolRegistry = () => toolRegistry; const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); @@ -874,7 +875,6 @@ describe('CoreToolScheduler request queueing', () => { const scheduler = new CoreToolScheduler({ config: mockConfig, - toolRegistry: toolRegistry as unknown as ToolRegistry, onAllToolCallsComplete, onToolCallsUpdate: (toolCalls) => { onToolCallsUpdate(toolCalls); diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 5a2bb85d..a7923647 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -226,12 +226,11 @@ const createErrorResponse = ( }); interface CoreToolSchedulerOptions { - toolRegistry: ToolRegistry; + config: Config; outputUpdateHandler?: OutputUpdateHandler; onAllToolCallsComplete?: AllToolCallsCompleteHandler; onToolCallsUpdate?: ToolCallsUpdateHandler; getPreferredEditor: () => EditorType | undefined; - config: Config; onEditorClose: () => void; } @@ -255,7 +254,7 @@ export class CoreToolScheduler { constructor(options: CoreToolSchedulerOptions) { this.config = options.config; - this.toolRegistry = options.toolRegistry; + this.toolRegistry = options.config.getToolRegistry(); this.outputUpdateHandler = options.outputUpdateHandler; this.onAllToolCallsComplete = options.onAllToolCallsComplete; this.onToolCallsUpdate = options.onToolCallsUpdate; diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index 38afa697..8f16aaa7 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -12,6 +12,7 @@ import { ToolResult, Config, ToolErrorType, + ApprovalMode, } from '../index.js'; import { Part } from '@google/genai'; import { MockTool } from '../test-utils/tools.js'; @@ -27,10 +28,11 @@ describe('executeToolCall', () => { mockToolRegistry = { getTool: vi.fn(), - // Add other ToolRegistry methods if needed, or use a more complete mock } as unknown as ToolRegistry; mockConfig = { + getToolRegistry: () => mockToolRegistry, + getApprovalMode: () => ApprovalMode.DEFAULT, getSessionId: () => 'test-session-id', getUsageStatisticsEnabled: () => true, getDebugMode: () => false, @@ -38,7 +40,6 @@ describe('executeToolCall', () => { model: 'test-model', authType: 'oauth-personal', }), - getToolRegistry: () => mockToolRegistry, } as unknown as Config; abortController = new AbortController(); @@ -57,7 +58,7 @@ describe('executeToolCall', () => { returnDisplay: 'Success!', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(toolResult); + mockTool.executeFn.mockReturnValue(toolResult); const response = await executeToolCall( mockConfig, @@ -66,18 +67,18 @@ describe('executeToolCall', () => { ); expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool'); - expect(mockTool.validateBuildAndExecute).toHaveBeenCalledWith( - request.args, - abortController.signal, - ); - expect(response.callId).toBe('call1'); - expect(response.error).toBeUndefined(); - expect(response.resultDisplay).toBe('Success!'); - expect(response.responseParts).toEqual({ - functionResponse: { - name: 'testTool', - id: 'call1', - response: { output: 'Tool executed successfully' }, + expect(mockTool.executeFn).toHaveBeenCalledWith(request.args); + expect(response).toStrictEqual({ + callId: 'call1', + error: undefined, + errorType: undefined, + resultDisplay: 'Success!', + responseParts: { + functionResponse: { + name: 'testTool', + id: 'call1', + response: { output: 'Tool executed successfully' }, + }, }, }); }); @@ -98,23 +99,19 @@ describe('executeToolCall', () => { abortController.signal, ); - expect(response.callId).toBe('call2'); - expect(response.error).toBeInstanceOf(Error); - expect(response.error?.message).toBe( - 'Tool "nonexistentTool" not found in registry.', - ); - expect(response.resultDisplay).toBe( - 'Tool "nonexistentTool" not found in registry.', - ); - expect(response.responseParts).toEqual([ - { + expect(response).toStrictEqual({ + callId: 'call2', + error: new Error('Tool "nonexistentTool" not found in registry.'), + errorType: ToolErrorType.TOOL_NOT_REGISTERED, + resultDisplay: 'Tool "nonexistentTool" not found in registry.', + responseParts: { functionResponse: { name: 'nonexistentTool', id: 'call2', response: { error: 'Tool "nonexistentTool" not found in registry.' }, }, }, - ]); + }); }); it('should return an error if tool validation fails', async () => { @@ -125,24 +122,17 @@ describe('executeToolCall', () => { isClientInitiated: false, prompt_id: 'prompt-id-3', }; - const validationErrorResult: ToolResult = { - llmContent: 'Error: Invalid parameters', - returnDisplay: 'Invalid parameters', - error: { - message: 'Invalid parameters', - type: ToolErrorType.INVALID_TOOL_PARAMS, - }, - }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue( - validationErrorResult, - ); + vi.spyOn(mockTool, 'build').mockImplementation(() => { + throw new Error('Invalid parameters'); + }); const response = await executeToolCall( mockConfig, request, abortController.signal, ); + expect(response).toStrictEqual({ callId: 'call3', error: new Error('Invalid parameters'), @@ -152,7 +142,7 @@ describe('executeToolCall', () => { id: 'call3', name: 'testTool', response: { - output: 'Error: Invalid parameters', + error: 'Invalid parameters', }, }, }, @@ -177,9 +167,7 @@ describe('executeToolCall', () => { }, }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue( - executionErrorResult, - ); + mockTool.executeFn.mockReturnValue(executionErrorResult); const response = await executeToolCall( mockConfig, @@ -195,7 +183,7 @@ describe('executeToolCall', () => { id: 'call4', name: 'testTool', response: { - output: 'Error: Execution failed', + error: 'Execution failed', }, }, }, @@ -211,11 +199,10 @@ describe('executeToolCall', () => { isClientInitiated: false, prompt_id: 'prompt-id-5', }; - const executionError = new Error('Something went very wrong'); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.spyOn(mockTool, 'validateBuildAndExecute').mockRejectedValue( - executionError, - ); + mockTool.executeFn.mockImplementation(() => { + throw new Error('Something went very wrong'); + }); const response = await executeToolCall( mockConfig, @@ -223,19 +210,19 @@ describe('executeToolCall', () => { abortController.signal, ); - expect(response.callId).toBe('call5'); - expect(response.error).toBe(executionError); - expect(response.errorType).toBe(ToolErrorType.UNHANDLED_EXCEPTION); - expect(response.resultDisplay).toBe('Something went very wrong'); - expect(response.responseParts).toEqual([ - { + expect(response).toStrictEqual({ + callId: 'call5', + error: new Error('Something went very wrong'), + errorType: ToolErrorType.UNHANDLED_EXCEPTION, + resultDisplay: 'Something went very wrong', + responseParts: { functionResponse: { name: 'testTool', id: 'call5', response: { error: 'Something went very wrong' }, }, }, - ]); + }); }); it('should correctly format llmContent with inlineData', async () => { @@ -254,7 +241,7 @@ describe('executeToolCall', () => { returnDisplay: 'Image processed', }; vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); - vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(toolResult); + mockTool.executeFn.mockReturnValue(toolResult); const response = await executeToolCall( mockConfig, @@ -262,18 +249,23 @@ describe('executeToolCall', () => { abortController.signal, ); - expect(response.resultDisplay).toBe('Image processed'); - expect(response.responseParts).toEqual([ - { - functionResponse: { - name: 'testTool', - id: 'call6', - response: { - output: 'Binary content of type image/png was processed.', + expect(response).toStrictEqual({ + callId: 'call6', + error: undefined, + errorType: undefined, + resultDisplay: 'Image processed', + responseParts: [ + { + functionResponse: { + name: 'testTool', + id: 'call6', + response: { + output: 'Binary content of type image/png was processed.', + }, }, }, - }, - imageDataPart, - ]); + imageDataPart, + ], + }); }); }); diff --git a/packages/core/src/core/nonInteractiveToolExecutor.ts b/packages/core/src/core/nonInteractiveToolExecutor.ts index c116ca33..46ca71d2 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.ts @@ -4,166 +4,27 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { - FileDiff, - logToolCall, - ToolCallRequestInfo, - ToolCallResponseInfo, - ToolErrorType, - ToolResult, -} from '../index.js'; -import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; -import { Config } from '../config/config.js'; -import { convertToFunctionResponse } from './coreToolScheduler.js'; -import { ToolCallDecision } from '../telemetry/tool-call-decision.js'; +import { ToolCallRequestInfo, ToolCallResponseInfo, Config } from '../index.js'; +import { CoreToolScheduler } from './coreToolScheduler.js'; /** - * Executes a single tool call non-interactively. - * It does not handle confirmations, multiple calls, or live updates. + * Executes a single tool call non-interactively by leveraging the CoreToolScheduler. */ export async function executeToolCall( config: Config, toolCallRequest: ToolCallRequestInfo, - abortSignal?: AbortSignal, + abortSignal: AbortSignal, ): Promise { - const tool = config.getToolRegistry().getTool(toolCallRequest.name); - - const startTime = Date.now(); - if (!tool) { - const error = new Error( - `Tool "${toolCallRequest.name}" not found in registry.`, - ); - const durationMs = Date.now() - startTime; - logToolCall(config, { - 'event.name': 'tool_call', - 'event.timestamp': new Date().toISOString(), - function_name: toolCallRequest.name, - function_args: toolCallRequest.args, - duration_ms: durationMs, - success: false, - error: error.message, - prompt_id: toolCallRequest.prompt_id, - tool_type: 'native', - }); - // Ensure the response structure matches what the API expects for an error - return { - callId: toolCallRequest.callId, - responseParts: [ - { - functionResponse: { - id: toolCallRequest.callId, - name: toolCallRequest.name, - response: { error: error.message }, - }, - }, - ], - resultDisplay: error.message, - error, - errorType: ToolErrorType.TOOL_NOT_REGISTERED, - }; - } - - try { - // Directly execute without confirmation or live output handling - const effectiveAbortSignal = abortSignal ?? new AbortController().signal; - const toolResult: ToolResult = await tool.validateBuildAndExecute( - toolCallRequest.args, - effectiveAbortSignal, - // No live output callback for non-interactive mode - ); - - const tool_output = toolResult.llmContent; - - const tool_display = toolResult.returnDisplay; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let metadata: { [key: string]: any } = {}; - if ( - toolResult.error === undefined && - typeof tool_display === 'object' && - tool_display !== null && - 'diffStat' in tool_display - ) { - const diffStat = (tool_display as FileDiff).diffStat; - if (diffStat) { - metadata = { - ai_added_lines: diffStat.ai_added_lines, - ai_removed_lines: diffStat.ai_removed_lines, - user_added_lines: diffStat.user_added_lines, - user_removed_lines: diffStat.user_removed_lines, - }; - } - } - const durationMs = Date.now() - startTime; - logToolCall(config, { - 'event.name': 'tool_call', - 'event.timestamp': new Date().toISOString(), - function_name: toolCallRequest.name, - function_args: toolCallRequest.args, - duration_ms: durationMs, - success: toolResult.error === undefined, - error: - toolResult.error === undefined ? undefined : toolResult.error.message, - error_type: - toolResult.error === undefined ? undefined : toolResult.error.type, - prompt_id: toolCallRequest.prompt_id, - metadata, - decision: ToolCallDecision.AUTO_ACCEPT, - tool_type: - typeof tool !== 'undefined' && tool instanceof DiscoveredMCPTool - ? 'mcp' - : 'native', - }); - - const response = convertToFunctionResponse( - toolCallRequest.name, - toolCallRequest.callId, - tool_output, - ); - - return { - callId: toolCallRequest.callId, - responseParts: response, - resultDisplay: tool_display, - error: - toolResult.error === undefined - ? undefined - : new Error(toolResult.error.message), - errorType: - toolResult.error === undefined ? undefined : toolResult.error.type, - }; - } catch (e) { - const error = e instanceof Error ? e : new Error(String(e)); - const durationMs = Date.now() - startTime; - logToolCall(config, { - 'event.name': 'tool_call', - 'event.timestamp': new Date().toISOString(), - function_name: toolCallRequest.name, - function_args: toolCallRequest.args, - duration_ms: durationMs, - success: false, - error: error.message, - error_type: ToolErrorType.UNHANDLED_EXCEPTION, - prompt_id: toolCallRequest.prompt_id, - tool_type: - typeof tool !== 'undefined' && tool instanceof DiscoveredMCPTool - ? 'mcp' - : 'native', - }); - return { - callId: toolCallRequest.callId, - responseParts: [ - { - functionResponse: { - id: toolCallRequest.callId, - name: toolCallRequest.name, - response: { error: error.message }, - }, - }, - ], - resultDisplay: error.message, - error, - errorType: ToolErrorType.UNHANDLED_EXCEPTION, - }; - } + return new Promise((resolve, reject) => { + new CoreToolScheduler({ + config, + getPreferredEditor: () => undefined, + onEditorClose: () => {}, + onAllToolCallsComplete: async (completedToolCalls) => { + resolve(completedToolCalls[0].response); + }, + }) + .schedule(toolCallRequest, abortSignal) + .catch(reject); + }); }