Fix: Ensure that non interactive mode and interactive mode are calling the same entry points (#5137)

This commit is contained in:
anj-s 2025-07-31 05:36:12 -07:00 committed by GitHub
parent 23c014e29c
commit 65be9cab47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 153 additions and 284 deletions

View File

@ -4,196 +4,167 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
/* eslint-disable @typescript-eslint/no-explicit-any */ import {
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; Config,
executeToolCall,
ToolRegistry,
shutdownTelemetry,
GeminiEventType,
ServerGeminiStreamEvent,
} from '@google/gemini-cli-core';
import { Part } from '@google/genai';
import { runNonInteractive } from './nonInteractiveCli.js'; import { runNonInteractive } from './nonInteractiveCli.js';
import { Config, GeminiClient, ToolRegistry } from '@google/gemini-cli-core'; import { vi } from 'vitest';
import { GenerateContentResponse, Part, FunctionCall } from '@google/genai';
// Mock dependencies // Mock core modules
vi.mock('@google/gemini-cli-core', async () => { vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actualCore = await vi.importActual< const original =
typeof import('@google/gemini-cli-core') await importOriginal<typeof import('@google/gemini-cli-core')>();
>('@google/gemini-cli-core');
return { return {
...actualCore, ...original,
GeminiClient: vi.fn(),
ToolRegistry: vi.fn(),
executeToolCall: vi.fn(), executeToolCall: vi.fn(),
shutdownTelemetry: vi.fn(),
isTelemetrySdkInitialized: vi.fn().mockReturnValue(true),
}; };
}); });
describe('runNonInteractive', () => { describe('runNonInteractive', () => {
let mockConfig: Config; let mockConfig: Config;
let mockGeminiClient: GeminiClient;
let mockToolRegistry: ToolRegistry; let mockToolRegistry: ToolRegistry;
let mockChat: { let mockCoreExecuteToolCall: vi.Mock;
sendMessageStream: ReturnType<typeof vi.fn>; let mockShutdownTelemetry: vi.Mock;
let consoleErrorSpy: vi.SpyInstance;
let processExitSpy: vi.SpyInstance;
let processStdoutSpy: vi.SpyInstance;
let mockGeminiClient: {
sendMessageStream: vi.Mock;
}; };
let mockProcessStdoutWrite: ReturnType<typeof vi.fn>;
let mockProcessExit: ReturnType<typeof vi.fn>;
beforeEach(() => { beforeEach(() => {
vi.resetAllMocks(); mockCoreExecuteToolCall = vi.mocked(executeToolCall);
mockChat = { mockShutdownTelemetry = vi.mocked(shutdownTelemetry);
sendMessageStream: vi.fn(),
}; consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
mockGeminiClient = { processExitSpy = vi
getChat: vi.fn().mockResolvedValue(mockChat), .spyOn(process, 'exit')
} as unknown as GeminiClient; .mockImplementation((() => {}) as (code?: number) => never);
processStdoutSpy = vi
.spyOn(process.stdout, 'write')
.mockImplementation(() => true);
mockToolRegistry = { mockToolRegistry = {
getFunctionDeclarations: vi.fn().mockReturnValue([]),
getTool: vi.fn(), getTool: vi.fn(),
getFunctionDeclarations: vi.fn().mockReturnValue([]),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient); mockGeminiClient = {
vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry); sendMessageStream: vi.fn(),
};
mockConfig = { mockConfig = {
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), initialize: vi.fn().mockResolvedValue(undefined),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getContentGeneratorConfig: vi.fn().mockReturnValue({}), getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
getMaxSessionTurns: vi.fn().mockReturnValue(10), getMaxSessionTurns: vi.fn().mockReturnValue(10),
initialize: vi.fn(), getIdeMode: vi.fn().mockReturnValue(false),
getFullContext: vi.fn().mockReturnValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({}),
} as unknown as Config; } as unknown as Config;
mockProcessStdoutWrite = vi.fn().mockImplementation(() => true);
process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock
mockProcessExit = vi
.fn()
.mockImplementation((_code?: number) => undefined as never);
process.exit = mockProcessExit as any; // Use any for process.exit mock
}); });
afterEach(() => { afterEach(() => {
vi.restoreAllMocks(); vi.restoreAllMocks();
// Restore original process methods if they were globally patched
// This might require storing the original methods before patching them in beforeEach
}); });
async function* createStreamFromEvents(
events: ServerGeminiStreamEvent[],
): AsyncGenerator<ServerGeminiStreamEvent> {
for (const event of events) {
yield event;
}
}
it('should process input and write text output', async () => { it('should process input and write text output', async () => {
const inputStream = (async function* () { const events: ServerGeminiStreamEvent[] = [
yield { { type: GeminiEventType.Content, value: 'Hello' },
candidates: [{ content: { parts: [{ text: 'Hello' }] } }], { type: GeminiEventType.Content, value: ' World' },
} as GenerateContentResponse; ];
yield { mockGeminiClient.sendMessageStream.mockReturnValue(
candidates: [{ content: { parts: [{ text: ' World' }] } }], createStreamFromEvents(events),
} as GenerateContentResponse; );
})();
mockChat.sendMessageStream.mockResolvedValue(inputStream);
await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1'); await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1');
expect(mockChat.sendMessageStream).toHaveBeenCalledWith( expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith(
{ [{ text: 'Test input' }],
message: [{ text: 'Test input' }], expect.any(AbortSignal),
config: { 'prompt-id-1',
abortSignal: expect.any(AbortSignal),
tools: [{ functionDeclarations: [] }],
},
},
expect.any(String),
); );
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello'); expect(processStdoutSpy).toHaveBeenCalledWith('Hello');
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World'); expect(processStdoutSpy).toHaveBeenCalledWith(' World');
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n'); expect(processStdoutSpy).toHaveBeenCalledWith('\n');
expect(mockShutdownTelemetry).toHaveBeenCalled();
}); });
it('should handle a single tool call and respond', async () => { it('should handle a single tool call and respond', async () => {
const functionCall: FunctionCall = { const toolCallEvent: ServerGeminiStreamEvent = {
id: 'fc1', type: GeminiEventType.ToolCallRequest,
name: 'testTool', value: {
args: { p: 'v' }, callId: 'tool-1',
};
const toolResponsePart: Part = {
functionResponse: {
name: 'testTool', name: 'testTool',
id: 'fc1', args: { arg1: 'value1' },
response: { result: 'tool success' }, isClientInitiated: false,
prompt_id: 'prompt-id-2',
}, },
}; };
const toolResponse: Part[] = [{ text: 'Tool response' }];
mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse });
const { executeToolCall: mockCoreExecuteToolCall } = await import( const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent];
'@google/gemini-cli-core' const secondCallEvents: ServerGeminiStreamEvent[] = [
); { type: GeminiEventType.Content, value: 'Final answer' },
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ ];
callId: 'fc1',
responseParts: [toolResponsePart],
resultDisplay: 'Tool success display',
error: undefined,
});
const stream1 = (async function* () { mockGeminiClient.sendMessageStream
yield { functionCalls: [functionCall] } as GenerateContentResponse; .mockReturnValueOnce(createStreamFromEvents(firstCallEvents))
})(); .mockReturnValueOnce(createStreamFromEvents(secondCallEvents));
const stream2 = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'Final answer' }] } }],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2'); await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2');
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig, mockConfig,
expect.objectContaining({ callId: 'fc1', name: 'testTool' }), expect.objectContaining({ name: 'testTool' }),
mockToolRegistry, mockToolRegistry,
expect.any(AbortSignal), expect.any(AbortSignal),
); );
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
expect.objectContaining({ 2,
message: [toolResponsePart], [{ text: 'Tool response' }],
}), expect.any(AbortSignal),
expect.any(String), 'prompt-id-2',
); );
expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer'); expect(processStdoutSpy).toHaveBeenCalledWith('Final answer');
expect(processStdoutSpy).toHaveBeenCalledWith('\n');
}); });
it('should handle error during tool execution', async () => { it('should handle error during tool execution', async () => {
const functionCall: FunctionCall = { const toolCallEvent: ServerGeminiStreamEvent = {
id: 'fcError', type: GeminiEventType.ToolCallRequest,
name: 'errorTool', value: {
args: {}, callId: 'tool-1',
};
const errorResponsePart: Part = {
functionResponse: {
name: 'errorTool', name: 'errorTool',
id: 'fcError', args: {},
response: { error: 'Tool failed' }, isClientInitiated: false,
prompt_id: 'prompt-id-3',
}, },
}; };
mockCoreExecuteToolCall.mockResolvedValue({
const { executeToolCall: mockCoreExecuteToolCall } = await import( error: new Error('Tool execution failed badly'),
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcError',
responseParts: [errorResponsePart],
resultDisplay: 'Tool execution failed badly',
error: new Error('Tool failed'),
}); });
mockGeminiClient.sendMessageStream.mockReturnValue(
const stream1 = (async function* () { createStreamFromEvents([toolCallEvent]),
yield { functionCalls: [functionCall] } as GenerateContentResponse; );
})();
const stream2 = (async function* () {
yield {
candidates: [
{ content: { parts: [{ text: 'Could not complete request.' }] } },
],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3'); await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3');
@ -201,75 +172,48 @@ describe('runNonInteractive', () => {
expect(consoleErrorSpy).toHaveBeenCalledWith( expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool errorTool: Tool execution failed badly', 'Error executing tool errorTool: Tool execution failed badly',
); );
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( expect(processExitSpy).toHaveBeenCalledWith(1);
expect.objectContaining({
message: [errorResponsePart],
}),
expect.any(String),
);
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
'Could not complete request.',
);
}); });
it('should exit with error if sendMessageStream throws initially', async () => { it('should exit with error if sendMessageStream throws initially', async () => {
const apiError = new Error('API connection failed'); const apiError = new Error('API connection failed');
mockChat.sendMessageStream.mockRejectedValue(apiError); mockGeminiClient.sendMessageStream.mockImplementation(() => {
const consoleErrorSpy = vi throw apiError;
.spyOn(console, 'error') });
.mockImplementation(() => {});
await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4'); await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4');
expect(consoleErrorSpy).toHaveBeenCalledWith( expect(consoleErrorSpy).toHaveBeenCalledWith(
'[API Error: API connection failed]', '[API Error: API connection failed]',
); );
expect(processExitSpy).toHaveBeenCalledWith(1);
}); });
it('should not exit if a tool is not found, and should send error back to model', async () => { it('should not exit if a tool is not found, and should send error back to model', async () => {
const functionCall: FunctionCall = { const toolCallEvent: ServerGeminiStreamEvent = {
id: 'fcNotFound', type: GeminiEventType.ToolCallRequest,
name: 'nonexistentTool', value: {
args: {}, callId: 'tool-1',
};
const errorResponsePart: Part = {
functionResponse: {
name: 'nonexistentTool', name: 'nonexistentTool',
id: 'fcNotFound', args: {},
response: { error: 'Tool "nonexistentTool" not found in registry.' }, isClientInitiated: false,
prompt_id: 'prompt-id-5',
}, },
}; };
mockCoreExecuteToolCall.mockResolvedValue({
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcNotFound',
responseParts: [errorResponsePart],
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
error: new Error('Tool "nonexistentTool" not found in registry.'), error: new Error('Tool "nonexistentTool" not found in registry.'),
resultDisplay: 'Tool "nonexistentTool" not found in registry.',
}); });
const finalResponse: ServerGeminiStreamEvent[] = [
{
type: GeminiEventType.Content,
value: "Sorry, I can't find that tool.",
},
];
const stream1 = (async function* () { mockGeminiClient.sendMessageStream
yield { functionCalls: [functionCall] } as GenerateContentResponse; .mockReturnValueOnce(createStreamFromEvents([toolCallEvent]))
})(); .mockReturnValueOnce(createStreamFromEvents(finalResponse));
const stream2 = (async function* () {
yield {
candidates: [
{
content: {
parts: [{ text: 'Unfortunately the tool does not exist.' }],
},
},
],
} as GenerateContentResponse;
})();
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await runNonInteractive( await runNonInteractive(
mockConfig, mockConfig,
@ -277,68 +221,22 @@ describe('runNonInteractive', () => {
'prompt-id-5', 'prompt-id-5',
); );
expect(mockCoreExecuteToolCall).toHaveBeenCalled();
expect(consoleErrorSpy).toHaveBeenCalledWith( expect(consoleErrorSpy).toHaveBeenCalledWith(
'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.', 'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.',
); );
expect(processExitSpy).not.toHaveBeenCalled();
expect(mockProcessExit).not.toHaveBeenCalled(); expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2);
expect(processStdoutSpy).toHaveBeenCalledWith(
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); "Sorry, I can't find that tool.",
expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith(
expect.objectContaining({
message: [errorResponsePart],
}),
expect.any(String),
);
expect(mockProcessStdoutWrite).toHaveBeenCalledWith(
'Unfortunately the tool does not exist.',
); );
}); });
it('should exit when max session turns are exceeded', async () => { it('should exit when max session turns are exceeded', async () => {
const functionCall: FunctionCall = { vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0);
id: 'fcLoop', await runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6');
name: 'loopTool',
args: {},
};
const toolResponsePart: Part = {
functionResponse: {
name: 'loopTool',
id: 'fcLoop',
response: { result: 'still looping' },
},
};
// Config with a max turn of 1
vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1);
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcLoop',
responseParts: [toolResponsePart],
resultDisplay: 'Still looping',
error: undefined,
});
const stream = (async function* () {
yield { functionCalls: [functionCall] } as GenerateContentResponse;
})();
mockChat.sendMessageStream.mockResolvedValue(stream);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await runNonInteractive(mockConfig, 'Trigger loop');
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1);
expect(consoleErrorSpy).toHaveBeenCalledWith( expect(consoleErrorSpy).toHaveBeenCalledWith(
` '\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`,
); );
expect(mockProcessExit).not.toHaveBeenCalled();
}); });
}); });

View File

@ -11,38 +11,12 @@ import {
ToolRegistry, ToolRegistry,
shutdownTelemetry, shutdownTelemetry,
isTelemetrySdkInitialized, isTelemetrySdkInitialized,
GeminiEventType,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { import { Content, Part, FunctionCall } from '@google/genai';
Content,
Part,
FunctionCall,
GenerateContentResponse,
} from '@google/genai';
import { parseAndFormatApiError } from './ui/utils/errorParsing.js'; import { parseAndFormatApiError } from './ui/utils/errorParsing.js';
function getResponseText(response: GenerateContentResponse): string | null {
if (response.candidates && response.candidates.length > 0) {
const candidate = response.candidates[0];
if (
candidate.content &&
candidate.content.parts &&
candidate.content.parts.length > 0
) {
// We are running in headless mode so we don't need to return thoughts to STDOUT.
const thoughtPart = candidate.content.parts[0];
if (thoughtPart?.thought) {
return null;
}
return candidate.content.parts
.filter((part) => part.text)
.map((part) => part.text)
.join('');
}
}
return null;
}
export async function runNonInteractive( export async function runNonInteractive(
config: Config, config: Config,
input: string, input: string,
@ -60,7 +34,6 @@ export async function runNonInteractive(
const geminiClient = config.getGeminiClient(); const geminiClient = config.getGeminiClient();
const toolRegistry: ToolRegistry = await config.getToolRegistry(); const toolRegistry: ToolRegistry = await config.getToolRegistry();
const chat = await geminiClient.getChat();
const abortController = new AbortController(); const abortController = new AbortController();
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
let turnCount = 0; let turnCount = 0;
@ -68,7 +41,7 @@ export async function runNonInteractive(
while (true) { while (true) {
turnCount++; turnCount++;
if ( if (
config.getMaxSessionTurns() > 0 && config.getMaxSessionTurns() >= 0 &&
turnCount > config.getMaxSessionTurns() turnCount > config.getMaxSessionTurns()
) { ) {
console.error( console.error(
@ -78,30 +51,28 @@ export async function runNonInteractive(
} }
const functionCalls: FunctionCall[] = []; const functionCalls: FunctionCall[] = [];
const responseStream = await chat.sendMessageStream( const responseStream = geminiClient.sendMessageStream(
{ currentMessages[0]?.parts || [],
message: currentMessages[0]?.parts || [], // Ensure parts are always provided abortController.signal,
config: {
abortSignal: abortController.signal,
tools: [
{ functionDeclarations: toolRegistry.getFunctionDeclarations() },
],
},
},
prompt_id, prompt_id,
); );
for await (const resp of responseStream) { for await (const event of responseStream) {
if (abortController.signal.aborted) { if (abortController.signal.aborted) {
console.error('Operation cancelled.'); console.error('Operation cancelled.');
return; return;
} }
const textPart = getResponseText(resp);
if (textPart) { if (event.type === GeminiEventType.Content) {
process.stdout.write(textPart); process.stdout.write(event.value);
} } else if (event.type === GeminiEventType.ToolCallRequest) {
if (resp.functionCalls) { const toolCallRequest = event.value;
functionCalls.push(...resp.functionCalls); const fc: FunctionCall = {
name: toolCallRequest.name,
args: toolCallRequest.args,
id: toolCallRequest.callId,
};
functionCalls.push(fc);
} }
} }