332 lines
11 KiB
TypeScript
332 lines
11 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
import {
|
|
Config,
|
|
executeToolCall,
|
|
ToolRegistry,
|
|
ToolErrorType,
|
|
shutdownTelemetry,
|
|
GeminiEventType,
|
|
ServerGeminiStreamEvent,
|
|
} from '@google/gemini-cli-core';
|
|
import { Part } from '@google/genai';
|
|
import { runNonInteractive } from './nonInteractiveCli.js';
|
|
import { vi } from 'vitest';
|
|
|
|
// Mock core modules
|
|
vi.mock('./ui/hooks/atCommandProcessor.js');
|
|
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
|
const original =
|
|
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
|
return {
|
|
...original,
|
|
executeToolCall: vi.fn(),
|
|
shutdownTelemetry: vi.fn(),
|
|
isTelemetrySdkInitialized: vi.fn().mockReturnValue(true),
|
|
};
|
|
});
|
|
|
|
describe('runNonInteractive', () => {
|
|
let mockConfig: Config;
|
|
let mockToolRegistry: ToolRegistry;
|
|
let mockCoreExecuteToolCall: vi.Mock;
|
|
let mockShutdownTelemetry: vi.Mock;
|
|
let consoleErrorSpy: vi.SpyInstance;
|
|
let processExitSpy: vi.SpyInstance;
|
|
let processStdoutSpy: vi.SpyInstance;
|
|
let mockGeminiClient: {
|
|
sendMessageStream: vi.Mock;
|
|
};
|
|
|
|
beforeEach(async () => {
|
|
mockCoreExecuteToolCall = vi.mocked(executeToolCall);
|
|
mockShutdownTelemetry = vi.mocked(shutdownTelemetry);
|
|
|
|
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
processExitSpy = vi
|
|
.spyOn(process, 'exit')
|
|
.mockImplementation((() => {}) as (code?: number) => never);
|
|
processStdoutSpy = vi
|
|
.spyOn(process.stdout, 'write')
|
|
.mockImplementation(() => true);
|
|
|
|
mockToolRegistry = {
|
|
getTool: vi.fn(),
|
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
|
} as unknown as ToolRegistry;
|
|
|
|
mockGeminiClient = {
|
|
sendMessageStream: vi.fn(),
|
|
};
|
|
|
|
mockConfig = {
|
|
initialize: vi.fn().mockResolvedValue(undefined),
|
|
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
|
|
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
|
getMaxSessionTurns: vi.fn().mockReturnValue(10),
|
|
getIdeMode: vi.fn().mockReturnValue(false),
|
|
getFullContext: vi.fn().mockReturnValue(false),
|
|
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
|
|
getDebugMode: vi.fn().mockReturnValue(false),
|
|
} as unknown as Config;
|
|
|
|
const { handleAtCommand } = await import(
|
|
'./ui/hooks/atCommandProcessor.js'
|
|
);
|
|
vi.mocked(handleAtCommand).mockImplementation(async ({ query }) => ({
|
|
processedQuery: [{ text: query }],
|
|
shouldProceed: true,
|
|
}));
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.restoreAllMocks();
|
|
});
|
|
|
|
async function* createStreamFromEvents(
|
|
events: ServerGeminiStreamEvent[],
|
|
): AsyncGenerator<ServerGeminiStreamEvent> {
|
|
for (const event of events) {
|
|
yield event;
|
|
}
|
|
}
|
|
|
|
it('should process input and write text output', async () => {
|
|
const events: ServerGeminiStreamEvent[] = [
|
|
{ type: GeminiEventType.Content, value: 'Hello' },
|
|
{ type: GeminiEventType.Content, value: ' World' },
|
|
];
|
|
mockGeminiClient.sendMessageStream.mockReturnValue(
|
|
createStreamFromEvents(events),
|
|
);
|
|
|
|
await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1');
|
|
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith(
|
|
[{ text: 'Test input' }],
|
|
expect.any(AbortSignal),
|
|
'prompt-id-1',
|
|
);
|
|
expect(processStdoutSpy).toHaveBeenCalledWith('Hello');
|
|
expect(processStdoutSpy).toHaveBeenCalledWith(' World');
|
|
expect(processStdoutSpy).toHaveBeenCalledWith('\n');
|
|
expect(mockShutdownTelemetry).toHaveBeenCalled();
|
|
});
|
|
|
|
it('should handle a single tool call and respond', async () => {
|
|
const toolCallEvent: ServerGeminiStreamEvent = {
|
|
type: GeminiEventType.ToolCallRequest,
|
|
value: {
|
|
callId: 'tool-1',
|
|
name: 'testTool',
|
|
args: { arg1: 'value1' },
|
|
isClientInitiated: false,
|
|
prompt_id: 'prompt-id-2',
|
|
},
|
|
};
|
|
const toolResponse: Part[] = [{ text: 'Tool response' }];
|
|
mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse });
|
|
|
|
const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
|
|
const secondCallEvents: ServerGeminiStreamEvent[] = [
|
|
{ type: GeminiEventType.Content, value: 'Final answer' },
|
|
];
|
|
|
|
mockGeminiClient.sendMessageStream
|
|
.mockReturnValueOnce(createStreamFromEvents(firstCallEvents))
|
|
.mockReturnValueOnce(createStreamFromEvents(secondCallEvents));
|
|
|
|
await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2');
|
|
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
|
|
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
|
|
mockConfig,
|
|
expect.objectContaining({ name: 'testTool' }),
|
|
expect.any(AbortSignal),
|
|
);
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
|
|
2,
|
|
[{ text: 'Tool response' }],
|
|
expect.any(AbortSignal),
|
|
'prompt-id-2',
|
|
);
|
|
expect(processStdoutSpy).toHaveBeenCalledWith('Final answer');
|
|
expect(processStdoutSpy).toHaveBeenCalledWith('\n');
|
|
});
|
|
|
|
it('should handle error during tool execution and should send error back to the model', async () => {
|
|
const toolCallEvent: ServerGeminiStreamEvent = {
|
|
type: GeminiEventType.ToolCallRequest,
|
|
value: {
|
|
callId: 'tool-1',
|
|
name: 'errorTool',
|
|
args: {},
|
|
isClientInitiated: false,
|
|
prompt_id: 'prompt-id-3',
|
|
},
|
|
};
|
|
mockCoreExecuteToolCall.mockResolvedValue({
|
|
error: new Error('Execution failed'),
|
|
errorType: ToolErrorType.EXECUTION_FAILED,
|
|
responseParts: [
|
|
{
|
|
functionResponse: {
|
|
name: 'errorTool',
|
|
response: {
|
|
output: 'Error: Execution failed',
|
|
},
|
|
},
|
|
},
|
|
],
|
|
resultDisplay: 'Execution failed',
|
|
});
|
|
const finalResponse: ServerGeminiStreamEvent[] = [
|
|
{
|
|
type: GeminiEventType.Content,
|
|
value: 'Sorry, let me try again.',
|
|
},
|
|
];
|
|
mockGeminiClient.sendMessageStream
|
|
.mockReturnValueOnce(createStreamFromEvents([toolCallEvent]))
|
|
.mockReturnValueOnce(createStreamFromEvents(finalResponse));
|
|
|
|
await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3');
|
|
|
|
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
|
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
'Error executing tool errorTool: Execution failed',
|
|
);
|
|
expect(processExitSpy).not.toHaveBeenCalled();
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
|
|
2,
|
|
[
|
|
{
|
|
functionResponse: {
|
|
name: 'errorTool',
|
|
response: {
|
|
output: 'Error: Execution failed',
|
|
},
|
|
},
|
|
},
|
|
],
|
|
expect.any(AbortSignal),
|
|
'prompt-id-3',
|
|
);
|
|
expect(processStdoutSpy).toHaveBeenCalledWith('Sorry, let me try again.');
|
|
});
|
|
|
|
it('should exit with error if sendMessageStream throws initially', async () => {
|
|
const apiError = new Error('API connection failed');
|
|
mockGeminiClient.sendMessageStream.mockImplementation(() => {
|
|
throw apiError;
|
|
});
|
|
|
|
await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4');
|
|
|
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
'[API Error: API connection failed]',
|
|
);
|
|
expect(processExitSpy).toHaveBeenCalledWith(1);
|
|
});
|
|
|
|
it('should not exit if a tool is not found, and should send error back to model', async () => {
|
|
const toolCallEvent: ServerGeminiStreamEvent = {
|
|
type: GeminiEventType.ToolCallRequest,
|
|
value: {
|
|
callId: 'tool-1',
|
|
name: 'nonexistentTool',
|
|
args: {},
|
|
isClientInitiated: false,
|
|
prompt_id: 'prompt-id-5',
|
|
},
|
|
};
|
|
mockCoreExecuteToolCall.mockResolvedValue({
|
|
error: new Error('Tool "nonexistentTool" not found in registry.'),
|
|
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
|
|
});
|
|
const finalResponse: ServerGeminiStreamEvent[] = [
|
|
{
|
|
type: GeminiEventType.Content,
|
|
value: "Sorry, I can't find that tool.",
|
|
},
|
|
];
|
|
|
|
mockGeminiClient.sendMessageStream
|
|
.mockReturnValueOnce(createStreamFromEvents([toolCallEvent]))
|
|
.mockReturnValueOnce(createStreamFromEvents(finalResponse));
|
|
|
|
await runNonInteractive(
|
|
mockConfig,
|
|
'Trigger tool not found',
|
|
'prompt-id-5',
|
|
);
|
|
|
|
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
|
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.',
|
|
);
|
|
expect(processExitSpy).not.toHaveBeenCalled();
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
|
|
expect(processStdoutSpy).toHaveBeenCalledWith(
|
|
"Sorry, I can't find that tool.",
|
|
);
|
|
});
|
|
|
|
it('should exit when max session turns are exceeded', async () => {
|
|
vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0);
|
|
await runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6');
|
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
'\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
|
|
);
|
|
});
|
|
|
|
it('should preprocess @include commands before sending to the model', async () => {
|
|
// 1. Mock the imported atCommandProcessor
|
|
const { handleAtCommand } = await import(
|
|
'./ui/hooks/atCommandProcessor.js'
|
|
);
|
|
const mockHandleAtCommand = vi.mocked(handleAtCommand);
|
|
|
|
// 2. Define the raw input and the expected processed output
|
|
const rawInput = 'Summarize @file.txt';
|
|
const processedParts: Part[] = [
|
|
{ text: 'Summarize @file.txt' },
|
|
{ text: '\n--- Content from referenced files ---\n' },
|
|
{ text: 'This is the content of the file.' },
|
|
{ text: '\n--- End of content ---' },
|
|
];
|
|
|
|
// 3. Setup the mock to return the processed parts
|
|
mockHandleAtCommand.mockResolvedValue({
|
|
processedQuery: processedParts,
|
|
shouldProceed: true,
|
|
});
|
|
|
|
// Mock a simple stream response from the Gemini client
|
|
const events: ServerGeminiStreamEvent[] = [
|
|
{ type: GeminiEventType.Content, value: 'Summary complete.' },
|
|
];
|
|
mockGeminiClient.sendMessageStream.mockReturnValue(
|
|
createStreamFromEvents(events),
|
|
);
|
|
|
|
// 4. Run the non-interactive mode with the raw input
|
|
await runNonInteractive(mockConfig, rawInput, 'prompt-id-7');
|
|
|
|
// 5. Assert that sendMessageStream was called with the PROCESSED parts, not the raw input
|
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith(
|
|
processedParts,
|
|
expect.any(AbortSignal),
|
|
'prompt-id-7',
|
|
);
|
|
|
|
// 6. Assert the final output is correct
|
|
expect(processStdoutSpy).toHaveBeenCalledWith('Summary complete.');
|
|
});
|
|
});
|