diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 389d35f2..c071dedc 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -39,7 +39,7 @@ describe('runNonInteractive', () => { sendMessageStream: vi.fn(), }; mockGeminiClient = { - startChat: vi.fn().mockResolvedValue(mockChat), + getChat: vi.fn().mockResolvedValue(mockChat), } as unknown as GeminiClient; mockToolRegistry = { getFunctionDeclarations: vi.fn().mockReturnValue([]), @@ -80,7 +80,6 @@ describe('runNonInteractive', () => { await runNonInteractive(mockConfig, 'Test input'); - expect(mockGeminiClient.startChat).toHaveBeenCalled(); expect(mockChat.sendMessageStream).toHaveBeenCalledWith({ message: [{ text: 'Test input' }], config: { diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index f7b4108b..7505c736 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -42,7 +42,7 @@ export async function runNonInteractive( const geminiClient = new GeminiClient(config); const toolRegistry: ToolRegistry = await config.getToolRegistry(); - const chat = await geminiClient.startChat(); + const chat = await geminiClient.getChat(); const abortController = new AbortController(); let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 3a421ebf..d46fab9e 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -405,10 +405,9 @@ describe('useGeminiStream', () => { } as TrackedCancelledToolCall, ]; - let hookResult: any; - await act(async () => { - hookResult = renderTestHook(simplifiedToolCalls); - }); + const hookResult = await act(async () => + renderTestHook(simplifiedToolCalls), + ); const { mockMarkToolsAsSubmitted, @@ -431,9 +430,8 @@ describe('useGeminiStream', () => { toolCall2ResponseParts, ]); expect(localMockSendMessageStream).toHaveBeenCalledWith( - expect.anything(), expectedMergedResponse, - expect.anything(), + expect.any(AbortSignal), ); }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 423f3489..284709cf 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -17,7 +17,6 @@ import { Config, MessageSenderType, ToolCallRequestInfo, - GeminiChat, } from '@gemini-code/core'; import { type PartListUnion } from '@google/genai'; import { @@ -76,7 +75,6 @@ export const useGeminiStream = ( ) => { const [initError, setInitError] = useState(null); const abortControllerRef = useRef(null); - const chatSessionRef = useRef(null); const geminiClientRef = useRef(null); const [isResponding, setIsResponding] = useState(false); const [pendingHistoryItemRef, setPendingHistoryItem] = @@ -256,31 +254,6 @@ export const useGeminiStream = ( ], ); - const ensureChatSession = useCallback(async (): Promise<{ - client: GeminiClient | null; - chat: GeminiChat | null; - }> => { - const currentClient = geminiClientRef.current; - if (!currentClient) { - const errorMsg = 'Gemini client is not available.'; - setInitError(errorMsg); - addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); - return { client: null, chat: null }; - } - - if (!chatSessionRef.current) { - try { - chatSessionRef.current = await currentClient.startChat(); - } catch (err: unknown) { - const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; - setInitError(errorMsg); - addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); - return { client: currentClient, chat: null }; - } - } - return { client: currentClient, chat: chatSessionRef.current }; - }, [addItem]); - // --- Stream Event Handlers --- const handleContentEvent = useCallback( @@ -444,9 +417,12 @@ export const useGeminiStream = ( return; } - const { client, chat } = await ensureChatSession(); + const client = geminiClientRef.current; - if (!client || !chat) { + if (!client) { + const errorMsg = 'Gemini client is not available.'; + setInitError(errorMsg); + addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); return; } @@ -454,7 +430,7 @@ export const useGeminiStream = ( setInitError(null); try { - const stream = client.sendMessageStream(chat, queryToSend, abortSignal); + const stream = client.sendMessageStream(queryToSend, abortSignal); const processingStatus = await processGeminiStreamEvents( stream, userMessageTimestamp, @@ -487,7 +463,6 @@ export const useGeminiStream = ( streamingState, setShowHelp, prepareQueryForGemini, - ensureChatSession, processGeminiStreamEvents, pendingHistoryItemRef, addItem, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index c3c46659..9b4b2664 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -35,6 +35,7 @@ vi.mock('../tools/memoryTool', () => ({ setGeminiMdFilename: vi.fn(), getCurrentGeminiMdFilename: vi.fn(() => 'GEMINI.md'), // Mock the original filename DEFAULT_CONTEXT_FILENAME: 'GEMINI.md', + GEMINI_CONFIG_DIR: '.gemini', })); describe('Server Config (config.ts)', () => { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 732126cb..fcad1ef0 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -27,6 +27,7 @@ import { GeminiChat } from './geminiChat.js'; import { retryWithBackoff } from '../utils/retry.js'; export class GeminiClient { + private chat: Promise; private client: GoogleGenAI; private model: string; private generateContentConfig: GenerateContentConfig = { @@ -50,6 +51,11 @@ export class GeminiClient { }, }); this.model = config.getModel(); + this.chat = this.startChat(); + } + + getChat(): Promise { + return this.chat; } private async getEnvironment(): Promise { @@ -114,12 +120,12 @@ export class GeminiClient { return initialParts; } - async startChat(): Promise { + private async startChat(extraHistory?: Content[]): Promise { const envParts = await this.getEnvironment(); const toolRegistry = await this.config.getToolRegistry(); const toolDeclarations = toolRegistry.getFunctionDeclarations(); const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; - const history: Content[] = [ + const initialHistory: Content[] = [ { role: 'user', parts: envParts, @@ -129,6 +135,7 @@ export class GeminiClient { parts: [{ text: 'Got it. Thanks for the context!' }], }, ]; + const history = initialHistory.concat(extraHistory ?? []); try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); @@ -157,7 +164,6 @@ export class GeminiClient { } async *sendMessageStream( - chat: GeminiChat, request: PartListUnion, signal: AbortSignal, turns: number = this.MAX_TURNS, @@ -166,6 +172,7 @@ export class GeminiClient { return; } + const chat = await this.chat; const turn = new Turn(chat); const resultStream = turn.run(request, signal); for await (const event of resultStream) { @@ -175,7 +182,7 @@ export class GeminiClient { const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; - yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1); + yield* this.sendMessageStream(nextRequest, signal, turns - 1); } } } diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 21aec687..12aa1a83 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -56,10 +56,10 @@ Signal: Signal number or \`(none)\` if no signal was received. let stdout = ''; let stderr = ''; child.stdout.on('data', (data) => { - stdout += data.toString(); + stdout += data?.toString(); }); child.stderr.on('data', (data) => { - stderr += data.toString(); + stderr += data?.toString(); }); let error: Error | null = null; child.on('error', (err: Error) => {