Fix: Ensure that non interactive mode and interactive mode are calling the same entry points (#5137)

This commit is contained in:
anj-s 2025-07-31 05:36:12 -07:00 committed by GitHub
parent 23c014e29c
commit 65be9cab47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 153 additions and 284 deletions

View File

@ -4,196 +4,167 @@
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
Config,
executeToolCall,
ToolRegistry,
shutdownTelemetry,
GeminiEventType,
ServerGeminiStreamEvent,
} from '@google/gemini-cli-core';
import { Part } from '@google/genai';
import { runNonInteractive } from './nonInteractiveCli.js';
import { Config, GeminiClient, ToolRegistry } from '@google/gemini-cli-core';
import { GenerateContentResponse, Part, FunctionCall } from '@google/genai';
import { vi } from 'vitest';
// Mock dependencies
vi.mock('@google/gemini-cli-core', async () => {
const actualCore = await vi.importActual<
typeof import('@google/gemini-cli-core')
>('@google/gemini-cli-core');
// Mock core modules
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actualCore,
GeminiClient: vi.fn(),
ToolRegistry: vi.fn(),
...original,
executeToolCall: vi.fn(),
shutdownTelemetry: vi.fn(),
isTelemetrySdkInitialized: vi.fn().mockReturnValue(true),
};
});
describe('runNonInteractive', () => {
let mockConfig: Config;
let mockGeminiClient: GeminiClient;
let mockToolRegistry: ToolRegistry;
let mockChat: {
sendMessageStream: ReturnType<typeof vi.fn>;
let mockCoreExecuteToolCall: vi.Mock;
let mockShutdownTelemetry: vi.Mock;
let consoleErrorSpy: vi.SpyInstance;
let processExitSpy: vi.SpyInstance;
let processStdoutSpy: vi.SpyInstance;
let mockGeminiClient: {
sendMessageStream: vi.Mock;
};
let mockProcessStdoutWrite: ReturnType<typeof vi.fn>;
let mockProcessExit: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.resetAllMocks();
mockChat = {
sendMessageStream: vi.fn(),
};
mockGeminiClient = {
getChat: vi.fn().mockResolvedValue(mockChat),
} as unknown as GeminiClient;
mockCoreExecuteToolCall = vi.mocked(executeToolCall);
mockShutdownTelemetry = vi.mocked(shutdownTelemetry);
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
processExitSpy = vi
.spyOn(process, 'exit')
.mockImplementation((() => {}) as (code?: number) => never);
processStdoutSpy = vi
.spyOn(process.stdout, 'write')
.mockImplementation(() => true);
mockToolRegistry = {
getFunctionDeclarations: vi.fn().mockReturnValue([]),
getTool: vi.fn(),
getFunctionDeclarations: vi.fn().mockReturnValue([]),
} as unknown as ToolRegistry;
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient);
vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry);
mockGeminiClient = {
sendMessageStream: vi.fn(),
};
mockConfig = {
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
initialize: vi.fn().mockResolvedValue(undefined),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
getMaxSessionTurns: vi.fn().mockReturnValue(10),
initialize: vi.fn(),
getIdeMode: vi.fn().mockReturnValue(false),
getFullContext: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
} as unknown as Config;
mockProcessStdoutWrite = vi.fn().mockImplementation(() => true);
process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock
mockProcessExit = vi
.fn()
.mockImplementation((_code?: number) => undefined as never);
process.exit = mockProcessExit as any; // Use any for process.exit mock
});
afterEach(() => {
vi.restoreAllMocks();
// Restore original process methods if they were globally patched
// This might require storing the original methods before patching them in beforeEach
});
async function* createStreamFromEvents(
events: ServerGeminiStreamEvent[],
): AsyncGenerator<ServerGeminiStreamEvent> {
for (const event of events) {
yield event;
}
}
it('should process input and write text output', async () => {
const inputStream = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
} as GenerateContentResponse;
yield {
candidates: [{ content: { parts: [{ text: ' World' }] } }],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream.mockResolvedValue(inputStream);
const events: ServerGeminiStreamEvent[] = [
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' World' },
];
mockGeminiClient.sendMessageStream.mockReturnValue(
createStreamFromEvents(events),
);
await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1');
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
{
message: [{ text: 'Test input' }],
config: {
abortSignal: expect.any(AbortSignal),
tools: [{ functionDeclarations: [] }],
},
},
expect.any(String),
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith(
[{ text: 'Test input' }],
expect.any(AbortSignal),
'prompt-id-1',
);
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello');
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World');
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n');
expect(processStdoutSpy).toHaveBeenCalledWith('Hello');
expect(processStdoutSpy).toHaveBeenCalledWith(' World');
expect(processStdoutSpy).toHaveBeenCalledWith('\n');
expect(mockShutdownTelemetry).toHaveBeenCalled();
});
it('should handle a single tool call and respond', async () => {
const functionCall: FunctionCall = {
id: 'fc1',
name: 'testTool',
args: { p: 'v' },
};
const toolResponsePart: Part = {
functionResponse: {
const toolCallEvent: ServerGeminiStreamEvent = {
type: GeminiEventType.ToolCallRequest,
value: {
callId: 'tool-1',
name: 'testTool',
id: 'fc1',
response: { result: 'tool success' },
args: { arg1: 'value1' },
isClientInitiated: false,
prompt_id: 'prompt-id-2',
},
};
const toolResponse: Part[] = [{ text: 'Tool response' }];
mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse });
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fc1',
responseParts: [toolResponsePart],
resultDisplay: 'Tool success display',
error: undefined,
});
const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
const secondCallEvents: ServerGeminiStreamEvent[] = [
{ type: GeminiEventType.Content, value: 'Final answer' },
];
const stream1 = (async function* () {
yield { functionCalls: [functionCall] } as GenerateContentResponse;
})();
const stream2 = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'Final answer' }] } }],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
mockGeminiClient.sendMessageStream
.mockReturnValueOnce(createStreamFromEvents(firstCallEvents))
.mockReturnValueOnce(createStreamFromEvents(secondCallEvents));
await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2');
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({ callId: 'fc1', name: 'testTool' }),
expect.objectContaining({ name: 'testTool' }),
mockToolRegistry,
expect.any(AbortSignal),
);
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
expect.objectContaining({
message: [toolResponsePart],
}),
expect.any(String),
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
2,
[{ text: 'Tool response' }],
expect.any(AbortSignal),
'prompt-id-2',
);
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer');
expect(processStdoutSpy).toHaveBeenCalledWith('Final answer');
expect(processStdoutSpy).toHaveBeenCalledWith('\n');
});
it('should handle error during tool execution', async () => {
const functionCall: FunctionCall = {
id: 'fcError',
name: 'errorTool',
args: {},
};
const errorResponsePart: Part = {
functionResponse: {
const toolCallEvent: ServerGeminiStreamEvent = {
type: GeminiEventType.ToolCallRequest,
value: {
callId: 'tool-1',
name: 'errorTool',
id: 'fcError',
response: { error: 'Tool failed' },
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-3',
},
};
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcError',
responseParts: [errorResponsePart],
resultDisplay: 'Tool execution failed badly',
error: new Error('Tool failed'),
mockCoreExecuteToolCall.mockResolvedValue({
error: new Error('Tool execution failed badly'),
});
const stream1 = (async function* () {
yield { functionCalls: [functionCall] } as GenerateContentResponse;
})();
const stream2 = (async function* () {
yield {
candidates: [
{ content: { parts: [{ text: 'Could not complete request.' }] } },
],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
mockGeminiClient.sendMessageStream.mockReturnValue(
createStreamFromEvents([toolCallEvent]),
);
await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3');
@ -201,75 +172,48 @@ describe('runNonInteractive', () => {
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool errorTool: Tool execution failed badly',
);
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
expect.objectContaining({
message: [errorResponsePart],
}),
expect.any(String),
);
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
'Could not complete request.',
);
expect(processExitSpy).toHaveBeenCalledWith(1);
});
it('should exit with error if sendMessageStream throws initially', async () => {
const apiError = new Error('API connection failed');
mockChat.sendMessageStream.mockRejectedValue(apiError);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
mockGeminiClient.sendMessageStream.mockImplementation(() => {
throw apiError;
});
await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4');
expect(consoleErrorSpy).toHaveBeenCalledWith(
'[API Error: API connection failed]',
);
expect(processExitSpy).toHaveBeenCalledWith(1);
});
it('should not exit if a tool is not found, and should send error back to model', async () => {
const functionCall: FunctionCall = {
id: 'fcNotFound',
name: 'nonexistentTool',
args: {},
};
const errorResponsePart: Part = {
functionResponse: {
const toolCallEvent: ServerGeminiStreamEvent = {
type: GeminiEventType.ToolCallRequest,
value: {
callId: 'tool-1',
name: 'nonexistentTool',
id: 'fcNotFound',
response: { error: 'Tool "nonexistentTool" not found in registry.' },
args: {},
isClientInitiated: false,
prompt_id: 'prompt-id-5',
},
};
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcNotFound',
responseParts: [errorResponsePart],
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
mockCoreExecuteToolCall.mockResolvedValue({
error: new Error('Tool "nonexistentTool" not found in registry.'),
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
});
const finalResponse: ServerGeminiStreamEvent[] = [
{
type: GeminiEventType.Content,
value: "Sorry, I can't find that tool.",
},
];
const stream1 = (async function* () {
yield { functionCalls: [functionCall] } as GenerateContentResponse;
})();
const stream2 = (async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: 'Unfortunately the tool does not exist.' }],
},
},
],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
mockGeminiClient.sendMessageStream
.mockReturnValueOnce(createStreamFromEvents([toolCallEvent]))
.mockReturnValueOnce(createStreamFromEvents(finalResponse));
await runNonInteractive(
mockConfig,
@ -277,68 +221,22 @@ describe('runNonInteractive', () => {
'prompt-id-5',
);
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.',
);
expect(mockProcessExit).not.toHaveBeenCalled();
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
expect.objectContaining({
message: [errorResponsePart],
}),
expect.any(String),
);
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
'Unfortunately the tool does not exist.',
expect(processExitSpy).not.toHaveBeenCalled();
expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(processStdoutSpy).toHaveBeenCalledWith(
"Sorry, I can't find that tool.",
);
});
it('should exit when max session turns are exceeded', async () => {
const functionCall: FunctionCall = {
id: 'fcLoop',
name: 'loopTool',
args: {},
};
const toolResponsePart: Part = {
functionResponse: {
name: 'loopTool',
id: 'fcLoop',
response: { result: 'still looping' },
},
};
// Config with a max turn of 1
vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1);
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcLoop',
responseParts: [toolResponsePart],
resultDisplay: 'Still looping',
error: undefined,
});
const stream = (async function* () {
yield { functionCalls: [functionCall] } as GenerateContentResponse;
})();
mockChat.sendMessageStream.mockResolvedValue(stream);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await runNonInteractive(mockConfig, 'Trigger loop');
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1);
vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0);
await runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6');
expect(consoleErrorSpy).toHaveBeenCalledWith(
`
Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`,
'\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
);
expect(mockProcessExit).not.toHaveBeenCalled();
});
});

View File

@ -11,38 +11,12 @@ import {
ToolRegistry,
shutdownTelemetry,
isTelemetrySdkInitialized,
GeminiEventType,
} from '@google/gemini-cli-core';
import {
Content,
Part,
FunctionCall,
GenerateContentResponse,
} from '@google/genai';
import { Content, Part, FunctionCall } from '@google/genai';
import { parseAndFormatApiError } from './ui/utils/errorParsing.js';
function getResponseText(response: GenerateContentResponse): string | null {
if (response.candidates && response.candidates.length > 0) {
const candidate = response.candidates[0];
if (
candidate.content &&
candidate.content.parts &&
candidate.content.parts.length > 0
) {
// We are running in headless mode so we don't need to return thoughts to STDOUT.
const thoughtPart = candidate.content.parts[0];
if (thoughtPart?.thought) {
return null;
}
return candidate.content.parts
.filter((part) => part.text)
.map((part) => part.text)
.join('');
}
}
return null;
}
export async function runNonInteractive(
config: Config,
input: string,
@ -60,7 +34,6 @@ export async function runNonInteractive(
const geminiClient = config.getGeminiClient();
const toolRegistry: ToolRegistry = await config.getToolRegistry();
const chat = await geminiClient.getChat();
const abortController = new AbortController();
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
let turnCount = 0;
@ -68,7 +41,7 @@ export async function runNonInteractive(
while (true) {
turnCount++;
if (
config.getMaxSessionTurns() > 0 &&
config.getMaxSessionTurns() >= 0 &&
turnCount > config.getMaxSessionTurns()
) {
console.error(
@ -78,30 +51,28 @@ export async function runNonInteractive(
}
const functionCalls: FunctionCall[] = [];
const responseStream = await chat.sendMessageStream(
{
message: currentMessages[0]?.parts || [], // Ensure parts are always provided
config: {
abortSignal: abortController.signal,
tools: [
{ functionDeclarations: toolRegistry.getFunctionDeclarations() },
],
},
},
const responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
abortController.signal,
prompt_id,
);
for await (const resp of responseStream) {
for await (const event of responseStream) {
if (abortController.signal.aborted) {
console.error('Operation cancelled.');
return;
}
const textPart = getResponseText(resp);
if (textPart) {
process.stdout.write(textPart);
}
if (resp.functionCalls) {
functionCalls.push(...resp.functionCalls);
if (event.type === GeminiEventType.Content) {
process.stdout.write(event.value);
} else if (event.type === GeminiEventType.ToolCallRequest) {
const toolCallRequest = event.value;
const fc: FunctionCall = {
name: toolCallRequest.name,
args: toolCallRequest.args,
id: toolCallRequest.callId,
};
functionCalls.push(fc);
}
}