From 65be9cab478bbccd7a7f3937c11edd88dac1feb9 Mon Sep 17 00:00:00 2001 From: anj-s <32556631+anj-s@users.noreply.github.com> Date: Thu, 31 Jul 2025 05:36:12 -0700 Subject: [PATCH] Fix: Ensure that non interactive mode and interactive mode are calling the same entry points (#5137) --- packages/cli/src/nonInteractiveCli.test.ts | 372 ++++++++------------- packages/cli/src/nonInteractiveCli.ts | 65 +--- 2 files changed, 153 insertions(+), 284 deletions(-) diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 8b0419f1..a0fc6f9f 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -4,196 +4,167 @@ * SPDX-License-Identifier: Apache-2.0 */ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + Config, + executeToolCall, + ToolRegistry, + shutdownTelemetry, + GeminiEventType, + ServerGeminiStreamEvent, +} from '@google/gemini-cli-core'; +import { Part } from '@google/genai'; import { runNonInteractive } from './nonInteractiveCli.js'; -import { Config, GeminiClient, ToolRegistry } from '@google/gemini-cli-core'; -import { GenerateContentResponse, Part, FunctionCall } from '@google/genai'; +import { vi } from 'vitest'; -// Mock dependencies -vi.mock('@google/gemini-cli-core', async () => { - const actualCore = await vi.importActual< - typeof import('@google/gemini-cli-core') - >('@google/gemini-cli-core'); +// Mock core modules +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); return { - ...actualCore, - GeminiClient: vi.fn(), - ToolRegistry: vi.fn(), + ...original, executeToolCall: vi.fn(), + shutdownTelemetry: vi.fn(), + isTelemetrySdkInitialized: vi.fn().mockReturnValue(true), }; }); describe('runNonInteractive', () => { let mockConfig: Config; - let mockGeminiClient: GeminiClient; let mockToolRegistry: ToolRegistry; - let mockChat: { - sendMessageStream: ReturnType; + let mockCoreExecuteToolCall: vi.Mock; + let mockShutdownTelemetry: vi.Mock; + let consoleErrorSpy: vi.SpyInstance; + let processExitSpy: vi.SpyInstance; + let processStdoutSpy: vi.SpyInstance; + let mockGeminiClient: { + sendMessageStream: vi.Mock; }; - let mockProcessStdoutWrite: ReturnType; - let mockProcessExit: ReturnType; beforeEach(() => { - vi.resetAllMocks(); - mockChat = { - sendMessageStream: vi.fn(), - }; - mockGeminiClient = { - getChat: vi.fn().mockResolvedValue(mockChat), - } as unknown as GeminiClient; + mockCoreExecuteToolCall = vi.mocked(executeToolCall); + mockShutdownTelemetry = vi.mocked(shutdownTelemetry); + + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + processExitSpy = vi + .spyOn(process, 'exit') + .mockImplementation((() => {}) as (code?: number) => never); + processStdoutSpy = vi + .spyOn(process.stdout, 'write') + .mockImplementation(() => true); + mockToolRegistry = { - getFunctionDeclarations: vi.fn().mockReturnValue([]), getTool: vi.fn(), + getFunctionDeclarations: vi.fn().mockReturnValue([]), } as unknown as ToolRegistry; - vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient); - vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry); + mockGeminiClient = { + sendMessageStream: vi.fn(), + }; mockConfig = { - getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + initialize: vi.fn().mockResolvedValue(undefined), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), - getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), getMaxSessionTurns: vi.fn().mockReturnValue(10), - initialize: vi.fn(), + getIdeMode: vi.fn().mockReturnValue(false), + getFullContext: vi.fn().mockReturnValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({}), } as unknown as Config; - - mockProcessStdoutWrite = vi.fn().mockImplementation(() => true); - process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock - mockProcessExit = vi - .fn() - .mockImplementation((_code?: number) => undefined as never); - process.exit = mockProcessExit as any; // Use any for process.exit mock }); afterEach(() => { vi.restoreAllMocks(); - // Restore original process methods if they were globally patched - // This might require storing the original methods before patching them in beforeEach }); + async function* createStreamFromEvents( + events: ServerGeminiStreamEvent[], + ): AsyncGenerator { + for (const event of events) { + yield event; + } + } + it('should process input and write text output', async () => { - const inputStream = (async function* () { - yield { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - } as GenerateContentResponse; - yield { - candidates: [{ content: { parts: [{ text: ' World' }] } }], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream.mockResolvedValue(inputStream); + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello' }, + { type: GeminiEventType.Content, value: ' World' }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); await runNonInteractive(mockConfig, 'Test input', 'prompt-id-1'); - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - { - message: [{ text: 'Test input' }], - config: { - abortSignal: expect.any(AbortSignal), - tools: [{ functionDeclarations: [] }], - }, - }, - expect.any(String), + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Test input' }], + expect.any(AbortSignal), + 'prompt-id-1', ); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello'); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World'); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n'); + expect(processStdoutSpy).toHaveBeenCalledWith('Hello'); + expect(processStdoutSpy).toHaveBeenCalledWith(' World'); + expect(processStdoutSpy).toHaveBeenCalledWith('\n'); + expect(mockShutdownTelemetry).toHaveBeenCalled(); }); it('should handle a single tool call and respond', async () => { - const functionCall: FunctionCall = { - id: 'fc1', - name: 'testTool', - args: { p: 'v' }, - }; - const toolResponsePart: Part = { - functionResponse: { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', name: 'testTool', - id: 'fc1', - response: { result: 'tool success' }, + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-2', }, }; + const toolResponse: Part[] = [{ text: 'Tool response' }]; + mockCoreExecuteToolCall.mockResolvedValue({ responseParts: toolResponse }); - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fc1', - responseParts: [toolResponsePart], - resultDisplay: 'Tool success display', - error: undefined, - }); + const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Final answer' }, + ]; - const stream1 = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - const stream2 = (async function* () { - yield { - candidates: [{ content: { parts: [{ text: 'Final answer' }] } }], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); await runNonInteractive(mockConfig, 'Use a tool', 'prompt-id-2'); - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( mockConfig, - expect.objectContaining({ callId: 'fc1', name: 'testTool' }), + expect.objectContaining({ name: 'testTool' }), mockToolRegistry, expect.any(AbortSignal), ); - expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( - expect.objectContaining({ - message: [toolResponsePart], - }), - expect.any(String), + expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith( + 2, + [{ text: 'Tool response' }], + expect.any(AbortSignal), + 'prompt-id-2', ); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer'); + expect(processStdoutSpy).toHaveBeenCalledWith('Final answer'); + expect(processStdoutSpy).toHaveBeenCalledWith('\n'); }); it('should handle error during tool execution', async () => { - const functionCall: FunctionCall = { - id: 'fcError', - name: 'errorTool', - args: {}, - }; - const errorResponsePart: Part = { - functionResponse: { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', name: 'errorTool', - id: 'fcError', - response: { error: 'Tool failed' }, + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-3', }, }; - - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fcError', - responseParts: [errorResponsePart], - resultDisplay: 'Tool execution failed badly', - error: new Error('Tool failed'), + mockCoreExecuteToolCall.mockResolvedValue({ + error: new Error('Tool execution failed badly'), }); - - const stream1 = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - - const stream2 = (async function* () { - yield { - candidates: [ - { content: { parts: [{ text: 'Could not complete request.' }] } }, - ], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents([toolCallEvent]), + ); await runNonInteractive(mockConfig, 'Trigger tool error', 'prompt-id-3'); @@ -201,75 +172,48 @@ describe('runNonInteractive', () => { expect(consoleErrorSpy).toHaveBeenCalledWith( 'Error executing tool errorTool: Tool execution failed badly', ); - expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( - expect.objectContaining({ - message: [errorResponsePart], - }), - expect.any(String), - ); - expect(mockProcessStdoutWrite).toHaveBeenCalledWith( - 'Could not complete request.', - ); + expect(processExitSpy).toHaveBeenCalledWith(1); }); it('should exit with error if sendMessageStream throws initially', async () => { const apiError = new Error('API connection failed'); - mockChat.sendMessageStream.mockRejectedValue(apiError); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw apiError; + }); await runNonInteractive(mockConfig, 'Initial fail', 'prompt-id-4'); expect(consoleErrorSpy).toHaveBeenCalledWith( '[API Error: API connection failed]', ); + expect(processExitSpy).toHaveBeenCalledWith(1); }); it('should not exit if a tool is not found, and should send error back to model', async () => { - const functionCall: FunctionCall = { - id: 'fcNotFound', - name: 'nonexistentTool', - args: {}, - }; - const errorResponsePart: Part = { - functionResponse: { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', name: 'nonexistentTool', - id: 'fcNotFound', - response: { error: 'Tool "nonexistentTool" not found in registry.' }, + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-5', }, }; - - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fcNotFound', - responseParts: [errorResponsePart], - resultDisplay: 'Tool "nonexistentTool" not found in registry.', + mockCoreExecuteToolCall.mockResolvedValue({ error: new Error('Tool "nonexistentTool" not found in registry.'), + resultDisplay: 'Tool "nonexistentTool" not found in registry.', }); + const finalResponse: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Content, + value: "Sorry, I can't find that tool.", + }, + ]; - const stream1 = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - const stream2 = (async function* () { - yield { - candidates: [ - { - content: { - parts: [{ text: 'Unfortunately the tool does not exist.' }], - }, - }, - ], - } as GenerateContentResponse; - })(); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents([toolCallEvent])) + .mockReturnValueOnce(createStreamFromEvents(finalResponse)); await runNonInteractive( mockConfig, @@ -277,68 +221,22 @@ describe('runNonInteractive', () => { 'prompt-id-5', ); + expect(mockCoreExecuteToolCall).toHaveBeenCalled(); expect(consoleErrorSpy).toHaveBeenCalledWith( 'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.', ); - - expect(mockProcessExit).not.toHaveBeenCalled(); - - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); - expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( - expect.objectContaining({ - message: [errorResponsePart], - }), - expect.any(String), - ); - - expect(mockProcessStdoutWrite).toHaveBeenCalledWith( - 'Unfortunately the tool does not exist.', + expect(processExitSpy).not.toHaveBeenCalled(); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(processStdoutSpy).toHaveBeenCalledWith( + "Sorry, I can't find that tool.", ); }); it('should exit when max session turns are exceeded', async () => { - const functionCall: FunctionCall = { - id: 'fcLoop', - name: 'loopTool', - args: {}, - }; - const toolResponsePart: Part = { - functionResponse: { - name: 'loopTool', - id: 'fcLoop', - response: { result: 'still looping' }, - }, - }; - - // Config with a max turn of 1 - vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1); - - const { executeToolCall: mockCoreExecuteToolCall } = await import( - '@google/gemini-cli-core' - ); - vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ - callId: 'fcLoop', - responseParts: [toolResponsePart], - resultDisplay: 'Still looping', - error: undefined, - }); - - const stream = (async function* () { - yield { functionCalls: [functionCall] } as GenerateContentResponse; - })(); - - mockChat.sendMessageStream.mockResolvedValue(stream); - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - await runNonInteractive(mockConfig, 'Trigger loop'); - - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0); + await runNonInteractive(mockConfig, 'Trigger loop', 'prompt-id-6'); expect(consoleErrorSpy).toHaveBeenCalledWith( - ` - Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`, + '\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', ); - expect(mockProcessExit).not.toHaveBeenCalled(); }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 7bc0f6aa..1d0a7f3d 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -11,38 +11,12 @@ import { ToolRegistry, shutdownTelemetry, isTelemetrySdkInitialized, + GeminiEventType, } from '@google/gemini-cli-core'; -import { - Content, - Part, - FunctionCall, - GenerateContentResponse, -} from '@google/genai'; +import { Content, Part, FunctionCall } from '@google/genai'; import { parseAndFormatApiError } from './ui/utils/errorParsing.js'; -function getResponseText(response: GenerateContentResponse): string | null { - if (response.candidates && response.candidates.length > 0) { - const candidate = response.candidates[0]; - if ( - candidate.content && - candidate.content.parts && - candidate.content.parts.length > 0 - ) { - // We are running in headless mode so we don't need to return thoughts to STDOUT. - const thoughtPart = candidate.content.parts[0]; - if (thoughtPart?.thought) { - return null; - } - return candidate.content.parts - .filter((part) => part.text) - .map((part) => part.text) - .join(''); - } - } - return null; -} - export async function runNonInteractive( config: Config, input: string, @@ -60,7 +34,6 @@ export async function runNonInteractive( const geminiClient = config.getGeminiClient(); const toolRegistry: ToolRegistry = await config.getToolRegistry(); - const chat = await geminiClient.getChat(); const abortController = new AbortController(); let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; let turnCount = 0; @@ -68,7 +41,7 @@ export async function runNonInteractive( while (true) { turnCount++; if ( - config.getMaxSessionTurns() > 0 && + config.getMaxSessionTurns() >= 0 && turnCount > config.getMaxSessionTurns() ) { console.error( @@ -78,30 +51,28 @@ export async function runNonInteractive( } const functionCalls: FunctionCall[] = []; - const responseStream = await chat.sendMessageStream( - { - message: currentMessages[0]?.parts || [], // Ensure parts are always provided - config: { - abortSignal: abortController.signal, - tools: [ - { functionDeclarations: toolRegistry.getFunctionDeclarations() }, - ], - }, - }, + const responseStream = geminiClient.sendMessageStream( + currentMessages[0]?.parts || [], + abortController.signal, prompt_id, ); - for await (const resp of responseStream) { + for await (const event of responseStream) { if (abortController.signal.aborted) { console.error('Operation cancelled.'); return; } - const textPart = getResponseText(resp); - if (textPart) { - process.stdout.write(textPart); - } - if (resp.functionCalls) { - functionCalls.push(...resp.functionCalls); + + if (event.type === GeminiEventType.Content) { + process.stdout.write(event.value); + } else if (event.type === GeminiEventType.ToolCallRequest) { + const toolCallRequest = event.value; + const fc: FunctionCall = { + name: toolCallRequest.name, + args: toolCallRequest.args, + id: toolCallRequest.callId, + }; + functionCalls.push(fc); } }