From 2828fc6d66cd7a74db231143183bd7c44e55148d Mon Sep 17 00:00:00 2001 From: "N. Taylor Mullen" Date: Sun, 1 Jun 2025 16:11:37 -0700 Subject: [PATCH] feat: Implement non-interactive mode for CLI (#675) --- packages/cli/src/gemini.tsx | 66 +++-- packages/cli/src/nonInteractiveCli.test.ts | 224 +++++++++++++++++ packages/cli/src/nonInteractiveCli.ts | 114 +++++++++ .../core/nonInteractiveToolExecutor.test.ts | 235 ++++++++++++++++++ .../src/core/nonInteractiveToolExecutor.ts | 91 +++++++ packages/core/src/index.ts | 4 + 6 files changed, 710 insertions(+), 24 deletions(-) create mode 100644 packages/cli/src/nonInteractiveCli.test.ts create mode 100644 packages/cli/src/nonInteractiveCli.ts create mode 100644 packages/core/src/core/nonInteractiveToolExecutor.test.ts create mode 100644 packages/core/src/core/nonInteractiveToolExecutor.ts diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 0ed27a99..07551813 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -9,7 +9,6 @@ import { render } from 'ink'; import { App } from './ui/App.js'; import { loadCliConfig } from './config/config.js'; import { readStdin } from './utils/readStdin.js'; -import { GeminiClient } from '@gemini-code/core'; import { readPackageUp } from 'read-package-up'; import { fileURLToPath } from 'node:url'; import { dirname } from 'node:path'; @@ -17,14 +16,25 @@ import { sandbox_command, start_sandbox } from './utils/sandbox.js'; import { loadSettings } from './config/settings.js'; import { themeManager } from './ui/themes/theme-manager.js'; import { getStartupWarnings } from './utils/startupWarnings.js'; +import { runNonInteractive } from './nonInteractiveCli.js'; +import { + EditTool, + GlobTool, + GrepTool, + LSTool, + MemoryTool, + ReadFileTool, + ReadManyFilesTool, + ShellTool, + WebFetchTool, + WebSearchTool, + WriteFileTool, +} from '@gemini-code/core'; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); async function main() { - const settings = loadSettings(process.cwd()); - const config = await loadCliConfig(settings.merged); - // warn about deprecated environment variables if (process.env.GEMINI_CODE_MODEL) { console.warn('GEMINI_CODE_MODEL is deprecated. Use GEMINI_MODEL instead.'); @@ -43,6 +53,9 @@ async function main() { process.env.GEMINI_SANDBOX_IMAGE = process.env.GEMINI_CODE_SANDBOX_IMAGE; } + const settings = loadSettings(process.cwd()); + const config = await loadCliConfig(settings.merged); + if (settings.merged.theme) { if (!themeManager.setActiveTheme(settings.merged.theme)) { // If the theme is not found during initial load, log a warning and continue. @@ -92,26 +105,31 @@ async function main() { process.exit(1); } - // If not a TTY and we have initial input, process it directly - const geminiClient = new GeminiClient(config); - const chat = await geminiClient.startChat(); - try { - for await (const event of geminiClient.sendMessageStream( - chat, - [{ text: input }], - new AbortController().signal, - )) { - if (event.type === 'content') { - process.stdout.write(event.value); - } - // We might need to handle other event types later, but for now, just content. - } - process.stdout.write('\n'); // Add a newline at the end - process.exit(0); - } catch (error) { - console.error('Error processing piped input:', error); - process.exit(1); - } + // Non-interactive mode handled by runNonInteractive + let existingCoreTools = config.getCoreTools(); + existingCoreTools = existingCoreTools || [ + ReadFileTool.Name, + LSTool.Name, + GrepTool.Name, + GlobTool.Name, + EditTool.Name, + WriteFileTool.Name, + WebFetchTool.Name, + WebSearchTool.Name, + ReadManyFilesTool.Name, + ShellTool.Name, + MemoryTool.Name, + ]; + const interactiveTools = [ShellTool.Name, EditTool.Name, WriteFileTool.Name]; + const nonInteractiveTools = existingCoreTools.filter( + (tool) => !interactiveTools.includes(tool), + ); + const nonInteractiveSettings = { + ...settings.merged, + coreTools: nonInteractiveTools, + }; + const nonInteractiveConfig = await loadCliConfig(nonInteractiveSettings); + await runNonInteractive(nonInteractiveConfig, input); } // --- Global Unhandled Rejection Handler --- diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts new file mode 100644 index 00000000..dca3b855 --- /dev/null +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -0,0 +1,224 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { runNonInteractive } from './nonInteractiveCli.js'; +import { Config, GeminiClient, ToolRegistry } from '@gemini-code/core'; +import { GenerateContentResponse, Part, FunctionCall } from '@google/genai'; + +// Mock dependencies +vi.mock('@gemini-code/core', async () => { + const actualCore = + await vi.importActual( + '@gemini-code/core', + ); + return { + ...actualCore, + GeminiClient: vi.fn(), + ToolRegistry: vi.fn(), + executeToolCall: vi.fn(), + }; +}); + +describe('runNonInteractive', () => { + let mockConfig: Config; + let mockGeminiClient: GeminiClient; + let mockToolRegistry: ToolRegistry; + let mockChat: { + sendMessageStream: ReturnType; + }; + let mockProcessStdoutWrite: ReturnType; + let mockProcessExit: ReturnType; + + beforeEach(() => { + mockChat = { + sendMessageStream: vi.fn(), + }; + mockGeminiClient = { + startChat: vi.fn().mockResolvedValue(mockChat), + } as unknown as GeminiClient; + mockToolRegistry = { + discoverTools: vi.fn().mockResolvedValue(undefined), + getFunctionDeclarations: vi.fn().mockReturnValue([]), + getTool: vi.fn(), + } as unknown as ToolRegistry; + + vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClient); + vi.mocked(ToolRegistry).mockImplementation(() => mockToolRegistry); + + mockConfig = { + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + } as unknown as Config; + + mockProcessStdoutWrite = vi.fn().mockImplementation(() => true); + process.stdout.write = mockProcessStdoutWrite as any; // Use any to bypass strict signature matching for mock + mockProcessExit = vi + .fn() + .mockImplementation((_code?: number) => undefined as never); + process.exit = mockProcessExit as any; // Use any for process.exit mock + }); + + afterEach(() => { + vi.restoreAllMocks(); + // Restore original process methods if they were globally patched + // This might require storing the original methods before patching them in beforeEach + }); + + it('should process input and write text output', async () => { + const inputStream = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + } as GenerateContentResponse; + yield { + candidates: [{ content: { parts: [{ text: ' World' }] } }], + } as GenerateContentResponse; + })(); + mockChat.sendMessageStream.mockResolvedValue(inputStream); + + await runNonInteractive(mockConfig, 'Test input'); + + expect(mockGeminiClient.startChat).toHaveBeenCalled(); + expect(mockToolRegistry.discoverTools).toHaveBeenCalled(); + expect(mockChat.sendMessageStream).toHaveBeenCalledWith({ + message: [{ text: 'Test input' }], + config: { + abortSignal: expect.any(AbortSignal), + tools: [{ functionDeclarations: [] }], + }, + }); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Hello'); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith(' World'); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith('\n'); + }); + + it('should handle a single tool call and respond', async () => { + const functionCall: FunctionCall = { + id: 'fc1', + name: 'testTool', + args: { p: 'v' }, + }; + const toolResponsePart: Part = { + functionResponse: { + name: 'testTool', + id: 'fc1', + response: { result: 'tool success' }, + }, + }; + + const { executeToolCall: mockCoreExecuteToolCall } = await import( + '@gemini-code/core' + ); + vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ + callId: 'fc1', + responseParts: [toolResponsePart], + resultDisplay: 'Tool success display', + error: undefined, + }); + + const stream1 = (async function* () { + yield { functionCalls: [functionCall] } as GenerateContentResponse; + })(); + const stream2 = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Final answer' }] } }], + } as GenerateContentResponse; + })(); + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + + await runNonInteractive(mockConfig, 'Use a tool'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( + expect.objectContaining({ callId: 'fc1', name: 'testTool' }), + mockToolRegistry, + expect.any(AbortSignal), + ); + expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( + expect.objectContaining({ + message: [toolResponsePart], + }), + ); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith('Final answer'); + }); + + it('should handle error during tool execution', async () => { + const functionCall: FunctionCall = { + id: 'fcError', + name: 'errorTool', + args: {}, + }; + const errorResponsePart: Part = { + functionResponse: { + name: 'errorTool', + id: 'fcError', + response: { error: 'Tool failed' }, + }, + }; + + const { executeToolCall: mockCoreExecuteToolCall } = await import( + '@gemini-code/core' + ); + vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ + callId: 'fcError', + responseParts: [errorResponsePart], + resultDisplay: 'Tool execution failed badly', + error: new Error('Tool failed'), + }); + + const stream1 = (async function* () { + yield { functionCalls: [functionCall] } as GenerateContentResponse; + })(); + + const stream2 = (async function* () { + yield { + candidates: [ + { content: { parts: [{ text: 'Could not complete request.' }] } }, + ], + } as GenerateContentResponse; + })(); + mockChat.sendMessageStream + .mockResolvedValueOnce(stream1) + .mockResolvedValueOnce(stream2); + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + await runNonInteractive(mockConfig, 'Trigger tool error'); + + expect(mockCoreExecuteToolCall).toHaveBeenCalled(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error executing tool errorTool: Tool execution failed badly', + ); + expect(mockChat.sendMessageStream).toHaveBeenLastCalledWith( + expect.objectContaining({ + message: [errorResponsePart], + }), + ); + expect(mockProcessStdoutWrite).toHaveBeenCalledWith( + 'Could not complete request.', + ); + consoleErrorSpy.mockRestore(); + }); + + it('should exit with error if sendMessageStream throws initially', async () => { + const apiError = new Error('API connection failed'); + mockChat.sendMessageStream.mockRejectedValue(apiError); + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + await runNonInteractive(mockConfig, 'Initial fail'); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error processing input:', + apiError, + ); + consoleErrorSpy.mockRestore(); + }); +}); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts new file mode 100644 index 00000000..9077ecbf --- /dev/null +++ b/packages/cli/src/nonInteractiveCli.ts @@ -0,0 +1,114 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + GeminiClient, + ToolCallRequestInfo, + executeToolCall, + ToolRegistry, +} from '@gemini-code/core'; +import { + Content, + Part, + FunctionCall, + GenerateContentResponse, +} from '@google/genai'; + +function getResponseText(response: GenerateContentResponse): string | null { + if (response.candidates && response.candidates.length > 0) { + const candidate = response.candidates[0]; + if ( + candidate.content && + candidate.content.parts && + candidate.content.parts.length > 0 + ) { + return candidate.content.parts + .filter((part) => part.text) + .map((part) => part.text) + .join(''); + } + } + return null; +} + +export async function runNonInteractive( + config: Config, + input: string, +): Promise { + const geminiClient = new GeminiClient(config); + const toolRegistry: ToolRegistry = config.getToolRegistry(); + await toolRegistry.discoverTools(); + + const chat = await geminiClient.startChat(); + const abortController = new AbortController(); + let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; + + try { + while (true) { + const functionCalls: FunctionCall[] = []; + + const responseStream = await chat.sendMessageStream({ + message: currentMessages[0]?.parts || [], // Ensure parts are always provided + config: { + abortSignal: abortController.signal, + tools: [ + { functionDeclarations: toolRegistry.getFunctionDeclarations() }, + ], + }, + }); + + for await (const resp of responseStream) { + if (abortController.signal.aborted) { + console.error('Operation cancelled.'); + return; + } + const textPart = getResponseText(resp); + if (textPart) { + process.stdout.write(textPart); + } + if (resp.functionCalls) { + functionCalls.push(...resp.functionCalls); + } + } + + if (functionCalls.length > 0) { + const toolResponseParts: Part[] = []; + + for (const fc of functionCalls) { + const callId = fc.id ?? `${fc.name}-${Date.now()}`; + const requestInfo: ToolCallRequestInfo = { + callId, + name: fc.name as string, + args: (fc.args ?? {}) as Record, + }; + + const toolResponse = await executeToolCall( + requestInfo, + toolRegistry, + abortController.signal, + ); + + if (toolResponse.error) { + console.error( + `Error executing tool ${fc.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, + ); + toolResponseParts.push(...(toolResponse.responseParts as Part[])); + } else { + toolResponseParts.push(...(toolResponse.responseParts as Part[])); + } + } + currentMessages = [{ role: 'user', parts: toolResponseParts }]; + } else { + process.stdout.write('\n'); // Ensure a final newline + return; + } + } + } catch (error) { + console.error('Error processing input:', error); + process.exit(1); + } +} diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts new file mode 100644 index 00000000..3d7dc1a2 --- /dev/null +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -0,0 +1,235 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { executeToolCall } from './nonInteractiveToolExecutor.js'; +import { + ToolRegistry, + ToolCallRequestInfo, + ToolResult, + Tool, + ToolCallConfirmationDetails, +} from '../index.js'; +import { Part, Type } from '@google/genai'; + +describe('executeToolCall', () => { + let mockToolRegistry: ToolRegistry; + let mockTool: Tool; + let abortController: AbortController; + + beforeEach(() => { + mockTool = { + name: 'testTool', + displayName: 'Test Tool', + description: 'A tool for testing', + schema: { + name: 'testTool', + description: 'A tool for testing', + parameters: { + type: Type.OBJECT, + properties: { + param1: { type: Type.STRING }, + }, + required: ['param1'], + }, + }, + execute: vi.fn(), + validateToolParams: vi.fn(() => null), + shouldConfirmExecute: vi.fn(() => + Promise.resolve(false as false | ToolCallConfirmationDetails), + ), + isOutputMarkdown: false, + canUpdateOutput: false, + getDescription: vi.fn(), + }; + + mockToolRegistry = { + getTool: vi.fn(), + // Add other ToolRegistry methods if needed, or use a more complete mock + } as unknown as ToolRegistry; + + abortController = new AbortController(); + }); + + it('should execute a tool successfully', async () => { + const request: ToolCallRequestInfo = { + callId: 'call1', + name: 'testTool', + args: { param1: 'value1' }, + }; + const toolResult: ToolResult = { + llmContent: 'Tool executed successfully', + returnDisplay: 'Success!', + }; + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); + vi.mocked(mockTool.execute).mockResolvedValue(toolResult); + + const response = await executeToolCall( + request, + mockToolRegistry, + abortController.signal, + ); + + expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool'); + expect(mockTool.execute).toHaveBeenCalledWith( + request.args, + abortController.signal, + ); + expect(response.callId).toBe('call1'); + expect(response.error).toBeUndefined(); + expect(response.resultDisplay).toBe('Success!'); + expect(response.responseParts).toEqual([ + { + functionResponse: { + name: 'testTool', + id: 'call1', + response: { output: 'Tool executed successfully' }, + }, + }, + ]); + }); + + it('should return an error if tool is not found', async () => { + const request: ToolCallRequestInfo = { + callId: 'call2', + name: 'nonExistentTool', + args: {}, + }; + vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined); + + const response = await executeToolCall( + request, + mockToolRegistry, + abortController.signal, + ); + + expect(response.callId).toBe('call2'); + expect(response.error).toBeInstanceOf(Error); + expect(response.error?.message).toBe( + 'Tool "nonExistentTool" not found in registry.', + ); + expect(response.resultDisplay).toBe( + 'Tool "nonExistentTool" not found in registry.', + ); + expect(response.responseParts).toEqual([ + { + functionResponse: { + name: 'nonExistentTool', + id: 'call2', + response: { error: 'Tool "nonExistentTool" not found in registry.' }, + }, + }, + ]); + }); + + it('should return an error if tool execution fails', async () => { + const request: ToolCallRequestInfo = { + callId: 'call3', + name: 'testTool', + args: { param1: 'value1' }, + }; + const executionError = new Error('Tool execution failed'); + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); + vi.mocked(mockTool.execute).mockRejectedValue(executionError); + + const response = await executeToolCall( + request, + mockToolRegistry, + abortController.signal, + ); + + expect(response.callId).toBe('call3'); + expect(response.error).toBe(executionError); + expect(response.resultDisplay).toBe('Tool execution failed'); + expect(response.responseParts).toEqual([ + { + functionResponse: { + name: 'testTool', + id: 'call3', + response: { error: 'Tool execution failed' }, + }, + }, + ]); + }); + + it('should handle cancellation during tool execution', async () => { + const request: ToolCallRequestInfo = { + callId: 'call4', + name: 'testTool', + args: { param1: 'value1' }, + }; + const cancellationError = new Error('Operation cancelled'); + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); + + vi.mocked(mockTool.execute).mockImplementation(async (_args, signal) => { + if (signal?.aborted) { + return Promise.reject(cancellationError); + } + return new Promise((_resolve, reject) => { + signal?.addEventListener('abort', () => { + reject(cancellationError); + }); + // Simulate work that might happen if not aborted immediately + const timeoutId = setTimeout( + () => + reject( + new Error('Should have been cancelled if not aborted prior'), + ), + 100, + ); + signal?.addEventListener('abort', () => clearTimeout(timeoutId)); + }); + }); + + abortController.abort(); // Abort before calling + const response = await executeToolCall( + request, + mockToolRegistry, + abortController.signal, + ); + + expect(response.callId).toBe('call4'); + expect(response.error?.message).toBe(cancellationError.message); + expect(response.resultDisplay).toBe('Operation cancelled'); + }); + + it('should correctly format llmContent with inlineData', async () => { + const request: ToolCallRequestInfo = { + callId: 'call5', + name: 'testTool', + args: {}, + }; + const imageDataPart: Part = { + inlineData: { mimeType: 'image/png', data: 'base64data' }, + }; + const toolResult: ToolResult = { + llmContent: [imageDataPart], + returnDisplay: 'Image processed', + }; + vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); + vi.mocked(mockTool.execute).mockResolvedValue(toolResult); + + const response = await executeToolCall( + request, + mockToolRegistry, + abortController.signal, + ); + + expect(response.resultDisplay).toBe('Image processed'); + expect(response.responseParts).toEqual([ + { + functionResponse: { + name: 'testTool', + id: 'call5', + response: { + status: 'Binary content of type image/png was processed.', + }, + }, + }, + imageDataPart, + ]); + }); +}); diff --git a/packages/core/src/core/nonInteractiveToolExecutor.ts b/packages/core/src/core/nonInteractiveToolExecutor.ts new file mode 100644 index 00000000..5b5c9a13 --- /dev/null +++ b/packages/core/src/core/nonInteractiveToolExecutor.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Part } from '@google/genai'; +import { + ToolCallRequestInfo, + ToolCallResponseInfo, + ToolRegistry, + ToolResult, +} from '../index.js'; +import { formatLlmContentForFunctionResponse } from './coreToolScheduler.js'; + +/** + * Executes a single tool call non-interactively. + * It does not handle confirmations, multiple calls, or live updates. + */ +export async function executeToolCall( + toolCallRequest: ToolCallRequestInfo, + toolRegistry: ToolRegistry, + abortSignal?: AbortSignal, +): Promise { + const tool = toolRegistry.getTool(toolCallRequest.name); + + if (!tool) { + const error = new Error( + `Tool "${toolCallRequest.name}" not found in registry.`, + ); + // Ensure the response structure matches what the API expects for an error + return { + callId: toolCallRequest.callId, + responseParts: [ + { + functionResponse: { + id: toolCallRequest.callId, + name: toolCallRequest.name, + response: { error: error.message }, + }, + }, + ], + resultDisplay: error.message, + error, + }; + } + + try { + // Directly execute without confirmation or live output handling + const effectiveAbortSignal = abortSignal ?? new AbortController().signal; + const toolResult: ToolResult = await tool.execute( + toolCallRequest.args, + effectiveAbortSignal, + // No live output callback for non-interactive mode + ); + + const { functionResponseJson, additionalParts } = + formatLlmContentForFunctionResponse(toolResult.llmContent); + + const functionResponsePart: Part = { + functionResponse: { + name: toolCallRequest.name, + id: toolCallRequest.callId, + response: functionResponseJson, + }, + }; + + return { + callId: toolCallRequest.callId, + responseParts: [functionResponsePart, ...additionalParts], + resultDisplay: toolResult.returnDisplay, + error: undefined, + }; + } catch (e) { + const error = e instanceof Error ? e : new Error(String(e)); + return { + callId: toolCallRequest.callId, + responseParts: [ + { + functionResponse: { + id: toolCallRequest.callId, + name: toolCallRequest.name, + response: { error: error.message }, + }, + }, + ], + resultDisplay: error.message, + error, + }; + } +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index f8c42336..bd28c864 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -14,6 +14,7 @@ export * from './core/prompts.js'; export * from './core/turn.js'; export * from './core/geminiRequest.js'; export * from './core/coreToolScheduler.js'; +export * from './core/nonInteractiveToolExecutor.js'; // Export utilities export * from './utils/paths.js'; @@ -35,3 +36,6 @@ export * from './tools/edit.js'; export * from './tools/write-file.js'; export * from './tools/web-fetch.js'; export * from './tools/memoryTool.js'; +export * from './tools/shell.js'; +export * from './tools/web-search.js'; +export * from './tools/read-many-files.js';