diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index b458a822..bf8c2abb 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -48,7 +48,7 @@ import { } from '@gemini-cli/core'; import { useLogger } from './hooks/useLogger.js'; import { StreamingContext } from './contexts/StreamingContext.js'; -import { SessionProvider } from './contexts/SessionContext.js'; +import { SessionStatsProvider } from './contexts/SessionContext.js'; import { useGitBranchName } from './hooks/useGitBranchName.js'; const CTRL_C_PROMPT_DURATION_MS = 1000; @@ -60,9 +60,9 @@ interface AppProps { } export const AppWrapper = (props: AppProps) => ( - + - + ); const App = ({ config, settings, startupWarnings = [] }: AppProps) => { diff --git a/packages/cli/src/ui/contexts/SessionContext.test.tsx b/packages/cli/src/ui/contexts/SessionContext.test.tsx index 3b5454cf..fedf3d74 100644 --- a/packages/cli/src/ui/contexts/SessionContext.test.tsx +++ b/packages/cli/src/ui/contexts/SessionContext.test.tsx @@ -4,26 +4,181 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { type MutableRefObject } from 'react'; import { render } from 'ink-testing-library'; -import { Text } from 'ink'; -import { SessionProvider, useSession } from './SessionContext.js'; -import { describe, it, expect } from 'vitest'; +import { act } from 'react-dom/test-utils'; +import { SessionStatsProvider, useSessionStats } from './SessionContext.js'; +import { describe, it, expect, vi } from 'vitest'; +import { GenerateContentResponseUsageMetadata } from '@google/genai'; -const TestComponent = () => { - const { startTime } = useSession(); - return {startTime.toISOString()}; +// Mock data that simulates what the Gemini API would return. +const mockMetadata1: GenerateContentResponseUsageMetadata = { + promptTokenCount: 100, + candidatesTokenCount: 200, + totalTokenCount: 300, + cachedContentTokenCount: 50, + toolUsePromptTokenCount: 10, + thoughtsTokenCount: 20, }; -describe('SessionContext', () => { - it('should provide a start time', () => { - const { lastFrame } = render( - - - , +const mockMetadata2: GenerateContentResponseUsageMetadata = { + promptTokenCount: 10, + candidatesTokenCount: 20, + totalTokenCount: 30, + cachedContentTokenCount: 5, + toolUsePromptTokenCount: 1, + thoughtsTokenCount: 2, +}; + +/** + * A test harness component that uses the hook and exposes the context value + * via a mutable ref. This allows us to interact with the context's functions + * and assert against its state directly in our tests. + */ +const TestHarness = ({ + contextRef, +}: { + contextRef: MutableRefObject | undefined>; +}) => { + contextRef.current = useSessionStats(); + return null; +}; + +describe('SessionStatsContext', () => { + it('should provide the correct initial state', () => { + const contextRef: MutableRefObject< + ReturnType | undefined + > = { current: undefined }; + + render( + + + , ); - const frameText = lastFrame(); - // Check if the output is a valid ISO string, which confirms it's a Date object. - expect(new Date(frameText!).toString()).not.toBe('Invalid Date'); + const stats = contextRef.current?.stats; + + expect(stats?.sessionStartTime).toBeInstanceOf(Date); + expect(stats?.lastTurn).toBeNull(); + expect(stats?.cumulative.turnCount).toBe(0); + expect(stats?.cumulative.totalTokenCount).toBe(0); + expect(stats?.cumulative.promptTokenCount).toBe(0); + }); + + it('should increment turnCount when startNewTurn is called', () => { + const contextRef: MutableRefObject< + ReturnType | undefined + > = { current: undefined }; + + render( + + + , + ); + + act(() => { + contextRef.current?.startNewTurn(); + }); + + const stats = contextRef.current?.stats; + expect(stats?.cumulative.turnCount).toBe(1); + // Ensure token counts are unaffected + expect(stats?.cumulative.totalTokenCount).toBe(0); + }); + + it('should aggregate token usage correctly when addUsage is called', () => { + const contextRef: MutableRefObject< + ReturnType | undefined + > = { current: undefined }; + + render( + + + , + ); + + act(() => { + contextRef.current?.addUsage(mockMetadata1); + }); + + const stats = contextRef.current?.stats; + + // Check that token counts are updated + expect(stats?.cumulative.totalTokenCount).toBe( + mockMetadata1.totalTokenCount ?? 0, + ); + expect(stats?.cumulative.promptTokenCount).toBe( + mockMetadata1.promptTokenCount ?? 0, + ); + + // Check that turn count is NOT incremented + expect(stats?.cumulative.turnCount).toBe(0); + + // Check that lastTurn is updated + expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1); + }); + + it('should correctly track a full logical turn with multiple API calls', () => { + const contextRef: MutableRefObject< + ReturnType | undefined + > = { current: undefined }; + + render( + + + , + ); + + // 1. User starts a new turn + act(() => { + contextRef.current?.startNewTurn(); + }); + + // 2. First API call (e.g., prompt with a tool request) + act(() => { + contextRef.current?.addUsage(mockMetadata1); + }); + + // 3. Second API call (e.g., sending tool response back) + act(() => { + contextRef.current?.addUsage(mockMetadata2); + }); + + const stats = contextRef.current?.stats; + + // Turn count should only be 1 + expect(stats?.cumulative.turnCount).toBe(1); + + // These fields should be the SUM of both calls + expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30 + expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20 + expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2 + + // These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true + expect(stats?.cumulative.promptTokenCount).toBe(100); + expect(stats?.cumulative.cachedContentTokenCount).toBe(50); + expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10); + + // Last turn should hold the metadata from the most recent call + expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2); + }); + + it('should throw an error when useSessionStats is used outside of a provider', () => { + // Suppress the expected console error during this test. + const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + const contextRef = { current: undefined }; + + // We expect rendering to fail, which React will catch and log as an error. + render(); + + // Assert that the first argument of the first call to console.error + // contains the expected message. This is more robust than checking + // the exact arguments, which can be affected by React/JSDOM internals. + expect(errorSpy.mock.calls[0][0]).toContain( + 'useSessionStats must be used within a SessionStatsProvider', + ); + + errorSpy.mockRestore(); }); }); diff --git a/packages/cli/src/ui/contexts/SessionContext.tsx b/packages/cli/src/ui/contexts/SessionContext.tsx index c511aa46..0549e3e1 100644 --- a/packages/cli/src/ui/contexts/SessionContext.tsx +++ b/packages/cli/src/ui/contexts/SessionContext.tsx @@ -4,35 +4,140 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React, { createContext, useContext, useState, useMemo } from 'react'; +import React, { + createContext, + useContext, + useState, + useMemo, + useCallback, +} from 'react'; -interface SessionContextType { - startTime: Date; +import { type GenerateContentResponseUsageMetadata } from '@google/genai'; + +// --- Interface Definitions --- + +interface CumulativeStats { + turnCount: number; + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + cachedContentTokenCount: number; + toolUsePromptTokenCount: number; + thoughtsTokenCount: number; } -const SessionContext = createContext(null); +interface LastTurnStats { + metadata: GenerateContentResponseUsageMetadata; + // TODO(abhipatel12): Add apiTime, etc. here in a future step. +} -export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({ +interface SessionStatsState { + sessionStartTime: Date; + cumulative: CumulativeStats; + lastTurn: LastTurnStats | null; + isNewTurnForAggregation: boolean; +} + +// Defines the final "value" of our context, including the state +// and the functions to update it. +interface SessionStatsContextValue { + stats: SessionStatsState; + startNewTurn: () => void; + addUsage: (metadata: GenerateContentResponseUsageMetadata) => void; +} + +// --- Context Definition --- + +const SessionStatsContext = createContext( + undefined, +); + +// --- Provider Component --- + +export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({ children, }) => { - const [startTime] = useState(new Date()); + const [stats, setStats] = useState({ + sessionStartTime: new Date(), + cumulative: { + turnCount: 0, + promptTokenCount: 0, + candidatesTokenCount: 0, + totalTokenCount: 0, + cachedContentTokenCount: 0, + toolUsePromptTokenCount: 0, + thoughtsTokenCount: 0, + }, + lastTurn: null, + isNewTurnForAggregation: true, + }); + + // A single, internal worker function to handle all metadata aggregation. + const aggregateTokens = useCallback( + (metadata: GenerateContentResponseUsageMetadata) => { + setStats((prevState) => { + const { isNewTurnForAggregation } = prevState; + const newCumulative = { ...prevState.cumulative }; + + newCumulative.candidatesTokenCount += + metadata.candidatesTokenCount ?? 0; + newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0; + newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0; + + if (isNewTurnForAggregation) { + newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0; + newCumulative.cachedContentTokenCount += + metadata.cachedContentTokenCount ?? 0; + newCumulative.toolUsePromptTokenCount += + metadata.toolUsePromptTokenCount ?? 0; + } + + return { + ...prevState, + cumulative: newCumulative, + lastTurn: { metadata }, + isNewTurnForAggregation: false, + }; + }); + }, + [], + ); + + const startNewTurn = useCallback(() => { + setStats((prevState) => ({ + ...prevState, + cumulative: { + ...prevState.cumulative, + turnCount: prevState.cumulative.turnCount + 1, + }, + isNewTurnForAggregation: true, + })); + }, []); const value = useMemo( () => ({ - startTime, + stats, + startNewTurn, + addUsage: aggregateTokens, }), - [startTime], + [stats, startNewTurn, aggregateTokens], ); return ( - {children} + + {children} + ); }; -export const useSession = () => { - const context = useContext(SessionContext); - if (!context) { - throw new Error('useSession must be used within a SessionProvider'); +// --- Consumer Hook --- + +export const useSessionStats = () => { + const context = useContext(SessionStatsContext); + if (context === undefined) { + throw new Error( + 'useSessionStats must be used within a SessionStatsProvider', + ); } return context; }; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index aa1e701f..cc6be49e 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -61,13 +61,13 @@ import { MCPServerStatus, getMCPServerStatus, } from '@gemini-cli/core'; -import { useSession } from '../contexts/SessionContext.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; import * as ShowMemoryCommandModule from './useShowMemoryCommand.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; vi.mock('../contexts/SessionContext.js', () => ({ - useSession: vi.fn(), + useSessionStats: vi.fn(), })); vi.mock('./useShowMemoryCommand.js', () => ({ @@ -89,7 +89,7 @@ describe('useSlashCommandProcessor', () => { let mockPerformMemoryRefresh: ReturnType; let mockConfig: Config; let mockCorgiMode: ReturnType; - const mockUseSession = useSession as Mock; + const mockUseSessionStats = useSessionStats as Mock; beforeEach(() => { mockAddItem = vi.fn(); @@ -105,8 +105,19 @@ describe('useSlashCommandProcessor', () => { getModel: vi.fn(() => 'test-model'), } as unknown as Config; mockCorgiMode = vi.fn(); - mockUseSession.mockReturnValue({ - startTime: new Date('2025-01-01T00:00:00.000Z'), + mockUseSessionStats.mockReturnValue({ + stats: { + sessionStartTime: new Date('2025-01-01T00:00:00.000Z'), + cumulative: { + turnCount: 0, + promptTokenCount: 0, + candidatesTokenCount: 0, + totalTokenCount: 0, + cachedContentTokenCount: 0, + toolUsePromptTokenCount: 0, + thoughtsTokenCount: 0, + }, + }, }); (open as Mock).mockClear(); @@ -240,29 +251,55 @@ describe('useSlashCommandProcessor', () => { }); describe('/stats command', () => { - it('should show the session duration', async () => { - const { handleSlashCommand } = getProcessor(); - let commandResult: SlashCommandActionReturn | boolean = false; - - // Mock current time - const mockDate = new Date('2025-01-01T00:01:05.000Z'); - vi.setSystemTime(mockDate); - - await act(async () => { - commandResult = handleSlashCommand('/stats'); + it('should show detailed session statistics', async () => { + // Arrange + mockUseSessionStats.mockReturnValue({ + stats: { + sessionStartTime: new Date('2025-01-01T00:00:00.000Z'), + cumulative: { + totalTokenCount: 900, + promptTokenCount: 200, + candidatesTokenCount: 400, + cachedContentTokenCount: 100, + turnCount: 1, + toolUsePromptTokenCount: 50, + thoughtsTokenCount: 150, + }, + }, }); + const { handleSlashCommand } = getProcessor(); + const mockDate = new Date('2025-01-01T01:02:03.000Z'); // 1h 2m 3s duration + vi.setSystemTime(mockDate); + + // Act + await act(async () => { + handleSlashCommand('/stats'); + }); + + // Assert + const expectedContent = [ + ` ⎿ Total duration (wall): 1h 2m 3s`, + ` Total Token usage:`, + ` Turns: 1`, + ` Total: 900`, + ` ├─ Input: 200`, + ` ├─ Output: 400`, + ` ├─ Cached: 100`, + ` └─ Overhead: 200`, + ` ├─ Model thoughts: 150`, + ` └─ Tool-use prompts: 50`, + ].join('\n'); + expect(mockAddItem).toHaveBeenNthCalledWith( - 2, + 2, // Called after the user message expect.objectContaining({ type: MessageType.INFO, - text: 'Session duration: 1m 5s', + text: expectedContent, }), expect.any(Number), ); - expect(commandResult).toBe(true); - // Restore system time vi.useRealTimers(); }); }); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index daec0379..6159fe89 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -11,7 +11,7 @@ import process from 'node:process'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { Config, MCPServerStatus, getMCPServerStatus } from '@gemini-cli/core'; import { Message, MessageType, HistoryItemWithoutId } from '../types.js'; -import { useSession } from '../contexts/SessionContext.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; import { createShowMemoryAction } from './useShowMemoryCommand.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { formatMemoryUsage } from '../utils/formatters.js'; @@ -50,8 +50,7 @@ export const useSlashCommandProcessor = ( toggleCorgiMode: () => void, showToolDescriptions: boolean = false, ) => { - const session = useSession(); - + const session = useSessionStats(); const addMessage = useCallback( (message: Message) => { // Convert Message to HistoryItemWithoutId @@ -147,7 +146,9 @@ export const useSlashCommandProcessor = ( description: 'check session stats', action: (_mainCommand, _subCommand, _args) => { const now = new Date(); - const duration = now.getTime() - session.startTime.getTime(); + const { sessionStartTime, cumulative } = session.stats; + + const duration = now.getTime() - sessionStartTime.getTime(); const durationInSeconds = Math.floor(duration / 1000); const hours = Math.floor(durationInSeconds / 3600); const minutes = Math.floor((durationInSeconds % 3600) / 60); @@ -161,9 +162,25 @@ export const useSlashCommandProcessor = ( .filter(Boolean) .join(' '); + const overheadTotal = + cumulative.thoughtsTokenCount + cumulative.toolUsePromptTokenCount; + + const statsContent = [ + ` ⎿ Total duration (wall): ${durationString}`, + ` Total Token usage:`, + ` Turns: ${cumulative.turnCount.toLocaleString()}`, + ` Total: ${cumulative.totalTokenCount.toLocaleString()}`, + ` ├─ Input: ${cumulative.promptTokenCount.toLocaleString()}`, + ` ├─ Output: ${cumulative.candidatesTokenCount.toLocaleString()}`, + ` ├─ Cached: ${cumulative.cachedContentTokenCount.toLocaleString()}`, + ` └─ Overhead: ${overheadTotal.toLocaleString()}`, + ` ├─ Model thoughts: ${cumulative.thoughtsTokenCount.toLocaleString()}`, + ` └─ Tool-use prompts: ${cumulative.toolUsePromptTokenCount.toLocaleString()}`, + ].join('\n'); + addMessage({ type: MessageType.INFO, - content: `Session duration: ${durationString}`, + content: statsContent, timestamp: new Date(), }); }, @@ -477,7 +494,7 @@ Add any other context about the problem here. toggleCorgiMode, config, showToolDescriptions, - session.startTime, + session, ], ); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index f41f7f9c..ed0f2aac 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -96,6 +96,15 @@ vi.mock('./useLogger.js', () => ({ }), })); +const mockStartNewTurn = vi.fn(); +const mockAddUsage = vi.fn(); +vi.mock('../contexts/SessionContext.js', () => ({ + useSessionStats: vi.fn(() => ({ + startNewTurn: mockStartNewTurn, + addUsage: mockAddUsage, + })), +})); + vi.mock('./slashCommandProcessor.js', () => ({ handleSlashCommand: vi.fn().mockReturnValue(false), })); @@ -531,4 +540,63 @@ describe('useGeminiStream', () => { }); }); }); + + describe('Session Stats Integration', () => { + it('should call startNewTurn and addUsage for a simple prompt', async () => { + const mockMetadata = { totalTokenCount: 123 }; + const mockStream = (async function* () { + yield { type: 'content', value: 'Response' }; + yield { type: 'usage_metadata', value: mockMetadata }; + })(); + mockSendMessageStream.mockReturnValue(mockStream); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('Hello, world!'); + }); + + expect(mockStartNewTurn).toHaveBeenCalledTimes(1); + expect(mockAddUsage).toHaveBeenCalledTimes(1); + expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata); + }); + + it('should only call addUsage for a tool continuation prompt', async () => { + const mockMetadata = { totalTokenCount: 456 }; + const mockStream = (async function* () { + yield { type: 'content', value: 'Final Answer' }; + yield { type: 'usage_metadata', value: mockMetadata }; + })(); + mockSendMessageStream.mockReturnValue(mockStream); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery([{ text: 'tool response' }], { + isContinuation: true, + }); + }); + + expect(mockStartNewTurn).not.toHaveBeenCalled(); + expect(mockAddUsage).toHaveBeenCalledTimes(1); + expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata); + }); + + it('should not call addUsage if the stream contains no usage metadata', async () => { + // Arrange: A stream that yields content but never a usage_metadata event + const mockStream = (async function* () { + yield { type: 'content', value: 'Some response text' }; + })(); + mockSendMessageStream.mockReturnValue(mockStream); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('Query with no usage data'); + }); + + expect(mockStartNewTurn).toHaveBeenCalledTimes(1); + expect(mockAddUsage).not.toHaveBeenCalled(); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2b47ae6f..bad9f78a 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -42,6 +42,7 @@ import { TrackedCompletedToolCall, TrackedCancelledToolCall, } from './useReactToolScheduler.js'; +import { useSessionStats } from '../contexts/SessionContext.js'; export function mergePartListUnions(list: PartListUnion[]): PartListUnion { const resultParts: PartListUnion = []; @@ -82,6 +83,7 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef(null); const logger = useLogger(); + const { startNewTurn, addUsage } = useSessionStats(); const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = useReactToolScheduler( @@ -390,6 +392,9 @@ export const useGeminiStream = ( case ServerGeminiEventType.ChatCompressed: handleChatCompressionEvent(); break; + case ServerGeminiEventType.UsageMetadata: + addUsage(event.value); + break; case ServerGeminiEventType.ToolCallConfirmation: case ServerGeminiEventType.ToolCallResponse: // do nothing @@ -412,11 +417,12 @@ export const useGeminiStream = ( handleErrorEvent, scheduleToolCalls, handleChatCompressionEvent, + addUsage, ], ); const submitQuery = useCallback( - async (query: PartListUnion) => { + async (query: PartListUnion, options?: { isContinuation: boolean }) => { if ( streamingState === StreamingState.Responding || streamingState === StreamingState.WaitingForConfirmation @@ -426,6 +432,10 @@ export const useGeminiStream = ( const userMessageTimestamp = Date.now(); setShowHelp(false); + if (!options?.isContinuation) { + startNewTurn(); + } + abortControllerRef.current = new AbortController(); const abortSignal = abortControllerRef.current.signal; @@ -491,6 +501,7 @@ export const useGeminiStream = ( setPendingHistoryItem, setInitError, geminiClient, + startNewTurn, ], ); @@ -576,7 +587,9 @@ export const useGeminiStream = ( ); markToolsAsSubmitted(callIdsToMarkAsSubmitted); - submitQuery(mergePartListUnions(responsesToSend)); + submitQuery(mergePartListUnions(responsesToSend), { + isContinuation: true, + }); } }, [ toolCalls, diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index cbbbd113..58ad5dbd 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -13,14 +13,32 @@ import { GoogleGenAI, } from '@google/genai'; import { GeminiClient } from './client.js'; +import { ContentGenerator } from './contentGenerator.js'; +import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; +import { Turn } from './turn.js'; // --- Mocks --- const mockChatCreateFn = vi.fn(); const mockGenerateContentFn = vi.fn(); const mockEmbedContentFn = vi.fn(); +const mockTurnRunFn = vi.fn(); vi.mock('@google/genai'); +vi.mock('./turn', () => { + // Define a mock class that has the same shape as the real Turn + class MockTurn { + pendingToolCalls = []; + // The run method is a property that holds our mock function + run = mockTurnRunFn; + + constructor() { + // The constructor can be empty or do some mock setup + } + } + // Export the mock class as 'Turn' + return { Turn: MockTurn }; +}); vi.mock('../config/config.js'); vi.mock('./prompts'); @@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => { expect(mockChat.addHistory).toHaveBeenCalledWith(newContent); }); }); + + describe('sendMessageStream', () => { + it('should return the turn instance after the stream is complete', async () => { + // Arrange + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = Promise.resolve(mockChat as GeminiChat); + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + ); + + // Consume the stream manually to get the final return value. + let finalResult: Turn | undefined; + while (true) { + const result = await stream.next(); + if (result.done) { + finalResult = result.value; + break; + } + } + + // Assert + expect(finalResult).toBeInstanceOf(Turn); + }); + }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 8b921ab1..1b953d30 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -174,9 +174,10 @@ export class GeminiClient { request: PartListUnion, signal: AbortSignal, turns: number = this.MAX_TURNS, - ): AsyncGenerator { + ): AsyncGenerator { if (!turns) { - return; + const chat = await this.chat; + return new Turn(chat); } const compressed = await this.tryCompressChat(); @@ -193,9 +194,12 @@ export class GeminiClient { const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; + // This recursive call's events will be yielded out, but the final + // turn object will be from the top-level call. yield* this.sendMessageStream(nextRequest, signal, turns - 1); } } + return turn; } private _logApiRequest(model: string, inputTokenCount: number): void { @@ -423,6 +427,10 @@ export class GeminiClient { }); const result = await retryWithBackoff(apiCall); + console.log( + 'Raw API Response in client.ts:', + JSON.stringify(result, null, 2), + ); const durationMs = Date.now() - startTime; this._logApiResponse(modelToUse, durationMs, attempt, result); return result; diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 8fb3a4c1..2217e5da 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -10,8 +10,14 @@ import { GeminiEventType, ServerGeminiToolCallRequestEvent, ServerGeminiErrorEvent, + ServerGeminiUsageMetadataEvent, } from './turn.js'; -import { GenerateContentResponse, Part, Content } from '@google/genai'; +import { + GenerateContentResponse, + Part, + Content, + GenerateContentResponseUsageMetadata, +} from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; import { GeminiChat } from './geminiChat.js'; @@ -49,6 +55,24 @@ describe('Turn', () => { }; let mockChatInstance: MockedChatInstance; + const mockMetadata1: GenerateContentResponseUsageMetadata = { + promptTokenCount: 10, + candidatesTokenCount: 20, + totalTokenCount: 30, + cachedContentTokenCount: 5, + toolUsePromptTokenCount: 2, + thoughtsTokenCount: 3, + }; + + const mockMetadata2: GenerateContentResponseUsageMetadata = { + promptTokenCount: 100, + candidatesTokenCount: 200, + totalTokenCount: 300, + cachedContentTokenCount: 50, + toolUsePromptTokenCount: 20, + thoughtsTokenCount: 30, + }; + beforeEach(() => { vi.resetAllMocks(); mockChatInstance = { @@ -96,6 +120,7 @@ describe('Turn', () => { message: reqParts, config: { abortSignal: expect.any(AbortSignal) }, }); + expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: ' world' }, @@ -208,6 +233,41 @@ describe('Turn', () => { ); }); + it('should yield the last UsageMetadata event from the stream', async () => { + const mockResponseStream = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'First response' }] } }], + usageMetadata: mockMetadata1, + } as unknown as GenerateContentResponse; + yield { + functionCalls: [{ name: 'aTool' }], + usageMetadata: mockMetadata2, + } as unknown as GenerateContentResponse; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + const reqParts: Part[] = [{ text: 'Test metadata' }]; + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { + events.push(event); + } + + // There should be a content event, a tool call, and our metadata event + expect(events.length).toBe(3); + + const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent; + expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata); + + // The value should be the *last* metadata object received. + expect(metadataEvent.value).toEqual(mockMetadata2); + + // Also check the public getter + expect(turn.getUsageMetadata()).toEqual(mockMetadata2); + }); + it('should handle function calls with undefined name or args', async () => { const mockResponseStream = (async function* () { yield { @@ -219,7 +279,6 @@ describe('Turn', () => { } as unknown as GenerateContentResponse; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); - const events = []; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; for await (const event of turn.run( diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 637fc19d..34e4a494 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -9,6 +9,7 @@ import { GenerateContentResponse, FunctionCall, FunctionDeclaration, + GenerateContentResponseUsageMetadata, } from '@google/genai'; import { ToolCallConfirmationDetails, @@ -43,6 +44,7 @@ export enum GeminiEventType { UserCancelled = 'user_cancelled', Error = 'error', ChatCompressed = 'chat_compressed', + UsageMetadata = 'usage_metadata', } export interface GeminiErrorEventValue { @@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = { type: GeminiEventType.ChatCompressed; }; +export type ServerGeminiUsageMetadataEvent = { + type: GeminiEventType.UsageMetadata; + value: GenerateContentResponseUsageMetadata; +}; + // The original union type, now composed of the individual types export type ServerGeminiStreamEvent = | ServerGeminiContentEvent @@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent = | ServerGeminiToolCallConfirmationEvent | ServerGeminiUserCancelledEvent | ServerGeminiErrorEvent - | ServerGeminiChatCompressedEvent; + | ServerGeminiChatCompressedEvent + | ServerGeminiUsageMetadataEvent; // A turn manages the agentic loop turn within the server context. export class Turn { @@ -118,6 +126,7 @@ export class Turn { args: Record; }>; private debugResponses: GenerateContentResponse[]; + private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null; constructor(private readonly chat: GeminiChat) { this.pendingToolCalls = []; @@ -157,6 +166,18 @@ export class Turn { yield event; } } + + if (resp.usageMetadata) { + this.lastUsageMetadata = + resp.usageMetadata as GenerateContentResponseUsageMetadata; + } + } + + if (this.lastUsageMetadata) { + yield { + type: GeminiEventType.UsageMetadata, + value: this.lastUsageMetadata, + }; } } catch (error) { if (signal.aborted) { @@ -197,4 +218,8 @@ export class Turn { getDebugResponses(): GenerateContentResponse[] { return this.debugResponses; } + + getUsageMetadata(): GenerateContentResponseUsageMetadata | null { + return this.lastUsageMetadata; + } }