feat: Implement non-interactive mode for CLI (#675)
This commit is contained in:
parent
c51d6cc9d3
commit
2828fc6d66
|
@ -9,7 +9,6 @@ import { render } from 'ink';
|
||||||
import { App } from './ui/App.js';
|
import { App } from './ui/App.js';
|
||||||
import { loadCliConfig } from './config/config.js';
|
import { loadCliConfig } from './config/config.js';
|
||||||
import { readStdin } from './utils/readStdin.js';
|
import { readStdin } from './utils/readStdin.js';
|
||||||
import { GeminiClient } from '@gemini-code/core';
|
|
||||||
import { readPackageUp } from 'read-package-up';
|
import { readPackageUp } from 'read-package-up';
|
||||||
import { fileURLToPath } from 'node:url';
|
import { fileURLToPath } from 'node:url';
|
||||||
import { dirname } from 'node:path';
|
import { dirname } from 'node:path';
|
||||||
|
@ -17,14 +16,25 @@ import { sandbox_command, start_sandbox } from './utils/sandbox.js';
|
||||||
import { loadSettings } from './config/settings.js';
|
import { loadSettings } from './config/settings.js';
|
||||||
import { themeManager } from './ui/themes/theme-manager.js';
|
import { themeManager } from './ui/themes/theme-manager.js';
|
||||||
import { getStartupWarnings } from './utils/startupWarnings.js';
|
import { getStartupWarnings } from './utils/startupWarnings.js';
|
||||||
|
import { runNonInteractive } from './nonInteractiveCli.js';
|
||||||
|
import {
|
||||||
|
EditTool,
|
||||||
|
GlobTool,
|
||||||
|
GrepTool,
|
||||||
|
LSTool,
|
||||||
|
MemoryTool,
|
||||||
|
ReadFileTool,
|
||||||
|
ReadManyFilesTool,
|
||||||
|
ShellTool,
|
||||||
|
WebFetchTool,
|
||||||
|
WebSearchTool,
|
||||||
|
WriteFileTool,
|
||||||
|
} from '@gemini-code/core';
|
||||||
|
|
||||||
const __filename = fileURLToPath(import.meta.url);
|
const __filename = fileURLToPath(import.meta.url);
|
||||||
const __dirname = dirname(__filename);
|
const __dirname = dirname(__filename);
|
||||||
|
|
||||||
async function main() {
|
async function main() {
|
||||||
const settings = loadSettings(process.cwd());
|
|
||||||
const config = await loadCliConfig(settings.merged);
|
|
||||||
|
|
||||||
// warn about deprecated environment variables
|
// warn about deprecated environment variables
|
||||||
if (process.env.GEMINI_CODE_MODEL) {
|
if (process.env.GEMINI_CODE_MODEL) {
|
||||||
console.warn('GEMINI_CODE_MODEL is deprecated. Use GEMINI_MODEL instead.');
|
console.warn('GEMINI_CODE_MODEL is deprecated. Use GEMINI_MODEL instead.');
|
||||||
|
@ -43,6 +53,9 @@ async function main() {
|
||||||
process.env.GEMINI_SANDBOX_IMAGE = process.env.GEMINI_CODE_SANDBOX_IMAGE;
|
process.env.GEMINI_SANDBOX_IMAGE = process.env.GEMINI_CODE_SANDBOX_IMAGE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const settings = loadSettings(process.cwd());
|
||||||
|
const config = await loadCliConfig(settings.merged);
|
||||||
|
|
||||||
if (settings.merged.theme) {
|
if (settings.merged.theme) {
|
||||||
if (!themeManager.setActiveTheme(settings.merged.theme)) {
|
if (!themeManager.setActiveTheme(settings.merged.theme)) {
|
||||||
// If the theme is not found during initial load, log a warning and continue.
|
// If the theme is not found during initial load, log a warning and continue.
|
||||||
|
@ -92,26 +105,31 @@ async function main() {
|
||||||
process.exit(1);
|
process.exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If not a TTY and we have initial input, process it directly
|
// Non-interactive mode handled by runNonInteractive
|
||||||
const geminiClient = new GeminiClient(config);
|
let existingCoreTools = config.getCoreTools();
|
||||||
const chat = await geminiClient.startChat();
|
existingCoreTools = existingCoreTools || [
|
||||||
try {
|
ReadFileTool.Name,
|
||||||
for await (const event of geminiClient.sendMessageStream(
|
LSTool.Name,
|
||||||
chat,
|
GrepTool.Name,
|
||||||
[{ text: input }],
|
GlobTool.Name,
|
||||||
new AbortController().signal,
|
EditTool.Name,
|
||||||
)) {
|
WriteFileTool.Name,
|
||||||
if (event.type === 'content') {
|
WebFetchTool.Name,
|
||||||
process.stdout.write(event.value);
|
WebSearchTool.Name,
|
||||||
}
|
ReadManyFilesTool.Name,
|
||||||
// We might need to handle other event types later, but for now, just content.
|
ShellTool.Name,
|
||||||
}
|
MemoryTool.Name,
|
||||||
process.stdout.write('\n'); // Add a newline at the end
|
];
|
||||||
process.exit(0);
|
const interactiveTools = [ShellTool.Name, EditTool.Name, WriteFileTool.Name];
|
||||||
} catch (error) {
|
const nonInteractiveTools = existingCoreTools.filter(
|
||||||
console.error('Error processing piped input:', error);
|
(tool) => !interactiveTools.includes(tool),
|
||||||
process.exit(1);
|
);
|
||||||
}
|
const nonInteractiveSettings = {
|
||||||
|
...settings.merged,
|
||||||
|
coreTools: nonInteractiveTools,
|
||||||
|
};
|
||||||
|
const nonInteractiveConfig = await loadCliConfig(nonInteractiveSettings);
|
||||||
|
await runNonInteractive(nonInteractiveConfig, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Global Unhandled Rejection Handler ---
|
// --- Global Unhandled Rejection Handler ---
|
||||||
|
|
|
@ -0,0 +1,224 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
|
import { runNonInteractive } from './nonInteractiveCli.js';
|
||||||
|
import { Config, GeminiClient, ToolRegistry } from '@gemini-code/core';
|
||||||
|
import { GenerateContentResponse, Part, FunctionCall } from '@google/genai';
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('@gemini-code/core', async () => {
|
||||||
|
const actualCore =
|
||||||
|
await vi.importActual<typeof import('@gemini-code/core')>(
|
||||||
|
'@gemini-code/core',
|
||||||
|
);
|
||||||
|
return {
|
||||||
|
...actualCore,
|
||||||
|
GeminiClient: vi.fn(),
|
||||||
|
ToolRegistry: vi.fn(),
|
||||||
|
executeToolCall: vi.fn(),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('runNonInteractive', () => {
|
||||||
|
let mockConfig: Config;
|
||||||
|
let mockGeminiClient: GeminiClient;
|
||||||
|
let mockToolRegistry: ToolRegistry;
|
||||||
|
let mockChat: {
|
||||||
|
sendMessageStream: ReturnType<typeof vi.fn>;
|
||||||
|
};
|
||||||
|
let mockProcessStdoutWrite: ReturnType<typeof vi.fn>;
|
||||||
|
let mockProcessExit: ReturnType<typeof vi.fn>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockChat = {
|
||||||
|
sendMessageStream: vi.fn(),
|
||||||
|
};
|
||||||
|
mockGeminiClient = {
|
||||||
|
startChat: vi.fn().mockResolvedValue(mockChat),
|
||||||
|
} as unknown as GeminiClient;
|
||||||
|
mockToolRegistry = {
|
||||||
|
discoverTools: vi.fn().mockResolvedValue(undefined),
|
||||||
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||||
|
getTool: vi.fn(),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient);
|
||||||
|
vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry);
|
||||||
|
|
||||||
|
mockConfig = {
|
||||||
|
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
|
mockProcessStdoutWrite = vi.fn().mockImplementation(() => true);
|
||||||
|
process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock
|
||||||
|
mockProcessExit = vi
|
||||||
|
.fn()
|
||||||
|
.mockImplementation((_code?: number) => undefined as never);
|
||||||
|
process.exit = mockProcessExit as any; // Use any for process.exit mock
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
// Restore original process methods if they were globally patched
|
||||||
|
// This might require storing the original methods before patching them in beforeEach
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should process input and write text output', async () => {
|
||||||
|
const inputStream = (async function* () {
|
||||||
|
yield {
|
||||||
|
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
|
||||||
|
} as GenerateContentResponse;
|
||||||
|
yield {
|
||||||
|
candidates: [{ content: { parts: [{ text: ' World' }] } }],
|
||||||
|
} as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
mockChat.sendMessageStream.mockResolvedValue(inputStream);
|
||||||
|
|
||||||
|
await runNonInteractive(mockConfig, 'Test input');
|
||||||
|
|
||||||
|
expect(mockGeminiClient.startChat).toHaveBeenCalled();
|
||||||
|
expect(mockToolRegistry.discoverTools).toHaveBeenCalled();
|
||||||
|
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
|
||||||
|
message: [{ text: 'Test input' }],
|
||||||
|
config: {
|
||||||
|
abortSignal: expect.any(AbortSignal),
|
||||||
|
tools: [{ functionDeclarations: [] }],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello');
|
||||||
|
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World');
|
||||||
|
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a single tool call and respond', async () => {
|
||||||
|
const functionCall: FunctionCall = {
|
||||||
|
id: 'fc1',
|
||||||
|
name: 'testTool',
|
||||||
|
args: { p: 'v' },
|
||||||
|
};
|
||||||
|
const toolResponsePart: Part = {
|
||||||
|
functionResponse: {
|
||||||
|
name: 'testTool',
|
||||||
|
id: 'fc1',
|
||||||
|
response: { result: 'tool success' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const { executeToolCall: mockCoreExecuteToolCall } = await import(
|
||||||
|
'@gemini-code/core'
|
||||||
|
);
|
||||||
|
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
|
||||||
|
callId: 'fc1',
|
||||||
|
responseParts: [toolResponsePart],
|
||||||
|
resultDisplay: 'Tool success display',
|
||||||
|
error: undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream1 = (async function* () {
|
||||||
|
yield { functionCalls: [functionCall] } as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
const stream2 = (async function* () {
|
||||||
|
yield {
|
||||||
|
candidates: [{ content: { parts: [{ text: 'Final answer' }] } }],
|
||||||
|
} as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
mockChat.sendMessageStream
|
||||||
|
.mockResolvedValueOnce(stream1)
|
||||||
|
.mockResolvedValueOnce(stream2);
|
||||||
|
|
||||||
|
await runNonInteractive(mockConfig, 'Use a tool');
|
||||||
|
|
||||||
|
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
|
||||||
|
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ callId: 'fc1', name: 'testTool' }),
|
||||||
|
mockToolRegistry,
|
||||||
|
expect.any(AbortSignal),
|
||||||
|
);
|
||||||
|
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
message: [toolResponsePart],
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle error during tool execution', async () => {
|
||||||
|
const functionCall: FunctionCall = {
|
||||||
|
id: 'fcError',
|
||||||
|
name: 'errorTool',
|
||||||
|
args: {},
|
||||||
|
};
|
||||||
|
const errorResponsePart: Part = {
|
||||||
|
functionResponse: {
|
||||||
|
name: 'errorTool',
|
||||||
|
id: 'fcError',
|
||||||
|
response: { error: 'Tool failed' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const { executeToolCall: mockCoreExecuteToolCall } = await import(
|
||||||
|
'@gemini-code/core'
|
||||||
|
);
|
||||||
|
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
|
||||||
|
callId: 'fcError',
|
||||||
|
responseParts: [errorResponsePart],
|
||||||
|
resultDisplay: 'Tool execution failed badly',
|
||||||
|
error: new Error('Tool failed'),
|
||||||
|
});
|
||||||
|
|
||||||
|
const stream1 = (async function* () {
|
||||||
|
yield { functionCalls: [functionCall] } as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
|
||||||
|
const stream2 = (async function* () {
|
||||||
|
yield {
|
||||||
|
candidates: [
|
||||||
|
{ content: { parts: [{ text: 'Could not complete request.' }] } },
|
||||||
|
],
|
||||||
|
} as GenerateContentResponse;
|
||||||
|
})();
|
||||||
|
mockChat.sendMessageStream
|
||||||
|
.mockResolvedValueOnce(stream1)
|
||||||
|
.mockResolvedValueOnce(stream2);
|
||||||
|
const consoleErrorSpy = vi
|
||||||
|
.spyOn(console, 'error')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await runNonInteractive(mockConfig, 'Trigger tool error');
|
||||||
|
|
||||||
|
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
|
||||||
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||||
|
'Error executing tool errorTool: Tool execution failed badly',
|
||||||
|
);
|
||||||
|
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
message: [errorResponsePart],
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
|
||||||
|
'Could not complete request.',
|
||||||
|
);
|
||||||
|
consoleErrorSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should exit with error if sendMessageStream throws initially', async () => {
|
||||||
|
const apiError = new Error('API connection failed');
|
||||||
|
mockChat.sendMessageStream.mockRejectedValue(apiError);
|
||||||
|
const consoleErrorSpy = vi
|
||||||
|
.spyOn(console, 'error')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await runNonInteractive(mockConfig, 'Initial fail');
|
||||||
|
|
||||||
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||||
|
'Error processing input:',
|
||||||
|
apiError,
|
||||||
|
);
|
||||||
|
consoleErrorSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
|
@ -0,0 +1,114 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
Config,
|
||||||
|
GeminiClient,
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
executeToolCall,
|
||||||
|
ToolRegistry,
|
||||||
|
} from '@gemini-code/core';
|
||||||
|
import {
|
||||||
|
Content,
|
||||||
|
Part,
|
||||||
|
FunctionCall,
|
||||||
|
GenerateContentResponse,
|
||||||
|
} from '@google/genai';
|
||||||
|
|
||||||
|
function getResponseText(response: GenerateContentResponse): string | null {
|
||||||
|
if (response.candidates && response.candidates.length > 0) {
|
||||||
|
const candidate = response.candidates[0];
|
||||||
|
if (
|
||||||
|
candidate.content &&
|
||||||
|
candidate.content.parts &&
|
||||||
|
candidate.content.parts.length > 0
|
||||||
|
) {
|
||||||
|
return candidate.content.parts
|
||||||
|
.filter((part) => part.text)
|
||||||
|
.map((part) => part.text)
|
||||||
|
.join('');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function runNonInteractive(
|
||||||
|
config: Config,
|
||||||
|
input: string,
|
||||||
|
): Promise<void> {
|
||||||
|
const geminiClient = new GeminiClient(config);
|
||||||
|
const toolRegistry: ToolRegistry = config.getToolRegistry();
|
||||||
|
await toolRegistry.discoverTools();
|
||||||
|
|
||||||
|
const chat = await geminiClient.startChat();
|
||||||
|
const abortController = new AbortController();
|
||||||
|
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const functionCalls: FunctionCall[] = [];
|
||||||
|
|
||||||
|
const responseStream = await chat.sendMessageStream({
|
||||||
|
message: currentMessages[0]?.parts || [], // Ensure parts are always provided
|
||||||
|
config: {
|
||||||
|
abortSignal: abortController.signal,
|
||||||
|
tools: [
|
||||||
|
{ functionDeclarations: toolRegistry.getFunctionDeclarations() },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const resp of responseStream) {
|
||||||
|
if (abortController.signal.aborted) {
|
||||||
|
console.error('Operation cancelled.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const textPart = getResponseText(resp);
|
||||||
|
if (textPart) {
|
||||||
|
process.stdout.write(textPart);
|
||||||
|
}
|
||||||
|
if (resp.functionCalls) {
|
||||||
|
functionCalls.push(...resp.functionCalls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (functionCalls.length > 0) {
|
||||||
|
const toolResponseParts: Part[] = [];
|
||||||
|
|
||||||
|
for (const fc of functionCalls) {
|
||||||
|
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
|
||||||
|
const requestInfo: ToolCallRequestInfo = {
|
||||||
|
callId,
|
||||||
|
name: fc.name as string,
|
||||||
|
args: (fc.args ?? {}) as Record<string, unknown>,
|
||||||
|
};
|
||||||
|
|
||||||
|
const toolResponse = await executeToolCall(
|
||||||
|
requestInfo,
|
||||||
|
toolRegistry,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (toolResponse.error) {
|
||||||
|
console.error(
|
||||||
|
`Error executing tool ${fc.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
|
||||||
|
);
|
||||||
|
toolResponseParts.push(...(toolResponse.responseParts as Part[]));
|
||||||
|
} else {
|
||||||
|
toolResponseParts.push(...(toolResponse.responseParts as Part[]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
currentMessages = [{ role: 'user', parts: toolResponseParts }];
|
||||||
|
} else {
|
||||||
|
process.stdout.write('\n'); // Ensure a final newline
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error processing input:', error);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,235 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
|
import { executeToolCall } from './nonInteractiveToolExecutor.js';
|
||||||
|
import {
|
||||||
|
ToolRegistry,
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
ToolResult,
|
||||||
|
Tool,
|
||||||
|
ToolCallConfirmationDetails,
|
||||||
|
} from '../index.js';
|
||||||
|
import { Part, Type } from '@google/genai';
|
||||||
|
|
||||||
|
describe('executeToolCall', () => {
|
||||||
|
let mockToolRegistry: ToolRegistry;
|
||||||
|
let mockTool: Tool;
|
||||||
|
let abortController: AbortController;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockTool = {
|
||||||
|
name: 'testTool',
|
||||||
|
displayName: 'Test Tool',
|
||||||
|
description: 'A tool for testing',
|
||||||
|
schema: {
|
||||||
|
name: 'testTool',
|
||||||
|
description: 'A tool for testing',
|
||||||
|
parameters: {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
param1: { type: Type.STRING },
|
||||||
|
},
|
||||||
|
required: ['param1'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
execute: vi.fn(),
|
||||||
|
validateToolParams: vi.fn(() => null),
|
||||||
|
shouldConfirmExecute: vi.fn(() =>
|
||||||
|
Promise.resolve(false as false | ToolCallConfirmationDetails),
|
||||||
|
),
|
||||||
|
isOutputMarkdown: false,
|
||||||
|
canUpdateOutput: false,
|
||||||
|
getDescription: vi.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
mockToolRegistry = {
|
||||||
|
getTool: vi.fn(),
|
||||||
|
// Add other ToolRegistry methods if needed, or use a more complete mock
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
abortController = new AbortController();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should execute a tool successfully', async () => {
|
||||||
|
const request: ToolCallRequestInfo = {
|
||||||
|
callId: 'call1',
|
||||||
|
name: 'testTool',
|
||||||
|
args: { param1: 'value1' },
|
||||||
|
};
|
||||||
|
const toolResult: ToolResult = {
|
||||||
|
llmContent: 'Tool executed successfully',
|
||||||
|
returnDisplay: 'Success!',
|
||||||
|
};
|
||||||
|
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
|
||||||
|
vi.mocked(mockTool.execute).mockResolvedValue(toolResult);
|
||||||
|
|
||||||
|
const response = await executeToolCall(
|
||||||
|
request,
|
||||||
|
mockToolRegistry,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool');
|
||||||
|
expect(mockTool.execute).toHaveBeenCalledWith(
|
||||||
|
request.args,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
expect(response.callId).toBe('call1');
|
||||||
|
expect(response.error).toBeUndefined();
|
||||||
|
expect(response.resultDisplay).toBe('Success!');
|
||||||
|
expect(response.responseParts).toEqual([
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: 'testTool',
|
||||||
|
id: 'call1',
|
||||||
|
response: { output: 'Tool executed successfully' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return an error if tool is not found', async () => {
|
||||||
|
const request: ToolCallRequestInfo = {
|
||||||
|
callId: 'call2',
|
||||||
|
name: 'nonExistentTool',
|
||||||
|
args: {},
|
||||||
|
};
|
||||||
|
vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined);
|
||||||
|
|
||||||
|
const response = await executeToolCall(
|
||||||
|
request,
|
||||||
|
mockToolRegistry,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.callId).toBe('call2');
|
||||||
|
expect(response.error).toBeInstanceOf(Error);
|
||||||
|
expect(response.error?.message).toBe(
|
||||||
|
'Tool "nonExistentTool" not found in registry.',
|
||||||
|
);
|
||||||
|
expect(response.resultDisplay).toBe(
|
||||||
|
'Tool "nonExistentTool" not found in registry.',
|
||||||
|
);
|
||||||
|
expect(response.responseParts).toEqual([
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: 'nonExistentTool',
|
||||||
|
id: 'call2',
|
||||||
|
response: { error: 'Tool "nonExistentTool" not found in registry.' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return an error if tool execution fails', async () => {
|
||||||
|
const request: ToolCallRequestInfo = {
|
||||||
|
callId: 'call3',
|
||||||
|
name: 'testTool',
|
||||||
|
args: { param1: 'value1' },
|
||||||
|
};
|
||||||
|
const executionError = new Error('Tool execution failed');
|
||||||
|
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
|
||||||
|
vi.mocked(mockTool.execute).mockRejectedValue(executionError);
|
||||||
|
|
||||||
|
const response = await executeToolCall(
|
||||||
|
request,
|
||||||
|
mockToolRegistry,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.callId).toBe('call3');
|
||||||
|
expect(response.error).toBe(executionError);
|
||||||
|
expect(response.resultDisplay).toBe('Tool execution failed');
|
||||||
|
expect(response.responseParts).toEqual([
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: 'testTool',
|
||||||
|
id: 'call3',
|
||||||
|
response: { error: 'Tool execution failed' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle cancellation during tool execution', async () => {
|
||||||
|
const request: ToolCallRequestInfo = {
|
||||||
|
callId: 'call4',
|
||||||
|
name: 'testTool',
|
||||||
|
args: { param1: 'value1' },
|
||||||
|
};
|
||||||
|
const cancellationError = new Error('Operation cancelled');
|
||||||
|
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
|
||||||
|
|
||||||
|
vi.mocked(mockTool.execute).mockImplementation(async (_args, signal) => {
|
||||||
|
if (signal?.aborted) {
|
||||||
|
return Promise.reject(cancellationError);
|
||||||
|
}
|
||||||
|
return new Promise((_resolve, reject) => {
|
||||||
|
signal?.addEventListener('abort', () => {
|
||||||
|
reject(cancellationError);
|
||||||
|
});
|
||||||
|
// Simulate work that might happen if not aborted immediately
|
||||||
|
const timeoutId = setTimeout(
|
||||||
|
() =>
|
||||||
|
reject(
|
||||||
|
new Error('Should have been cancelled if not aborted prior'),
|
||||||
|
),
|
||||||
|
100,
|
||||||
|
);
|
||||||
|
signal?.addEventListener('abort', () => clearTimeout(timeoutId));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
abortController.abort(); // Abort before calling
|
||||||
|
const response = await executeToolCall(
|
||||||
|
request,
|
||||||
|
mockToolRegistry,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.callId).toBe('call4');
|
||||||
|
expect(response.error?.message).toBe(cancellationError.message);
|
||||||
|
expect(response.resultDisplay).toBe('Operation cancelled');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should correctly format llmContent with inlineData', async () => {
|
||||||
|
const request: ToolCallRequestInfo = {
|
||||||
|
callId: 'call5',
|
||||||
|
name: 'testTool',
|
||||||
|
args: {},
|
||||||
|
};
|
||||||
|
const imageDataPart: Part = {
|
||||||
|
inlineData: { mimeType: 'image/png', data: 'base64data' },
|
||||||
|
};
|
||||||
|
const toolResult: ToolResult = {
|
||||||
|
llmContent: [imageDataPart],
|
||||||
|
returnDisplay: 'Image processed',
|
||||||
|
};
|
||||||
|
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
|
||||||
|
vi.mocked(mockTool.execute).mockResolvedValue(toolResult);
|
||||||
|
|
||||||
|
const response = await executeToolCall(
|
||||||
|
request,
|
||||||
|
mockToolRegistry,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.resultDisplay).toBe('Image processed');
|
||||||
|
expect(response.responseParts).toEqual([
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
name: 'testTool',
|
||||||
|
id: 'call5',
|
||||||
|
response: {
|
||||||
|
status: 'Binary content of type image/png was processed.',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
imageDataPart,
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
});
|
|
@ -0,0 +1,91 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { Part } from '@google/genai';
|
||||||
|
import {
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
ToolCallResponseInfo,
|
||||||
|
ToolRegistry,
|
||||||
|
ToolResult,
|
||||||
|
} from '../index.js';
|
||||||
|
import { formatLlmContentForFunctionResponse } from './coreToolScheduler.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes a single tool call non-interactively.
|
||||||
|
* It does not handle confirmations, multiple calls, or live updates.
|
||||||
|
*/
|
||||||
|
export async function executeToolCall(
|
||||||
|
toolCallRequest: ToolCallRequestInfo,
|
||||||
|
toolRegistry: ToolRegistry,
|
||||||
|
abortSignal?: AbortSignal,
|
||||||
|
): Promise<ToolCallResponseInfo> {
|
||||||
|
const tool = toolRegistry.getTool(toolCallRequest.name);
|
||||||
|
|
||||||
|
if (!tool) {
|
||||||
|
const error = new Error(
|
||||||
|
`Tool "${toolCallRequest.name}" not found in registry.`,
|
||||||
|
);
|
||||||
|
// Ensure the response structure matches what the API expects for an error
|
||||||
|
return {
|
||||||
|
callId: toolCallRequest.callId,
|
||||||
|
responseParts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
id: toolCallRequest.callId,
|
||||||
|
name: toolCallRequest.name,
|
||||||
|
response: { error: error.message },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
resultDisplay: error.message,
|
||||||
|
error,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Directly execute without confirmation or live output handling
|
||||||
|
const effectiveAbortSignal = abortSignal ?? new AbortController().signal;
|
||||||
|
const toolResult: ToolResult = await tool.execute(
|
||||||
|
toolCallRequest.args,
|
||||||
|
effectiveAbortSignal,
|
||||||
|
// No live output callback for non-interactive mode
|
||||||
|
);
|
||||||
|
|
||||||
|
const { functionResponseJson, additionalParts } =
|
||||||
|
formatLlmContentForFunctionResponse(toolResult.llmContent);
|
||||||
|
|
||||||
|
const functionResponsePart: Part = {
|
||||||
|
functionResponse: {
|
||||||
|
name: toolCallRequest.name,
|
||||||
|
id: toolCallRequest.callId,
|
||||||
|
response: functionResponseJson,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
callId: toolCallRequest.callId,
|
||||||
|
responseParts: [functionResponsePart, ...additionalParts],
|
||||||
|
resultDisplay: toolResult.returnDisplay,
|
||||||
|
error: undefined,
|
||||||
|
};
|
||||||
|
} catch (e) {
|
||||||
|
const error = e instanceof Error ? e : new Error(String(e));
|
||||||
|
return {
|
||||||
|
callId: toolCallRequest.callId,
|
||||||
|
responseParts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
id: toolCallRequest.callId,
|
||||||
|
name: toolCallRequest.name,
|
||||||
|
response: { error: error.message },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
resultDisplay: error.message,
|
||||||
|
error,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
|
@ -14,6 +14,7 @@ export * from './core/prompts.js';
|
||||||
export * from './core/turn.js';
|
export * from './core/turn.js';
|
||||||
export * from './core/geminiRequest.js';
|
export * from './core/geminiRequest.js';
|
||||||
export * from './core/coreToolScheduler.js';
|
export * from './core/coreToolScheduler.js';
|
||||||
|
export * from './core/nonInteractiveToolExecutor.js';
|
||||||
|
|
||||||
// Export utilities
|
// Export utilities
|
||||||
export * from './utils/paths.js';
|
export * from './utils/paths.js';
|
||||||
|
@ -35,3 +36,6 @@ export * from './tools/edit.js';
|
||||||
export * from './tools/write-file.js';
|
export * from './tools/write-file.js';
|
||||||
export * from './tools/web-fetch.js';
|
export * from './tools/web-fetch.js';
|
||||||
export * from './tools/memoryTool.js';
|
export * from './tools/memoryTool.js';
|
||||||
|
export * from './tools/shell.js';
|
||||||
|
export * from './tools/web-search.js';
|
||||||
|
export * from './tools/read-many-files.js';
|
||||||
|
|
Loading…
Reference in New Issue