diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index f8b9a7de..a16a72cc 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -7,14 +7,16 @@ import { EmbedContentParameters, GenerateContentConfig, - Part, SchemaUnion, PartListUnion, Content, Tool, GenerateContentResponse, } from '@google/genai'; -import { getFolderStructure } from '../utils/getFolderStructure.js'; +import { + getDirectoryContextString, + getEnvironmentContext, +} from '../utils/environmentContext.js'; import { Turn, ServerGeminiStreamEvent, @@ -182,112 +184,12 @@ export class GeminiClient { this.getChat().addHistory({ role: 'user', - parts: [{ text: await this.getDirectoryContext() }], + parts: [{ text: await getDirectoryContextString(this.config) }], }); } - private async getDirectoryContext(): Promise { - const workspaceContext = this.config.getWorkspaceContext(); - const workspaceDirectories = workspaceContext.getDirectories(); - - const folderStructures = await Promise.all( - workspaceDirectories.map((dir) => - getFolderStructure(dir, { - fileService: this.config.getFileService(), - }), - ), - ); - - const folderStructure = folderStructures.join('\n'); - const dirList = workspaceDirectories.map((dir) => ` - ${dir}`).join('\n'); - const workingDirPreamble = `I'm currently working in the following directories:\n${dirList}\n Folder structures are as follows:\n${folderStructure}`; - return workingDirPreamble; - } - - private async getEnvironment(): Promise { - const today = new Date().toLocaleDateString(undefined, { - weekday: 'long', - year: 'numeric', - month: 'long', - day: 'numeric', - }); - const platform = process.platform; - - const workspaceContext = this.config.getWorkspaceContext(); - const workspaceDirectories = workspaceContext.getDirectories(); - - const folderStructures = await Promise.all( - workspaceDirectories.map((dir) => - getFolderStructure(dir, { - fileService: this.config.getFileService(), - }), - ), - ); - - const folderStructure = folderStructures.join('\n'); - - let workingDirPreamble: string; - if (workspaceDirectories.length === 1) { - workingDirPreamble = `I'm currently working in the directory: ${workspaceDirectories[0]}`; - } else { - const dirList = workspaceDirectories - .map((dir) => ` - ${dir}`) - .join('\n'); - workingDirPreamble = `I'm currently working in the following directories:\n${dirList}`; - } - - const context = ` - This is the Gemini CLI. We are setting up the context for our chat. - Today's date is ${today}. - My operating system is: ${platform} - ${workingDirPreamble} - Here is the folder structure of the current working directories:\n - ${folderStructure} - `.trim(); - - const initialParts: Part[] = [{ text: context }]; - const toolRegistry = await this.config.getToolRegistry(); - - // Add full file context if the flag is set - if (this.config.getFullContext()) { - try { - const readManyFilesTool = toolRegistry.getTool('read_many_files'); - if (readManyFilesTool) { - const invocation = readManyFilesTool.build({ - paths: ['**/*'], // Read everything recursively - useDefaultExcludes: true, // Use default excludes - }); - - // Read all files in the target directory - const result = await invocation.execute(AbortSignal.timeout(30000)); - if (result.llmContent) { - initialParts.push({ - text: `\n--- Full File Context ---\n${result.llmContent}`, - }); - } else { - console.warn( - 'Full context requested, but read_many_files returned no content.', - ); - } - } else { - console.warn( - 'Full context requested, but read_many_files tool not found.', - ); - } - } catch (error) { - // Not using reportError here as it's a startup/config phase, not a chat/generation phase error. - console.error('Error reading full file context:', error); - initialParts.push({ - text: '\n--- Error reading full file context ---', - }); - } - } - - return initialParts; - } - async startChat(extraHistory?: Content[]): Promise { - const envParts = await this.getEnvironment(); + const envParts = await getEnvironmentContext(this.config); const toolRegistry = await this.config.getToolRegistry(); const toolDeclarations = toolRegistry.getFunctionDeclarations(); const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 5f5b22e8..cff23d2d 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -242,6 +242,9 @@ export class GeminiChat { return null; } + setSystemInstruction(sysInstr: string) { + this.generationConfig.systemInstruction = sysInstr; + } /** * Sends a message to the model and returns the response. * diff --git a/packages/core/src/core/subagent.test.ts b/packages/core/src/core/subagent.test.ts new file mode 100644 index 00000000..889feb45 --- /dev/null +++ b/packages/core/src/core/subagent.test.ts @@ -0,0 +1,814 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi, describe, it, expect, beforeEach, Mock, afterEach } from 'vitest'; +import { + ContextState, + SubAgentScope, + SubagentTerminateMode, + PromptConfig, + ModelConfig, + RunConfig, + OutputConfig, + ToolConfig, +} from './subagent.js'; +import { Config, ConfigParameters } from '../config/config.js'; +import { GeminiChat } from './geminiChat.js'; +import { createContentGenerator } from './contentGenerator.js'; +import { getEnvironmentContext } from '../utils/environmentContext.js'; +import { executeToolCall } from './nonInteractiveToolExecutor.js'; +import { ToolRegistry } from '../tools/tool-registry.js'; +import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; +import { + Content, + FunctionCall, + FunctionDeclaration, + GenerateContentConfig, + Type, +} from '@google/genai'; +import { ToolErrorType } from '../tools/tool-error.js'; + +vi.mock('./geminiChat.js'); +vi.mock('./contentGenerator.js'); +vi.mock('../utils/environmentContext.js'); +vi.mock('./nonInteractiveToolExecutor.js'); +vi.mock('../ide/ide-client.js'); + +async function createMockConfig( + toolRegistryMocks = {}, +): Promise<{ config: Config; toolRegistry: ToolRegistry }> { + const configParams: ConfigParameters = { + sessionId: 'test-session', + model: DEFAULT_GEMINI_MODEL, + targetDir: '.', + debugMode: false, + cwd: process.cwd(), + }; + const config = new Config(configParams); + await config.initialize(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await config.refreshAuth('test-auth' as any); + + // Mock ToolRegistry + const mockToolRegistry = { + getTool: vi.fn(), + getFunctionDeclarationsFiltered: vi.fn().mockReturnValue([]), + ...toolRegistryMocks, + } as unknown as ToolRegistry; + + vi.spyOn(config, 'getToolRegistry').mockResolvedValue(mockToolRegistry); + return { config, toolRegistry: mockToolRegistry }; +} + +// Helper to simulate LLM responses (sequence of tool calls over multiple turns) +const createMockStream = ( + functionCallsList: Array, +) => { + let index = 0; + return vi.fn().mockImplementation(() => { + const response = functionCallsList[index] || 'stop'; + index++; + return (async function* () { + if (response === 'stop') { + // When stopping, the model might return text, but the subagent logic primarily cares about the absence of functionCalls. + yield { text: 'Done.' }; + } else if (response.length > 0) { + yield { functionCalls: response }; + } else { + yield { text: 'Done.' }; // Handle empty array also as stop + } + })(); + }); +}; + +describe('subagent.ts', () => { + describe('ContextState', () => { + it('should set and get values correctly', () => { + const context = new ContextState(); + context.set('key1', 'value1'); + context.set('key2', 123); + expect(context.get('key1')).toBe('value1'); + expect(context.get('key2')).toBe(123); + expect(context.get_keys()).toEqual(['key1', 'key2']); + }); + + it('should return undefined for missing keys', () => { + const context = new ContextState(); + expect(context.get('missing')).toBeUndefined(); + }); + }); + + describe('SubAgentScope', () => { + let mockSendMessageStream: Mock; + + const defaultModelConfig: ModelConfig = { + model: 'gemini-1.5-flash-latest', + temp: 0.5, // Specific temp to test override + top_p: 1, + }; + + const defaultRunConfig: RunConfig = { + max_time_minutes: 5, + max_turns: 10, + }; + + beforeEach(async () => { + vi.clearAllMocks(); + + vi.mocked(getEnvironmentContext).mockResolvedValue([ + { text: 'Env Context' }, + ]); + vi.mocked(createContentGenerator).mockResolvedValue({ + getGenerativeModel: vi.fn(), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } as any); + + mockSendMessageStream = vi.fn(); + // We mock the implementation of the constructor. + vi.mocked(GeminiChat).mockImplementation( + () => + ({ + sendMessageStream: mockSendMessageStream, + }) as unknown as GeminiChat, + ); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + // Helper to safely access generationConfig from mock calls + const getGenerationConfigFromMock = ( + callIndex = 0, + ): GenerateContentConfig & { systemInstruction?: string | Content } => { + const callArgs = vi.mocked(GeminiChat).mock.calls[callIndex]; + const generationConfig = callArgs?.[2]; + // Ensure it's defined before proceeding + expect(generationConfig).toBeDefined(); + if (!generationConfig) throw new Error('generationConfig is undefined'); + return generationConfig as GenerateContentConfig & { + systemInstruction?: string | Content; + }; + }; + + describe('create (Tool Validation)', () => { + const promptConfig: PromptConfig = { systemPrompt: 'Test prompt' }; + + it('should create a SubAgentScope successfully with minimal config', async () => { + const { config } = await createMockConfig(); + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + ); + expect(scope).toBeInstanceOf(SubAgentScope); + }); + + it('should throw an error if a tool requires confirmation', async () => { + const mockTool = { + schema: { parameters: { type: Type.OBJECT, properties: {} } }, + build: vi.fn().mockReturnValue({ + shouldConfirmExecute: vi.fn().mockResolvedValue({ + type: 'exec', + title: 'Confirm', + command: 'rm -rf /', + }), + }), + }; + + const { config } = await createMockConfig({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + getTool: vi.fn().mockReturnValue(mockTool as any), + }); + + const toolConfig: ToolConfig = { tools: ['risky_tool'] }; + + await expect( + SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + toolConfig, + ), + ).rejects.toThrow( + 'Tool "risky_tool" requires user confirmation and cannot be used in a non-interactive subagent.', + ); + }); + + it('should succeed if tools do not require confirmation', async () => { + const mockTool = { + schema: { parameters: { type: Type.OBJECT, properties: {} } }, + build: vi.fn().mockReturnValue({ + shouldConfirmExecute: vi.fn().mockResolvedValue(null), + }), + }; + const { config } = await createMockConfig({ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + getTool: vi.fn().mockReturnValue(mockTool as any), + }); + + const toolConfig: ToolConfig = { tools: ['safe_tool'] }; + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + toolConfig, + ); + expect(scope).toBeInstanceOf(SubAgentScope); + }); + + it('should skip interactivity check and warn for tools with required parameters', async () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + + const mockToolWithParams = { + schema: { + parameters: { + type: Type.OBJECT, + properties: { + path: { type: Type.STRING }, + }, + required: ['path'], + }, + }, + // build should not be called, but we mock it to be safe + build: vi.fn(), + }; + + const { config } = await createMockConfig({ + getTool: vi.fn().mockReturnValue(mockToolWithParams), + }); + + const toolConfig: ToolConfig = { tools: ['tool_with_params'] }; + + // The creation should succeed without throwing + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + toolConfig, + ); + + expect(scope).toBeInstanceOf(SubAgentScope); + + // Check that the warning was logged + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Cannot check tool "tool_with_params" for interactivity because it requires parameters. Assuming it is safe for non-interactive use.', + ); + + // Ensure build was never called + expect(mockToolWithParams.build).not.toHaveBeenCalled(); + + consoleWarnSpy.mockRestore(); + }); + }); + + describe('runNonInteractive - Initialization and Prompting', () => { + it('should correctly template the system prompt and initialize GeminiChat', async () => { + const { config } = await createMockConfig(); + + vi.mocked(GeminiChat).mockClear(); + + const promptConfig: PromptConfig = { + systemPrompt: 'Hello ${name}, your task is ${task}.', + }; + const context = new ContextState(); + context.set('name', 'Agent'); + context.set('task', 'Testing'); + + // Model stops immediately + mockSendMessageStream.mockImplementation(createMockStream(['stop'])); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + ); + + await scope.runNonInteractive(context); + + // Check if GeminiChat was initialized correctly by the subagent + expect(GeminiChat).toHaveBeenCalledTimes(1); + const callArgs = vi.mocked(GeminiChat).mock.calls[0]; + + // Check Generation Config + const generationConfig = getGenerationConfigFromMock(); + + // Check temperature override + expect(generationConfig.temperature).toBe(defaultModelConfig.temp); + expect(generationConfig.systemInstruction).toContain( + 'Hello Agent, your task is Testing.', + ); + expect(generationConfig.systemInstruction).toContain( + 'Important Rules:', + ); + + // Check History (should include environment context) + const history = callArgs[3]; + expect(history).toEqual([ + { role: 'user', parts: [{ text: 'Env Context' }] }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the context!' }], + }, + ]); + }); + + it('should include output instructions in the system prompt when outputs are defined', async () => { + const { config } = await createMockConfig(); + vi.mocked(GeminiChat).mockClear(); + + const promptConfig: PromptConfig = { systemPrompt: 'Do the task.' }; + const outputConfig: OutputConfig = { + outputs: { + result1: 'The first result', + }, + }; + const context = new ContextState(); + + // Model stops immediately + mockSendMessageStream.mockImplementation(createMockStream(['stop'])); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + undefined, // ToolConfig + outputConfig, + ); + + await scope.runNonInteractive(context); + + const generationConfig = getGenerationConfigFromMock(); + const systemInstruction = generationConfig.systemInstruction as string; + + expect(systemInstruction).toContain('Do the task.'); + expect(systemInstruction).toContain( + 'you MUST emit the required output variables', + ); + expect(systemInstruction).toContain( + "Use 'self.emitvalue' to emit the 'result1' key", + ); + }); + + it('should use initialMessages instead of systemPrompt if provided', async () => { + const { config } = await createMockConfig(); + vi.mocked(GeminiChat).mockClear(); + + const initialMessages: Content[] = [ + { role: 'user', parts: [{ text: 'Hi' }] }, + ]; + const promptConfig: PromptConfig = { initialMessages }; + const context = new ContextState(); + + // Model stops immediately + mockSendMessageStream.mockImplementation(createMockStream(['stop'])); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + ); + + await scope.runNonInteractive(context); + + const callArgs = vi.mocked(GeminiChat).mock.calls[0]; + const generationConfig = getGenerationConfigFromMock(); + const history = callArgs[3]; + + expect(generationConfig.systemInstruction).toBeUndefined(); + expect(history).toEqual([ + { role: 'user', parts: [{ text: 'Env Context' }] }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the context!' }], + }, + ...initialMessages, + ]); + }); + + it('should throw an error if template variables are missing', async () => { + const { config } = await createMockConfig(); + const promptConfig: PromptConfig = { + systemPrompt: 'Hello ${name}, you are missing ${missing}.', + }; + const context = new ContextState(); + context.set('name', 'Agent'); + // 'missing' is not set + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + ); + + // The error from templating causes the runNonInteractive to reject and the terminate_reason to be ERROR. + await expect(scope.runNonInteractive(context)).rejects.toThrow( + 'Missing context values for the following keys: missing', + ); + expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.ERROR); + }); + + it('should validate that systemPrompt and initialMessages are mutually exclusive', async () => { + const { config } = await createMockConfig(); + const promptConfig: PromptConfig = { + systemPrompt: 'System', + initialMessages: [{ role: 'user', parts: [{ text: 'Hi' }] }], + }; + const context = new ContextState(); + + const agent = await SubAgentScope.create( + 'TestAgent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + ); + + await expect(agent.runNonInteractive(context)).rejects.toThrow( + 'PromptConfig cannot have both `systemPrompt` and `initialMessages` defined.', + ); + expect(agent.output.terminate_reason).toBe(SubagentTerminateMode.ERROR); + }); + }); + + describe('runNonInteractive - Execution and Tool Use', () => { + const promptConfig: PromptConfig = { systemPrompt: 'Execute task.' }; + + it('should terminate with GOAL if no outputs are expected and model stops', async () => { + const { config } = await createMockConfig(); + // Model stops immediately + mockSendMessageStream.mockImplementation(createMockStream(['stop'])); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + // No ToolConfig, No OutputConfig + ); + + await scope.runNonInteractive(new ContextState()); + + expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL); + expect(scope.output.emitted_vars).toEqual({}); + expect(mockSendMessageStream).toHaveBeenCalledTimes(1); + // Check the initial message + expect(mockSendMessageStream.mock.calls[0][0].message).toEqual([ + { text: 'Get Started!' }, + ]); + }); + + it('should handle self.emitvalue and terminate with GOAL when outputs are met', async () => { + const { config } = await createMockConfig(); + const outputConfig: OutputConfig = { + outputs: { result: 'The final result' }, + }; + + // Turn 1: Model responds with emitvalue call + // Turn 2: Model stops after receiving the tool response + mockSendMessageStream.mockImplementation( + createMockStream([ + [ + { + name: 'self.emitvalue', + args: { + emit_variable_name: 'result', + emit_variable_value: 'Success!', + }, + }, + ], + 'stop', + ]), + ); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + undefined, + outputConfig, + ); + + await scope.runNonInteractive(new ContextState()); + + expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL); + expect(scope.output.emitted_vars).toEqual({ result: 'Success!' }); + expect(mockSendMessageStream).toHaveBeenCalledTimes(2); + + // Check the tool response sent back in the second call + const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + expect(secondCallArgs.message).toEqual([ + { text: 'Emitted variable result successfully' }, + ]); + }); + + it('should execute external tools and provide the response to the model', async () => { + const listFilesToolDef: FunctionDeclaration = { + name: 'list_files', + description: 'Lists files', + parameters: { type: Type.OBJECT, properties: {} }, + }; + + const { config, toolRegistry } = await createMockConfig({ + getFunctionDeclarationsFiltered: vi + .fn() + .mockReturnValue([listFilesToolDef]), + }); + const toolConfig: ToolConfig = { tools: ['list_files'] }; + + // Turn 1: Model calls the external tool + // Turn 2: Model stops + mockSendMessageStream.mockImplementation( + createMockStream([ + [ + { + id: 'call_1', + name: 'list_files', + args: { path: '.' }, + }, + ], + 'stop', + ]), + ); + + // Mock the tool execution result + vi.mocked(executeToolCall).mockResolvedValue({ + callId: 'call_1', + responseParts: 'file1.txt\nfile2.ts', + resultDisplay: 'Listed 2 files', + error: undefined, + errorType: undefined, // Or ToolErrorType.NONE if available and appropriate + }); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + toolConfig, + ); + + await scope.runNonInteractive(new ContextState()); + + // Check tool execution + expect(executeToolCall).toHaveBeenCalledWith( + config, + expect.objectContaining({ name: 'list_files', args: { path: '.' } }), + toolRegistry, + expect.any(AbortSignal), + ); + + // Check the response sent back to the model + const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + expect(secondCallArgs.message).toEqual([ + { text: 'file1.txt\nfile2.ts' }, + ]); + + expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL); + }); + + it('should provide specific tool error responses to the model', async () => { + const { config } = await createMockConfig(); + const toolConfig: ToolConfig = { tools: ['failing_tool'] }; + + // Turn 1: Model calls the failing tool + // Turn 2: Model stops after receiving the error response + mockSendMessageStream.mockImplementation( + createMockStream([ + [ + { + id: 'call_fail', + name: 'failing_tool', + args: {}, + }, + ], + 'stop', + ]), + ); + + // Mock the tool execution failure. + vi.mocked(executeToolCall).mockResolvedValue({ + callId: 'call_fail', + responseParts: 'ERROR: Tool failed catastrophically', // This should be sent to the model + resultDisplay: 'Tool failed catastrophically', + error: new Error('Failure'), + errorType: ToolErrorType.INVALID_TOOL_PARAMS, + }); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + toolConfig, + ); + + await scope.runNonInteractive(new ContextState()); + + // The agent should send the specific error message from responseParts. + const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + + expect(secondCallArgs.message).toEqual([ + { + text: 'ERROR: Tool failed catastrophically', + }, + ]); + }); + + it('should nudge the model if it stops before emitting all required variables', async () => { + const { config } = await createMockConfig(); + const outputConfig: OutputConfig = { + outputs: { required_var: 'Must be present' }, + }; + + // Turn 1: Model stops prematurely + // Turn 2: Model responds to the nudge and emits the variable + // Turn 3: Model stops + mockSendMessageStream.mockImplementation( + createMockStream([ + 'stop', + [ + { + name: 'self.emitvalue', + args: { + emit_variable_name: 'required_var', + emit_variable_value: 'Here it is', + }, + }, + ], + 'stop', + ]), + ); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + undefined, + outputConfig, + ); + + await scope.runNonInteractive(new ContextState()); + + // Check the nudge message sent in Turn 2 + const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + + // We check that the message contains the required variable name and the nudge phrasing. + expect(secondCallArgs.message[0].text).toContain('required_var'); + expect(secondCallArgs.message[0].text).toContain( + 'You have stopped calling tools', + ); + + expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.GOAL); + expect(scope.output.emitted_vars).toEqual({ + required_var: 'Here it is', + }); + expect(mockSendMessageStream).toHaveBeenCalledTimes(3); + }); + }); + + describe('runNonInteractive - Termination and Recovery', () => { + const promptConfig: PromptConfig = { systemPrompt: 'Execute task.' }; + + it('should terminate with MAX_TURNS if the limit is reached', async () => { + const { config } = await createMockConfig(); + const runConfig: RunConfig = { ...defaultRunConfig, max_turns: 2 }; + + // Model keeps looping by calling emitvalue repeatedly + mockSendMessageStream.mockImplementation( + createMockStream([ + [ + { + name: 'self.emitvalue', + args: { emit_variable_name: 'loop', emit_variable_value: 'v1' }, + }, + ], + [ + { + name: 'self.emitvalue', + args: { emit_variable_name: 'loop', emit_variable_value: 'v2' }, + }, + ], + // This turn should not happen + [ + { + name: 'self.emitvalue', + args: { emit_variable_name: 'loop', emit_variable_value: 'v3' }, + }, + ], + ]), + ); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + runConfig, + ); + + await scope.runNonInteractive(new ContextState()); + + expect(mockSendMessageStream).toHaveBeenCalledTimes(2); + expect(scope.output.terminate_reason).toBe( + SubagentTerminateMode.MAX_TURNS, + ); + }); + + it('should terminate with TIMEOUT if the time limit is reached during an LLM call', async () => { + // Use fake timers to reliably test timeouts + vi.useFakeTimers(); + + const { config } = await createMockConfig(); + const runConfig: RunConfig = { max_time_minutes: 5, max_turns: 100 }; + + // We need to control the resolution of the sendMessageStream promise to advance the timer during execution. + let resolveStream: ( + value: AsyncGenerator, + ) => void; + const streamPromise = new Promise< + AsyncGenerator + >((resolve) => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + resolveStream = resolve as any; + }); + + // The LLM call will hang until we resolve the promise. + mockSendMessageStream.mockReturnValue(streamPromise); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + runConfig, + ); + + const runPromise = scope.runNonInteractive(new ContextState()); + + // Advance time beyond the limit (6 minutes) while the agent is awaiting the LLM response. + await vi.advanceTimersByTimeAsync(6 * 60 * 1000); + + // Now resolve the stream. The model returns 'stop'. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + resolveStream!(createMockStream(['stop'])() as any); + + await runPromise; + + expect(scope.output.terminate_reason).toBe( + SubagentTerminateMode.TIMEOUT, + ); + expect(mockSendMessageStream).toHaveBeenCalledTimes(1); + + vi.useRealTimers(); + }); + + it('should terminate with ERROR if the model call throws', async () => { + const { config } = await createMockConfig(); + mockSendMessageStream.mockRejectedValue(new Error('API Failure')); + + const scope = await SubAgentScope.create( + 'test-agent', + config, + promptConfig, + defaultModelConfig, + defaultRunConfig, + ); + + await expect( + scope.runNonInteractive(new ContextState()), + ).rejects.toThrow('API Failure'); + expect(scope.output.terminate_reason).toBe(SubagentTerminateMode.ERROR); + }); + }); + }); +}); diff --git a/packages/core/src/core/subagent.ts b/packages/core/src/core/subagent.ts new file mode 100644 index 00000000..e11a5209 --- /dev/null +++ b/packages/core/src/core/subagent.ts @@ -0,0 +1,681 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { reportError } from '../utils/errorReporting.js'; +import { ToolRegistry } from '../tools/tool-registry.js'; +import { Config } from '../config/config.js'; +import { ToolCallRequestInfo } from './turn.js'; +import { executeToolCall } from './nonInteractiveToolExecutor.js'; +import { createContentGenerator } from './contentGenerator.js'; +import { getEnvironmentContext } from '../utils/environmentContext.js'; +import { + Content, + Part, + FunctionCall, + GenerateContentConfig, + FunctionDeclaration, + Type, +} from '@google/genai'; +import { GeminiChat } from './geminiChat.js'; + +/** + * @fileoverview Defines the configuration interfaces for a subagent. + * + * These interfaces specify the structure for defining the subagent's prompt, + * the model parameters, and the execution settings. + */ + +/** + * Describes the possible termination modes for a subagent. + * This enum provides a clear indication of why a subagent's execution might have ended. + */ +export enum SubagentTerminateMode { + /** + * Indicates that the subagent's execution terminated due to an unrecoverable error. + */ + ERROR = 'ERROR', + /** + * Indicates that the subagent's execution terminated because it exceeded the maximum allowed working time. + */ + TIMEOUT = 'TIMEOUT', + /** + * Indicates that the subagent's execution successfully completed all its defined goals. + */ + GOAL = 'GOAL', + /** + * Indicates that the subagent's execution terminated because it exceeded the maximum number of turns. + */ + MAX_TURNS = 'MAX_TURNS', +} + +/** + * Represents the output structure of a subagent's execution. + * This interface defines the data that a subagent will return upon completion, + * including any emitted variables and the reason for its termination. + */ +export interface OutputObject { + /** + * A record of key-value pairs representing variables emitted by the subagent + * during its execution. These variables can be used by the calling agent. + */ + emitted_vars: Record; + /** + * The reason for the subagent's termination, indicating whether it completed + * successfully, timed out, or encountered an error. + */ + terminate_reason: SubagentTerminateMode; +} + +/** + * Configures the initial prompt for the subagent. + */ +export interface PromptConfig { + /** + * A single system prompt string that defines the subagent's persona and instructions. + * Note: You should use either `systemPrompt` or `initialMessages`, but not both. + */ + systemPrompt?: string; + + /** + * An array of user/model content pairs to seed the chat history for few-shot prompting. + * Note: You should use either `systemPrompt` or `initialMessages`, but not both. + */ + initialMessages?: Content[]; +} + +/** + * Configures the tools available to the subagent during its execution. + */ +export interface ToolConfig { + /** + * A list of tool names (from the tool registry) or full function declarations + * that the subagent is permitted to use. + */ + tools: Array; +} + +/** + * Configures the expected outputs for the subagent. + */ +export interface OutputConfig { + /** + * A record describing the variables the subagent is expected to emit. + * The subagent will be prompted to generate these values before terminating. + */ + outputs: Record; +} + +/** + * Configures the generative model parameters for the subagent. + * This interface specifies the model to be used and its associated generation settings, + * such as temperature and top-p values, which influence the creativity and diversity of the model's output. + */ +export interface ModelConfig { + /** + * The name or identifier of the model to be used (e.g., 'gemini-2.5-pro'). + * + * TODO: In the future, this needs to support 'auto' or some other string to support routing use cases. + */ + model: string; + /** + * The temperature for the model's sampling process. + */ + temp: number; + /** + * The top-p value for nucleus sampling. + */ + top_p: number; +} + +/** + * Configures the execution environment and constraints for the subagent. + * This interface defines parameters that control the subagent's runtime behavior, + * such as maximum execution time, to prevent infinite loops or excessive resource consumption. + * + * TODO: Consider adding max_tokens as a form of budgeting. + */ +export interface RunConfig { + /** The maximum execution time for the subagent in minutes. */ + max_time_minutes: number; + /** + * The maximum number of conversational turns (a user message + model response) + * before the execution is terminated. Helps prevent infinite loops. + */ + max_turns?: number; +} + +/** + * Manages the runtime context state for the subagent. + * This class provides a mechanism to store and retrieve key-value pairs + * that represent the dynamic state and variables accessible to the subagent + * during its execution. + */ +export class ContextState { + private state: Record = {}; + + /** + * Retrieves a value from the context state. + * + * @param key - The key of the value to retrieve. + * @returns The value associated with the key, or undefined if the key is not found. + */ + get(key: string): unknown { + return this.state[key]; + } + + /** + * Sets a value in the context state. + * + * @param key - The key to set the value under. + * @param value - The value to set. + */ + set(key: string, value: unknown): void { + this.state[key] = value; + } + + /** + * Retrieves all keys in the context state. + * + * @returns An array of all keys in the context state. + */ + get_keys(): string[] { + return Object.keys(this.state); + } +} + +/** + * Replaces `${...}` placeholders in a template string with values from a context. + * + * This function identifies all placeholders in the format `${key}`, validates that + * each key exists in the provided `ContextState`, and then performs the substitution. + * + * @param template The template string containing placeholders. + * @param context The `ContextState` object providing placeholder values. + * @returns The populated string with all placeholders replaced. + * @throws {Error} if any placeholder key is not found in the context. + */ +function templateString(template: string, context: ContextState): string { + const placeholderRegex = /\$\{(\w+)\}/g; + + // First, find all unique keys required by the template. + const requiredKeys = new Set( + Array.from(template.matchAll(placeholderRegex), (match) => match[1]), + ); + + // Check if all required keys exist in the context. + const contextKeys = new Set(context.get_keys()); + const missingKeys = Array.from(requiredKeys).filter( + (key) => !contextKeys.has(key), + ); + + if (missingKeys.length > 0) { + throw new Error( + `Missing context values for the following keys: ${missingKeys.join( + ', ', + )}`, + ); + } + + // Perform the replacement using a replacer function. + return template.replace(placeholderRegex, (_match, key) => + String(context.get(key)), + ); +} + +/** + * Represents the scope and execution environment for a subagent. + * This class orchestrates the subagent's lifecycle, managing its chat interactions, + * runtime context, and the collection of its outputs. + */ +export class SubAgentScope { + output: OutputObject = { + terminate_reason: SubagentTerminateMode.ERROR, + emitted_vars: {}, + }; + private readonly subagentId: string; + + /** + * Constructs a new SubAgentScope instance. + * @param name - The name for the subagent, used for logging and identification. + * @param runtimeContext - The shared runtime configuration and services. + * @param promptConfig - Configuration for the subagent's prompt and behavior. + * @param modelConfig - Configuration for the generative model parameters. + * @param runConfig - Configuration for the subagent's execution environment. + * @param toolConfig - Optional configuration for tools available to the subagent. + * @param outputConfig - Optional configuration for the subagent's expected outputs. + */ + private constructor( + readonly name: string, + readonly runtimeContext: Config, + private readonly promptConfig: PromptConfig, + private readonly modelConfig: ModelConfig, + private readonly runConfig: RunConfig, + private readonly toolConfig?: ToolConfig, + private readonly outputConfig?: OutputConfig, + ) { + const randomPart = Math.random().toString(36).slice(2, 8); + this.subagentId = `${this.name}-${randomPart}`; + } + + /** + * Creates and validates a new SubAgentScope instance. + * This factory method ensures that all tools provided in the prompt configuration + * are valid for non-interactive use before creating the subagent instance. + * @param {string} name - The name of the subagent. + * @param {Config} runtimeContext - The shared runtime configuration and services. + * @param {PromptConfig} promptConfig - Configuration for the subagent's prompt and behavior. + * @param {ModelConfig} modelConfig - Configuration for the generative model parameters. + * @param {RunConfig} runConfig - Configuration for the subagent's execution environment. + * @param {ToolConfig} [toolConfig] - Optional configuration for tools. + * @param {OutputConfig} [outputConfig] - Optional configuration for expected outputs. + * @returns {Promise} A promise that resolves to a valid SubAgentScope instance. + * @throws {Error} If any tool requires user confirmation. + */ + static async create( + name: string, + runtimeContext: Config, + promptConfig: PromptConfig, + modelConfig: ModelConfig, + runConfig: RunConfig, + toolConfig?: ToolConfig, + outputConfig?: OutputConfig, + ): Promise { + if (toolConfig) { + const toolRegistry: ToolRegistry = await runtimeContext.getToolRegistry(); + const toolsToLoad: string[] = []; + for (const tool of toolConfig.tools) { + if (typeof tool === 'string') { + toolsToLoad.push(tool); + } + } + + for (const toolName of toolsToLoad) { + const tool = toolRegistry.getTool(toolName); + if (tool) { + const requiredParams = tool.schema.parameters?.required ?? []; + if (requiredParams.length > 0) { + // This check is imperfect. A tool might require parameters but still + // be interactive (e.g., `delete_file(path)`). However, we cannot + // build a generic invocation without knowing what dummy parameters + // to provide. Crashing here because `build({})` fails is worse + // than allowing a potential hang later if an interactive tool is + // used. This is a best-effort check. + console.warn( + `Cannot check tool "${toolName}" for interactivity because it requires parameters. Assuming it is safe for non-interactive use.`, + ); + continue; + } + + const invocation = tool.build({}); + const confirmationDetails = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + if (confirmationDetails) { + throw new Error( + `Tool "${toolName}" requires user confirmation and cannot be used in a non-interactive subagent.`, + ); + } + } + } + } + + return new SubAgentScope( + name, + runtimeContext, + promptConfig, + modelConfig, + runConfig, + toolConfig, + outputConfig, + ); + } + + /** + * Runs the subagent in a non-interactive mode. + * This method orchestrates the subagent's execution loop, including prompt templating, + * tool execution, and termination conditions. + * @param {ContextState} context - The current context state containing variables for prompt templating. + * @returns {Promise} A promise that resolves when the subagent has completed its execution. + */ + async runNonInteractive(context: ContextState): Promise { + const chat = await this.createChatObject(context); + + if (!chat) { + this.output.terminate_reason = SubagentTerminateMode.ERROR; + return; + } + + const abortController = new AbortController(); + const toolRegistry: ToolRegistry = + await this.runtimeContext.getToolRegistry(); + + // Prepare the list of tools available to the subagent. + const toolsList: FunctionDeclaration[] = []; + if (this.toolConfig) { + const toolsToLoad: string[] = []; + for (const tool of this.toolConfig.tools) { + if (typeof tool === 'string') { + toolsToLoad.push(tool); + } else { + toolsList.push(tool); + } + } + toolsList.push( + ...toolRegistry.getFunctionDeclarationsFiltered(toolsToLoad), + ); + } + // Add local scope functions if outputs are expected. + if (this.outputConfig && this.outputConfig.outputs) { + toolsList.push(...this.getScopeLocalFuncDefs()); + } + + let currentMessages: Content[] = [ + { role: 'user', parts: [{ text: 'Get Started!' }] }, + ]; + + const startTime = Date.now(); + let turnCounter = 0; + try { + while (true) { + // Check termination conditions. + if ( + this.runConfig.max_turns && + turnCounter >= this.runConfig.max_turns + ) { + this.output.terminate_reason = SubagentTerminateMode.MAX_TURNS; + break; + } + let durationMin = (Date.now() - startTime) / (1000 * 60); + if (durationMin >= this.runConfig.max_time_minutes) { + this.output.terminate_reason = SubagentTerminateMode.TIMEOUT; + break; + } + + const promptId = `${this.runtimeContext.getSessionId()}#${this.subagentId}#${turnCounter++}`; + const messageParams = { + message: currentMessages[0]?.parts || [], + config: { + abortSignal: abortController.signal, + tools: [{ functionDeclarations: toolsList }], + }, + }; + + const responseStream = await chat.sendMessageStream( + messageParams, + promptId, + ); + + const functionCalls: FunctionCall[] = []; + for await (const resp of responseStream) { + if (abortController.signal.aborted) return; + if (resp.functionCalls) functionCalls.push(...resp.functionCalls); + } + + durationMin = (Date.now() - startTime) / (1000 * 60); + if (durationMin >= this.runConfig.max_time_minutes) { + this.output.terminate_reason = SubagentTerminateMode.TIMEOUT; + break; + } + + if (functionCalls.length > 0) { + currentMessages = await this.processFunctionCalls( + functionCalls, + toolRegistry, + abortController, + promptId, + ); + } else { + // Model stopped calling tools. Check if goal is met. + if ( + !this.outputConfig || + Object.keys(this.outputConfig.outputs).length === 0 + ) { + this.output.terminate_reason = SubagentTerminateMode.GOAL; + break; + } + + const remainingVars = Object.keys(this.outputConfig.outputs).filter( + (key) => !(key in this.output.emitted_vars), + ); + + if (remainingVars.length === 0) { + this.output.terminate_reason = SubagentTerminateMode.GOAL; + break; + } + + const nudgeMessage = `You have stopped calling tools but have not emitted the following required variables: ${remainingVars.join( + ', ', + )}. Please use the 'self.emitvalue' tool to emit them now, or continue working if necessary.`; + + console.debug(nudgeMessage); + + currentMessages = [ + { + role: 'user', + parts: [{ text: nudgeMessage }], + }, + ]; + } + } + } catch (error) { + console.error('Error during subagent execution:', error); + this.output.terminate_reason = SubagentTerminateMode.ERROR; + throw error; + } + } + + /** + * Processes a list of function calls, executing each one and collecting their responses. + * This method iterates through the provided function calls, executes them using the + * `executeToolCall` function (or handles `self.emitvalue` internally), and aggregates + * their results. It also manages error reporting for failed tool executions. + * @param {FunctionCall[]} functionCalls - An array of `FunctionCall` objects to process. + * @param {ToolRegistry} toolRegistry - The tool registry to look up and execute tools. + * @param {AbortController} abortController - An `AbortController` to signal cancellation of tool executions. + * @returns {Promise} A promise that resolves to an array of `Content` parts representing the tool responses, + * which are then used to update the chat history. + */ + private async processFunctionCalls( + functionCalls: FunctionCall[], + toolRegistry: ToolRegistry, + abortController: AbortController, + promptId: string, + ): Promise { + const toolResponseParts: Part[] = []; + + for (const functionCall of functionCalls) { + const callId = functionCall.id ?? `${functionCall.name}-${Date.now()}`; + const requestInfo: ToolCallRequestInfo = { + callId, + name: functionCall.name as string, + args: (functionCall.args ?? {}) as Record, + isClientInitiated: true, + prompt_id: promptId, + }; + + let toolResponse; + + // Handle scope-local tools first. + if (functionCall.name === 'self.emitvalue') { + const valName = String(requestInfo.args['emit_variable_name']); + const valVal = String(requestInfo.args['emit_variable_value']); + this.output.emitted_vars[valName] = valVal; + + toolResponse = { + callId, + responseParts: `Emitted variable ${valName} successfully`, + resultDisplay: `Emitted variable ${valName} successfully`, + error: undefined, + }; + } else { + toolResponse = await executeToolCall( + this.runtimeContext, + requestInfo, + toolRegistry, + abortController.signal, + ); + } + + if (toolResponse.error) { + console.error( + `Error executing tool ${functionCall.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, + ); + } + + if (toolResponse.responseParts) { + const parts = Array.isArray(toolResponse.responseParts) + ? toolResponse.responseParts + : [toolResponse.responseParts]; + for (const part of parts) { + if (typeof part === 'string') { + toolResponseParts.push({ text: part }); + } else if (part) { + toolResponseParts.push(part); + } + } + } + } + // If all tool calls failed, inform the model so it can re-evaluate. + if (functionCalls.length > 0 && toolResponseParts.length === 0) { + toolResponseParts.push({ + text: 'All tool calls failed. Please analyze the errors and try an alternative approach.', + }); + } + + return [{ role: 'user', parts: toolResponseParts }]; + } + + private async createChatObject(context: ContextState) { + if (!this.promptConfig.systemPrompt && !this.promptConfig.initialMessages) { + throw new Error( + 'PromptConfig must have either `systemPrompt` or `initialMessages` defined.', + ); + } + if (this.promptConfig.systemPrompt && this.promptConfig.initialMessages) { + throw new Error( + 'PromptConfig cannot have both `systemPrompt` and `initialMessages` defined.', + ); + } + + const envParts = await getEnvironmentContext(this.runtimeContext); + const envHistory: Content[] = [ + { role: 'user', parts: envParts }, + { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, + ]; + + const start_history = [ + ...envHistory, + ...(this.promptConfig.initialMessages ?? []), + ]; + + const systemInstruction = this.promptConfig.systemPrompt + ? this.buildChatSystemPrompt(context) + : undefined; + + try { + const generationConfig: GenerateContentConfig & { + systemInstruction?: string | Content; + } = { + temperature: this.modelConfig.temp, + topP: this.modelConfig.top_p, + }; + + if (systemInstruction) { + generationConfig.systemInstruction = systemInstruction; + } + + const contentGenerator = await createContentGenerator( + this.runtimeContext.getContentGeneratorConfig(), + this.runtimeContext, + this.runtimeContext.getSessionId(), + ); + + this.runtimeContext.setModel(this.modelConfig.model); + + return new GeminiChat( + this.runtimeContext, + contentGenerator, + generationConfig, + start_history, + ); + } catch (error) { + await reportError( + error, + 'Error initializing Gemini chat session.', + start_history, + 'startChat', + ); + // The calling function will handle the undefined return. + return undefined; + } + } + + /** + * Returns an array of FunctionDeclaration objects for tools that are local to the subagent's scope. + * Currently, this includes the `self.emitvalue` tool for emitting variables. + * @returns An array of `FunctionDeclaration` objects. + */ + private getScopeLocalFuncDefs() { + const emitValueTool: FunctionDeclaration = { + name: 'self.emitvalue', + description: `* This tool emits A SINGLE return value from this execution, such that it can be collected and presented to the calling function. + * You can only emit ONE VALUE each time you call this tool. You are expected to call this tool MULTIPLE TIMES if you have MULTIPLE OUTPUTS.`, + parameters: { + type: Type.OBJECT, + properties: { + emit_variable_name: { + description: 'This is the name of the variable to be returned.', + type: Type.STRING, + }, + emit_variable_value: { + description: + 'This is the _value_ to be returned for this variable.', + type: Type.STRING, + }, + }, + required: ['emit_variable_name', 'emit_variable_value'], + }, + }; + + return [emitValueTool]; + } + + /** + * Builds the system prompt for the chat based on the provided configurations. + * It templates the base system prompt and appends instructions for emitting + * variables if an `OutputConfig` is provided. + * @param {ContextState} context - The context for templating. + * @returns {string} The complete system prompt. + */ + private buildChatSystemPrompt(context: ContextState): string { + if (!this.promptConfig.systemPrompt) { + // This should ideally be caught in createChatObject, but serves as a safeguard. + return ''; + } + + let finalPrompt = templateString(this.promptConfig.systemPrompt, context); + + // Add instructions for emitting variables if needed. + if (this.outputConfig && this.outputConfig.outputs) { + let outputInstructions = + '\n\nAfter you have achieved all other goals, you MUST emit the required output variables. For each expected output, make one final call to the `self.emitvalue` tool.'; + + for (const [key, value] of Object.entries(this.outputConfig.outputs)) { + outputInstructions += `\n* Use 'self.emitvalue' to emit the '${key}' key, with a value described as: '${value}'`; + } + finalPrompt += outputInstructions; + } + + // Add general non-interactive instructions. + finalPrompt += ` + +Important Rules: + * You are running in a non-interactive mode. You CANNOT ask the user for input or clarification. You must proceed with the information you have. + * Once you believe all goals have been met and all required outputs have been emitted, stop calling tools.`; + + return finalPrompt; + } +} diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 73b427d4..c77fab8c 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -365,6 +365,22 @@ export class ToolRegistry { return declarations; } + /** + * Retrieves a filtered list of tool schemas based on a list of tool names. + * @param toolNames - An array of tool names to include. + * @returns An array of FunctionDeclarations for the specified tools. + */ + getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] { + const declarations: FunctionDeclaration[] = []; + for (const name of toolNames) { + const tool = this.tools.get(name); + if (tool) { + declarations.push(tool.schema); + } + } + return declarations; + } + /** * Returns an array of all registered and discovered tool instances. */ diff --git a/packages/core/src/utils/environmentContext.test.ts b/packages/core/src/utils/environmentContext.test.ts new file mode 100644 index 00000000..656fb63f --- /dev/null +++ b/packages/core/src/utils/environmentContext.test.ts @@ -0,0 +1,205 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { + getEnvironmentContext, + getDirectoryContextString, +} from './environmentContext.js'; +import { Config } from '../config/config.js'; +import { getFolderStructure } from './getFolderStructure.js'; + +vi.mock('../config/config.js'); +vi.mock('./getFolderStructure.js', () => ({ + getFolderStructure: vi.fn(), +})); +vi.mock('../tools/read-many-files.js'); + +describe('getDirectoryContextString', () => { + let mockConfig: Partial; + + beforeEach(() => { + mockConfig = { + getWorkspaceContext: vi.fn().mockReturnValue({ + getDirectories: vi.fn().mockReturnValue(['/test/dir']), + }), + getFileService: vi.fn(), + }; + vi.mocked(getFolderStructure).mockResolvedValue('Mock Folder Structure'); + }); + + afterEach(() => { + vi.resetAllMocks(); + }); + + it('should return context string for a single directory', async () => { + const contextString = await getDirectoryContextString(mockConfig as Config); + expect(contextString).toContain( + "I'm currently working in the directory: /test/dir", + ); + expect(contextString).toContain( + 'Here is the folder structure of the current working directories:\n\nMock Folder Structure', + ); + }); + + it('should return context string for multiple directories', async () => { + ( + vi.mocked(mockConfig.getWorkspaceContext!)().getDirectories as Mock + ).mockReturnValue(['/test/dir1', '/test/dir2']); + vi.mocked(getFolderStructure) + .mockResolvedValueOnce('Structure 1') + .mockResolvedValueOnce('Structure 2'); + + const contextString = await getDirectoryContextString(mockConfig as Config); + expect(contextString).toContain( + "I'm currently working in the following directories:\n - /test/dir1\n - /test/dir2", + ); + expect(contextString).toContain( + 'Here is the folder structure of the current working directories:\n\nStructure 1\nStructure 2', + ); + }); +}); + +describe('getEnvironmentContext', () => { + let mockConfig: Partial; + let mockToolRegistry: { getTool: Mock }; + + beforeEach(() => { + vi.useFakeTimers(); + vi.setSystemTime(new Date('2025-08-05T12:00:00Z')); + + mockToolRegistry = { + getTool: vi.fn(), + }; + + mockConfig = { + getWorkspaceContext: vi.fn().mockReturnValue({ + getDirectories: vi.fn().mockReturnValue(['/test/dir']), + }), + getFileService: vi.fn(), + getFullContext: vi.fn().mockReturnValue(false), + getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), + }; + + vi.mocked(getFolderStructure).mockResolvedValue('Mock Folder Structure'); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.resetAllMocks(); + }); + + it('should return basic environment context for a single directory', async () => { + const parts = await getEnvironmentContext(mockConfig as Config); + + expect(parts.length).toBe(1); + const context = parts[0].text; + + expect(context).toContain("Today's date is Tuesday, August 5, 2025"); + expect(context).toContain(`My operating system is: ${process.platform}`); + expect(context).toContain( + "I'm currently working in the directory: /test/dir", + ); + expect(context).toContain( + 'Here is the folder structure of the current working directories:\n\nMock Folder Structure', + ); + expect(getFolderStructure).toHaveBeenCalledWith('/test/dir', { + fileService: undefined, + }); + }); + + it('should return basic environment context for multiple directories', async () => { + ( + vi.mocked(mockConfig.getWorkspaceContext!)().getDirectories as Mock + ).mockReturnValue(['/test/dir1', '/test/dir2']); + vi.mocked(getFolderStructure) + .mockResolvedValueOnce('Structure 1') + .mockResolvedValueOnce('Structure 2'); + + const parts = await getEnvironmentContext(mockConfig as Config); + + expect(parts.length).toBe(1); + const context = parts[0].text; + + expect(context).toContain( + "I'm currently working in the following directories:\n - /test/dir1\n - /test/dir2", + ); + expect(context).toContain( + 'Here is the folder structure of the current working directories:\n\nStructure 1\nStructure 2', + ); + expect(getFolderStructure).toHaveBeenCalledTimes(2); + }); + + it('should include full file context when getFullContext is true', async () => { + mockConfig.getFullContext = vi.fn().mockReturnValue(true); + const mockReadManyFilesTool = { + build: vi.fn().mockReturnValue({ + execute: vi + .fn() + .mockResolvedValue({ llmContent: 'Full file content here' }), + }), + }; + mockToolRegistry.getTool.mockReturnValue(mockReadManyFilesTool); + + const parts = await getEnvironmentContext(mockConfig as Config); + + expect(parts.length).toBe(2); + expect(parts[1].text).toBe( + '\n--- Full File Context ---\nFull file content here', + ); + expect(mockToolRegistry.getTool).toHaveBeenCalledWith('read_many_files'); + expect(mockReadManyFilesTool.build).toHaveBeenCalledWith({ + paths: ['**/*'], + useDefaultExcludes: true, + }); + }); + + it('should handle read_many_files returning no content', async () => { + mockConfig.getFullContext = vi.fn().mockReturnValue(true); + const mockReadManyFilesTool = { + build: vi.fn().mockReturnValue({ + execute: vi.fn().mockResolvedValue({ llmContent: '' }), + }), + }; + mockToolRegistry.getTool.mockReturnValue(mockReadManyFilesTool); + + const parts = await getEnvironmentContext(mockConfig as Config); + + expect(parts.length).toBe(1); // No extra part added + }); + + it('should handle read_many_files tool not being found', async () => { + mockConfig.getFullContext = vi.fn().mockReturnValue(true); + mockToolRegistry.getTool.mockReturnValue(null); + + const parts = await getEnvironmentContext(mockConfig as Config); + + expect(parts.length).toBe(1); // No extra part added + }); + + it('should handle errors when reading full file context', async () => { + mockConfig.getFullContext = vi.fn().mockReturnValue(true); + const mockReadManyFilesTool = { + build: vi.fn().mockReturnValue({ + execute: vi.fn().mockRejectedValue(new Error('Read error')), + }), + }; + mockToolRegistry.getTool.mockReturnValue(mockReadManyFilesTool); + + const parts = await getEnvironmentContext(mockConfig as Config); + + expect(parts.length).toBe(2); + expect(parts[1].text).toBe('\n--- Error reading full file context ---'); + }); +}); diff --git a/packages/core/src/utils/environmentContext.ts b/packages/core/src/utils/environmentContext.ts new file mode 100644 index 00000000..79fb6049 --- /dev/null +++ b/packages/core/src/utils/environmentContext.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Part } from '@google/genai'; +import { Config } from '../config/config.js'; +import { getFolderStructure } from './getFolderStructure.js'; + +/** + * Generates a string describing the current workspace directories and their structures. + * @param {Config} config - The runtime configuration and services. + * @returns {Promise} A promise that resolves to the directory context string. + */ +export async function getDirectoryContextString( + config: Config, +): Promise { + const workspaceContext = config.getWorkspaceContext(); + const workspaceDirectories = workspaceContext.getDirectories(); + + const folderStructures = await Promise.all( + workspaceDirectories.map((dir) => + getFolderStructure(dir, { + fileService: config.getFileService(), + }), + ), + ); + + const folderStructure = folderStructures.join('\n'); + + let workingDirPreamble: string; + if (workspaceDirectories.length === 1) { + workingDirPreamble = `I'm currently working in the directory: ${workspaceDirectories[0]}`; + } else { + const dirList = workspaceDirectories.map((dir) => ` - ${dir}`).join('\n'); + workingDirPreamble = `I'm currently working in the following directories:\n${dirList}`; + } + + return `${workingDirPreamble} +Here is the folder structure of the current working directories: + +${folderStructure}`; +} + +/** + * Retrieves environment-related information to be included in the chat context. + * This includes the current working directory, date, operating system, and folder structure. + * Optionally, it can also include the full file context if enabled. + * @param {Config} config - The runtime configuration and services. + * @returns A promise that resolves to an array of `Part` objects containing environment information. + */ +export async function getEnvironmentContext(config: Config): Promise { + const today = new Date().toLocaleDateString(undefined, { + weekday: 'long', + year: 'numeric', + month: 'long', + day: 'numeric', + }); + const platform = process.platform; + const directoryContext = await getDirectoryContextString(config); + + const context = ` +This is the Gemini CLI. We are setting up the context for our chat. +Today's date is ${today}. +My operating system is: ${platform} +${directoryContext} + `.trim(); + + const initialParts: Part[] = [{ text: context }]; + const toolRegistry = await config.getToolRegistry(); + + // Add full file context if the flag is set + if (config.getFullContext()) { + try { + const readManyFilesTool = toolRegistry.getTool('read_many_files'); + if (readManyFilesTool) { + const invocation = readManyFilesTool.build({ + paths: ['**/*'], // Read everything recursively + useDefaultExcludes: true, // Use default excludes + }); + + // Read all files in the target directory + const result = await invocation.execute(AbortSignal.timeout(30000)); + if (result.llmContent) { + initialParts.push({ + text: `\n--- Full File Context ---\n${result.llmContent}`, + }); + } else { + console.warn( + 'Full context requested, but read_many_files returned no content.', + ); + } + } else { + console.warn( + 'Full context requested, but read_many_files tool not found.', + ); + } + } catch (error) { + // Not using reportError here as it's a startup/config phase, not a chat/generation phase error. + console.error('Error reading full file context:', error); + initialParts.push({ + text: '\n--- Error reading full file context ---', + }); + } + } + + return initialParts; +}