Reuse CoreToolScheduler for nonInteractiveToolExecutor (#6714)

This commit is contained in:
Tommaso Sciortino 2025-08-21 16:49:12 -07:00 committed by GitHub
parent 29699274bb
commit 15c62bade3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 93 additions and 258 deletions

View File

@ -13,7 +13,7 @@ import {
GeminiEventType,
parseAndFormatApiError,
} from '@google/gemini-cli-core';
import { Content, Part, FunctionCall } from '@google/genai';
import { Content, Part } from '@google/genai';
import { ConsolePatcher } from './ui/utils/ConsolePatcher.js';
import { handleAtCommand } from './ui/hooks/atCommandProcessor.js';
@ -74,7 +74,7 @@ export async function runNonInteractive(
);
return;
}
const functionCalls: FunctionCall[] = [];
const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
@ -91,29 +91,13 @@ export async function runNonInteractive(
if (event.type === GeminiEventType.Content) {
process.stdout.write(event.value);
} else if (event.type === GeminiEventType.ToolCallRequest) {
const toolCallRequest = event.value;
const fc: FunctionCall = {
name: toolCallRequest.name,
args: toolCallRequest.args,
id: toolCallRequest.callId,
};
functionCalls.push(fc);
toolCallRequests.push(event.value);
}
}
if (functionCalls.length > 0) {
if (toolCallRequests.length > 0) {
const toolResponseParts: Part[] = [];
for (const fc of functionCalls) {
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
const requestInfo: ToolCallRequestInfo = {
callId,
name: fc.name as string,
args: (fc.args ?? {}) as Record<string, unknown>,
isClientInitiated: false,
prompt_id,
};
for (const requestInfo of toolCallRequests) {
const toolResponse = await executeToolCall(
config,
requestInfo,
@ -122,7 +106,7 @@ export async function runNonInteractive(
if (toolResponse.error) {
console.error(
`Error executing tool ${fc.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
`Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
);
}

View File

@ -134,7 +134,6 @@ export function useReactToolScheduler(
const scheduler = useMemo(
() =>
new CoreToolScheduler({
toolRegistry: config.getToolRegistry(),
outputUpdateHandler,
onAllToolCallsComplete: allToolCallsCompleteHandler,
onToolCallsUpdate: toolCallsUpdateHandler,

View File

@ -129,11 +129,11 @@ describe('CoreToolScheduler', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@ -189,11 +189,11 @@ describe('CoreToolScheduler with payload', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@ -462,15 +462,14 @@ class MockEditTool extends BaseDeclarativeTool<
describe('CoreToolScheduler edit cancellation', () => {
it('should preserve diff when an edit is cancelled', async () => {
const mockEditTool = new MockEditTool();
const declarativeTool = mockEditTool;
const mockToolRegistry = {
getTool: () => declarativeTool,
getTool: () => mockEditTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByName: () => declarativeTool,
getToolByDisplayName: () => declarativeTool,
getToolByName: () => mockEditTool,
getToolByDisplayName: () => mockEditTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
@ -489,11 +488,11 @@ describe('CoreToolScheduler edit cancellation', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@ -581,11 +580,11 @@ describe('CoreToolScheduler YOLO mode', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@ -670,11 +669,11 @@ describe('CoreToolScheduler request queueing', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@ -783,11 +782,11 @@ describe('CoreToolScheduler request queueing', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@ -864,7 +863,9 @@ describe('CoreToolScheduler request queueing', () => {
getTools: () => [],
discoverTools: async () => {},
discovery: {},
};
} as unknown as ToolRegistry;
mockConfig.getToolRegistry = () => toolRegistry;
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
@ -874,7 +875,6 @@ describe('CoreToolScheduler request queueing', () => {
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: toolRegistry as unknown as ToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate: (toolCalls) => {
onToolCallsUpdate(toolCalls);

View File

@ -226,12 +226,11 @@ const createErrorResponse = (
});
interface CoreToolSchedulerOptions {
toolRegistry: ToolRegistry;
config: Config;
outputUpdateHandler?: OutputUpdateHandler;
onAllToolCallsComplete?: AllToolCallsCompleteHandler;
onToolCallsUpdate?: ToolCallsUpdateHandler;
getPreferredEditor: () => EditorType | undefined;
config: Config;
onEditorClose: () => void;
}
@ -255,7 +254,7 @@ export class CoreToolScheduler {
constructor(options: CoreToolSchedulerOptions) {
this.config = options.config;
this.toolRegistry = options.toolRegistry;
this.toolRegistry = options.config.getToolRegistry();
this.outputUpdateHandler = options.outputUpdateHandler;
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;

View File

@ -12,6 +12,7 @@ import {
ToolResult,
Config,
ToolErrorType,
ApprovalMode,
} from '../index.js';
import { Part } from '@google/genai';
import { MockTool } from '../test-utils/tools.js';
@ -27,10 +28,11 @@ describe('executeToolCall', () => {
mockToolRegistry = {
getTool: vi.fn(),
// Add other ToolRegistry methods if needed, or use a more complete mock
} as unknown as ToolRegistry;
mockConfig = {
getToolRegistry: () => mockToolRegistry,
getApprovalMode: () => ApprovalMode.DEFAULT,
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
@ -38,7 +40,6 @@ describe('executeToolCall', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
abortController = new AbortController();
@ -57,7 +58,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Success!',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(toolResult);
mockTool.executeFn.mockReturnValue(toolResult);
const response = await executeToolCall(
mockConfig,
@ -66,18 +67,18 @@ describe('executeToolCall', () => {
);
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool');
expect(mockTool.validateBuildAndExecute).toHaveBeenCalledWith(
request.args,
abortController.signal,
);
expect(response.callId).toBe('call1');
expect(response.error).toBeUndefined();
expect(response.resultDisplay).toBe('Success!');
expect(response.responseParts).toEqual({
functionResponse: {
name: 'testTool',
id: 'call1',
response: { output: 'Tool executed successfully' },
expect(mockTool.executeFn).toHaveBeenCalledWith(request.args);
expect(response).toStrictEqual({
callId: 'call1',
error: undefined,
errorType: undefined,
resultDisplay: 'Success!',
responseParts: {
functionResponse: {
name: 'testTool',
id: 'call1',
response: { output: 'Tool executed successfully' },
},
},
});
});
@ -98,23 +99,19 @@ describe('executeToolCall', () => {
abortController.signal,
);
expect(response.callId).toBe('call2');
expect(response.error).toBeInstanceOf(Error);
expect(response.error?.message).toBe(
'Tool "nonexistentTool" not found in registry.',
);
expect(response.resultDisplay).toBe(
'Tool "nonexistentTool" not found in registry.',
);
expect(response.responseParts).toEqual([
{
expect(response).toStrictEqual({
callId: 'call2',
error: new Error('Tool "nonexistentTool" not found in registry.'),
errorType: ToolErrorType.TOOL_NOT_REGISTERED,
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
responseParts: {
functionResponse: {
name: 'nonexistentTool',
id: 'call2',
response: { error: 'Tool "nonexistentTool" not found in registry.' },
},
},
]);
});
});
it('should return an error if tool validation fails', async () => {
@ -125,24 +122,17 @@ describe('executeToolCall', () => {
isClientInitiated: false,
prompt_id: 'prompt-id-3',
};
const validationErrorResult: ToolResult = {
llmContent: 'Error: Invalid parameters',
returnDisplay: 'Invalid parameters',
error: {
message: 'Invalid parameters',
type: ToolErrorType.INVALID_TOOL_PARAMS,
},
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(
validationErrorResult,
);
vi.spyOn(mockTool, 'build').mockImplementation(() => {
throw new Error('Invalid parameters');
});
const response = await executeToolCall(
mockConfig,
request,
abortController.signal,
);
expect(response).toStrictEqual({
callId: 'call3',
error: new Error('Invalid parameters'),
@ -152,7 +142,7 @@ describe('executeToolCall', () => {
id: 'call3',
name: 'testTool',
response: {
output: 'Error: Invalid parameters',
error: 'Invalid parameters',
},
},
},
@ -177,9 +167,7 @@ describe('executeToolCall', () => {
},
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(
executionErrorResult,
);
mockTool.executeFn.mockReturnValue(executionErrorResult);
const response = await executeToolCall(
mockConfig,
@ -195,7 +183,7 @@ describe('executeToolCall', () => {
id: 'call4',
name: 'testTool',
response: {
output: 'Error: Execution failed',
error: 'Execution failed',
},
},
},
@ -211,11 +199,10 @@ describe('executeToolCall', () => {
isClientInitiated: false,
prompt_id: 'prompt-id-5',
};
const executionError = new Error('Something went very wrong');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.spyOn(mockTool, 'validateBuildAndExecute').mockRejectedValue(
executionError,
);
mockTool.executeFn.mockImplementation(() => {
throw new Error('Something went very wrong');
});
const response = await executeToolCall(
mockConfig,
@ -223,19 +210,19 @@ describe('executeToolCall', () => {
abortController.signal,
);
expect(response.callId).toBe('call5');
expect(response.error).toBe(executionError);
expect(response.errorType).toBe(ToolErrorType.UNHANDLED_EXCEPTION);
expect(response.resultDisplay).toBe('Something went very wrong');
expect(response.responseParts).toEqual([
{
expect(response).toStrictEqual({
callId: 'call5',
error: new Error('Something went very wrong'),
errorType: ToolErrorType.UNHANDLED_EXCEPTION,
resultDisplay: 'Something went very wrong',
responseParts: {
functionResponse: {
name: 'testTool',
id: 'call5',
response: { error: 'Something went very wrong' },
},
},
]);
});
});
it('should correctly format llmContent with inlineData', async () => {
@ -254,7 +241,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Image processed',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(toolResult);
mockTool.executeFn.mockReturnValue(toolResult);
const response = await executeToolCall(
mockConfig,
@ -262,18 +249,23 @@ describe('executeToolCall', () => {
abortController.signal,
);
expect(response.resultDisplay).toBe('Image processed');
expect(response.responseParts).toEqual([
{
functionResponse: {
name: 'testTool',
id: 'call6',
response: {
output: 'Binary content of type image/png was processed.',
expect(response).toStrictEqual({
callId: 'call6',
error: undefined,
errorType: undefined,
resultDisplay: 'Image processed',
responseParts: [
{
functionResponse: {
name: 'testTool',
id: 'call6',
response: {
output: 'Binary content of type image/png was processed.',
},
},
},
},
imageDataPart,
]);
imageDataPart,
],
});
});
});

View File

@ -4,166 +4,27 @@
* SPDX-License-Identifier: Apache-2.0
*/
import {
FileDiff,
logToolCall,
ToolCallRequestInfo,
ToolCallResponseInfo,
ToolErrorType,
ToolResult,
} from '../index.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { Config } from '../config/config.js';
import { convertToFunctionResponse } from './coreToolScheduler.js';
import { ToolCallDecision } from '../telemetry/tool-call-decision.js';
import { ToolCallRequestInfo, ToolCallResponseInfo, Config } from '../index.js';
import { CoreToolScheduler } from './coreToolScheduler.js';
/**
* Executes a single tool call non-interactively.
* It does not handle confirmations, multiple calls, or live updates.
* Executes a single tool call non-interactively by leveraging the CoreToolScheduler.
*/
export async function executeToolCall(
config: Config,
toolCallRequest: ToolCallRequestInfo,
abortSignal?: AbortSignal,
abortSignal: AbortSignal,
): Promise<ToolCallResponseInfo> {
const tool = config.getToolRegistry().getTool(toolCallRequest.name);
const startTime = Date.now();
if (!tool) {
const error = new Error(
`Tool "${toolCallRequest.name}" not found in registry.`,
);
const durationMs = Date.now() - startTime;
logToolCall(config, {
'event.name': 'tool_call',
'event.timestamp': new Date().toISOString(),
function_name: toolCallRequest.name,
function_args: toolCallRequest.args,
duration_ms: durationMs,
success: false,
error: error.message,
prompt_id: toolCallRequest.prompt_id,
tool_type: 'native',
});
// Ensure the response structure matches what the API expects for an error
return {
callId: toolCallRequest.callId,
responseParts: [
{
functionResponse: {
id: toolCallRequest.callId,
name: toolCallRequest.name,
response: { error: error.message },
},
},
],
resultDisplay: error.message,
error,
errorType: ToolErrorType.TOOL_NOT_REGISTERED,
};
}
try {
// Directly execute without confirmation or live output handling
const effectiveAbortSignal = abortSignal ?? new AbortController().signal;
const toolResult: ToolResult = await tool.validateBuildAndExecute(
toolCallRequest.args,
effectiveAbortSignal,
// No live output callback for non-interactive mode
);
const tool_output = toolResult.llmContent;
const tool_display = toolResult.returnDisplay;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let metadata: { [key: string]: any } = {};
if (
toolResult.error === undefined &&
typeof tool_display === 'object' &&
tool_display !== null &&
'diffStat' in tool_display
) {
const diffStat = (tool_display as FileDiff).diffStat;
if (diffStat) {
metadata = {
ai_added_lines: diffStat.ai_added_lines,
ai_removed_lines: diffStat.ai_removed_lines,
user_added_lines: diffStat.user_added_lines,
user_removed_lines: diffStat.user_removed_lines,
};
}
}
const durationMs = Date.now() - startTime;
logToolCall(config, {
'event.name': 'tool_call',
'event.timestamp': new Date().toISOString(),
function_name: toolCallRequest.name,
function_args: toolCallRequest.args,
duration_ms: durationMs,
success: toolResult.error === undefined,
error:
toolResult.error === undefined ? undefined : toolResult.error.message,
error_type:
toolResult.error === undefined ? undefined : toolResult.error.type,
prompt_id: toolCallRequest.prompt_id,
metadata,
decision: ToolCallDecision.AUTO_ACCEPT,
tool_type:
typeof tool !== 'undefined' && tool instanceof DiscoveredMCPTool
? 'mcp'
: 'native',
});
const response = convertToFunctionResponse(
toolCallRequest.name,
toolCallRequest.callId,
tool_output,
);
return {
callId: toolCallRequest.callId,
responseParts: response,
resultDisplay: tool_display,
error:
toolResult.error === undefined
? undefined
: new Error(toolResult.error.message),
errorType:
toolResult.error === undefined ? undefined : toolResult.error.type,
};
} catch (e) {
const error = e instanceof Error ? e : new Error(String(e));
const durationMs = Date.now() - startTime;
logToolCall(config, {
'event.name': 'tool_call',
'event.timestamp': new Date().toISOString(),
function_name: toolCallRequest.name,
function_args: toolCallRequest.args,
duration_ms: durationMs,
success: false,
error: error.message,
error_type: ToolErrorType.UNHANDLED_EXCEPTION,
prompt_id: toolCallRequest.prompt_id,
tool_type:
typeof tool !== 'undefined' && tool instanceof DiscoveredMCPTool
? 'mcp'
: 'native',
});
return {
callId: toolCallRequest.callId,
responseParts: [
{
functionResponse: {
id: toolCallRequest.callId,
name: toolCallRequest.name,
response: { error: error.message },
},
},
],
resultDisplay: error.message,
error,
errorType: ToolErrorType.UNHANDLED_EXCEPTION,
};
}
return new Promise<ToolCallResponseInfo>((resolve, reject) => {
new CoreToolScheduler({
config,
getPreferredEditor: () => undefined,
onEditorClose: () => {},
onAllToolCallsComplete: async (completedToolCalls) => {
resolve(completedToolCalls[0].response);
},
})
.schedule(toolCallRequest, abortSignal)
.catch(reject);
});
}