From 7f1252d364ec251a4a76becbcb3f101b361f2656 Mon Sep 17 00:00:00 2001
From: Abhi <43648792+abhipatel12@users.noreply.github.com>
Date: Mon, 9 Jun 2025 20:25:37 -0400
Subject: [PATCH] feat: Display initial token usage metrics in /stats (#879)
---
packages/cli/src/ui/App.tsx | 6 +-
.../src/ui/contexts/SessionContext.test.tsx | 185 ++++++++++++++++--
.../cli/src/ui/contexts/SessionContext.tsx | 131 +++++++++++--
.../ui/hooks/slashCommandProcessor.test.ts | 75 +++++--
.../cli/src/ui/hooks/slashCommandProcessor.ts | 29 ++-
.../cli/src/ui/hooks/useGeminiStream.test.tsx | 68 +++++++
packages/cli/src/ui/hooks/useGeminiStream.ts | 17 +-
packages/core/src/core/client.test.ts | 58 ++++++
packages/core/src/core/client.ts | 12 +-
packages/core/src/core/turn.test.ts | 63 +++++-
packages/core/src/core/turn.ts | 27 ++-
11 files changed, 608 insertions(+), 63 deletions(-)
diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx
index b458a822..bf8c2abb 100644
--- a/packages/cli/src/ui/App.tsx
+++ b/packages/cli/src/ui/App.tsx
@@ -48,7 +48,7 @@ import {
} from '@gemini-cli/core';
import { useLogger } from './hooks/useLogger.js';
import { StreamingContext } from './contexts/StreamingContext.js';
-import { SessionProvider } from './contexts/SessionContext.js';
+import { SessionStatsProvider } from './contexts/SessionContext.js';
import { useGitBranchName } from './hooks/useGitBranchName.js';
const CTRL_C_PROMPT_DURATION_MS = 1000;
@@ -60,9 +60,9 @@ interface AppProps {
}
export const AppWrapper = (props: AppProps) => (
-
+
-
+
);
const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
diff --git a/packages/cli/src/ui/contexts/SessionContext.test.tsx b/packages/cli/src/ui/contexts/SessionContext.test.tsx
index 3b5454cf..fedf3d74 100644
--- a/packages/cli/src/ui/contexts/SessionContext.test.tsx
+++ b/packages/cli/src/ui/contexts/SessionContext.test.tsx
@@ -4,26 +4,181 @@
* SPDX-License-Identifier: Apache-2.0
*/
+import { type MutableRefObject } from 'react';
import { render } from 'ink-testing-library';
-import { Text } from 'ink';
-import { SessionProvider, useSession } from './SessionContext.js';
-import { describe, it, expect } from 'vitest';
+import { act } from 'react-dom/test-utils';
+import { SessionStatsProvider, useSessionStats } from './SessionContext.js';
+import { describe, it, expect, vi } from 'vitest';
+import { GenerateContentResponseUsageMetadata } from '@google/genai';
-const TestComponent = () => {
- const { startTime } = useSession();
- return {startTime.toISOString()};
+// Mock data that simulates what the Gemini API would return.
+const mockMetadata1: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 100,
+ candidatesTokenCount: 200,
+ totalTokenCount: 300,
+ cachedContentTokenCount: 50,
+ toolUsePromptTokenCount: 10,
+ thoughtsTokenCount: 20,
};
-describe('SessionContext', () => {
- it('should provide a start time', () => {
- const { lastFrame } = render(
-
-
- ,
+const mockMetadata2: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 10,
+ candidatesTokenCount: 20,
+ totalTokenCount: 30,
+ cachedContentTokenCount: 5,
+ toolUsePromptTokenCount: 1,
+ thoughtsTokenCount: 2,
+};
+
+/**
+ * A test harness component that uses the hook and exposes the context value
+ * via a mutable ref. This allows us to interact with the context's functions
+ * and assert against its state directly in our tests.
+ */
+const TestHarness = ({
+ contextRef,
+}: {
+ contextRef: MutableRefObject | undefined>;
+}) => {
+ contextRef.current = useSessionStats();
+ return null;
+};
+
+describe('SessionStatsContext', () => {
+ it('should provide the correct initial state', () => {
+ const contextRef: MutableRefObject<
+ ReturnType | undefined
+ > = { current: undefined };
+
+ render(
+
+
+ ,
);
- const frameText = lastFrame();
- // Check if the output is a valid ISO string, which confirms it's a Date object.
- expect(new Date(frameText!).toString()).not.toBe('Invalid Date');
+ const stats = contextRef.current?.stats;
+
+ expect(stats?.sessionStartTime).toBeInstanceOf(Date);
+ expect(stats?.lastTurn).toBeNull();
+ expect(stats?.cumulative.turnCount).toBe(0);
+ expect(stats?.cumulative.totalTokenCount).toBe(0);
+ expect(stats?.cumulative.promptTokenCount).toBe(0);
+ });
+
+ it('should increment turnCount when startNewTurn is called', () => {
+ const contextRef: MutableRefObject<
+ ReturnType | undefined
+ > = { current: undefined };
+
+ render(
+
+
+ ,
+ );
+
+ act(() => {
+ contextRef.current?.startNewTurn();
+ });
+
+ const stats = contextRef.current?.stats;
+ expect(stats?.cumulative.turnCount).toBe(1);
+ // Ensure token counts are unaffected
+ expect(stats?.cumulative.totalTokenCount).toBe(0);
+ });
+
+ it('should aggregate token usage correctly when addUsage is called', () => {
+ const contextRef: MutableRefObject<
+ ReturnType | undefined
+ > = { current: undefined };
+
+ render(
+
+
+ ,
+ );
+
+ act(() => {
+ contextRef.current?.addUsage(mockMetadata1);
+ });
+
+ const stats = contextRef.current?.stats;
+
+ // Check that token counts are updated
+ expect(stats?.cumulative.totalTokenCount).toBe(
+ mockMetadata1.totalTokenCount ?? 0,
+ );
+ expect(stats?.cumulative.promptTokenCount).toBe(
+ mockMetadata1.promptTokenCount ?? 0,
+ );
+
+ // Check that turn count is NOT incremented
+ expect(stats?.cumulative.turnCount).toBe(0);
+
+ // Check that lastTurn is updated
+ expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1);
+ });
+
+ it('should correctly track a full logical turn with multiple API calls', () => {
+ const contextRef: MutableRefObject<
+ ReturnType | undefined
+ > = { current: undefined };
+
+ render(
+
+
+ ,
+ );
+
+ // 1. User starts a new turn
+ act(() => {
+ contextRef.current?.startNewTurn();
+ });
+
+ // 2. First API call (e.g., prompt with a tool request)
+ act(() => {
+ contextRef.current?.addUsage(mockMetadata1);
+ });
+
+ // 3. Second API call (e.g., sending tool response back)
+ act(() => {
+ contextRef.current?.addUsage(mockMetadata2);
+ });
+
+ const stats = contextRef.current?.stats;
+
+ // Turn count should only be 1
+ expect(stats?.cumulative.turnCount).toBe(1);
+
+ // These fields should be the SUM of both calls
+ expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30
+ expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20
+ expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2
+
+ // These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true
+ expect(stats?.cumulative.promptTokenCount).toBe(100);
+ expect(stats?.cumulative.cachedContentTokenCount).toBe(50);
+ expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10);
+
+ // Last turn should hold the metadata from the most recent call
+ expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2);
+ });
+
+ it('should throw an error when useSessionStats is used outside of a provider', () => {
+ // Suppress the expected console error during this test.
+ const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
+
+ const contextRef = { current: undefined };
+
+ // We expect rendering to fail, which React will catch and log as an error.
+ render();
+
+ // Assert that the first argument of the first call to console.error
+ // contains the expected message. This is more robust than checking
+ // the exact arguments, which can be affected by React/JSDOM internals.
+ expect(errorSpy.mock.calls[0][0]).toContain(
+ 'useSessionStats must be used within a SessionStatsProvider',
+ );
+
+ errorSpy.mockRestore();
});
});
diff --git a/packages/cli/src/ui/contexts/SessionContext.tsx b/packages/cli/src/ui/contexts/SessionContext.tsx
index c511aa46..0549e3e1 100644
--- a/packages/cli/src/ui/contexts/SessionContext.tsx
+++ b/packages/cli/src/ui/contexts/SessionContext.tsx
@@ -4,35 +4,140 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import React, { createContext, useContext, useState, useMemo } from 'react';
+import React, {
+ createContext,
+ useContext,
+ useState,
+ useMemo,
+ useCallback,
+} from 'react';
-interface SessionContextType {
- startTime: Date;
+import { type GenerateContentResponseUsageMetadata } from '@google/genai';
+
+// --- Interface Definitions ---
+
+interface CumulativeStats {
+ turnCount: number;
+ promptTokenCount: number;
+ candidatesTokenCount: number;
+ totalTokenCount: number;
+ cachedContentTokenCount: number;
+ toolUsePromptTokenCount: number;
+ thoughtsTokenCount: number;
}
-const SessionContext = createContext(null);
+interface LastTurnStats {
+ metadata: GenerateContentResponseUsageMetadata;
+ // TODO(abhipatel12): Add apiTime, etc. here in a future step.
+}
-export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({
+interface SessionStatsState {
+ sessionStartTime: Date;
+ cumulative: CumulativeStats;
+ lastTurn: LastTurnStats | null;
+ isNewTurnForAggregation: boolean;
+}
+
+// Defines the final "value" of our context, including the state
+// and the functions to update it.
+interface SessionStatsContextValue {
+ stats: SessionStatsState;
+ startNewTurn: () => void;
+ addUsage: (metadata: GenerateContentResponseUsageMetadata) => void;
+}
+
+// --- Context Definition ---
+
+const SessionStatsContext = createContext(
+ undefined,
+);
+
+// --- Provider Component ---
+
+export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
- const [startTime] = useState(new Date());
+ const [stats, setStats] = useState({
+ sessionStartTime: new Date(),
+ cumulative: {
+ turnCount: 0,
+ promptTokenCount: 0,
+ candidatesTokenCount: 0,
+ totalTokenCount: 0,
+ cachedContentTokenCount: 0,
+ toolUsePromptTokenCount: 0,
+ thoughtsTokenCount: 0,
+ },
+ lastTurn: null,
+ isNewTurnForAggregation: true,
+ });
+
+ // A single, internal worker function to handle all metadata aggregation.
+ const aggregateTokens = useCallback(
+ (metadata: GenerateContentResponseUsageMetadata) => {
+ setStats((prevState) => {
+ const { isNewTurnForAggregation } = prevState;
+ const newCumulative = { ...prevState.cumulative };
+
+ newCumulative.candidatesTokenCount +=
+ metadata.candidatesTokenCount ?? 0;
+ newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0;
+ newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0;
+
+ if (isNewTurnForAggregation) {
+ newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0;
+ newCumulative.cachedContentTokenCount +=
+ metadata.cachedContentTokenCount ?? 0;
+ newCumulative.toolUsePromptTokenCount +=
+ metadata.toolUsePromptTokenCount ?? 0;
+ }
+
+ return {
+ ...prevState,
+ cumulative: newCumulative,
+ lastTurn: { metadata },
+ isNewTurnForAggregation: false,
+ };
+ });
+ },
+ [],
+ );
+
+ const startNewTurn = useCallback(() => {
+ setStats((prevState) => ({
+ ...prevState,
+ cumulative: {
+ ...prevState.cumulative,
+ turnCount: prevState.cumulative.turnCount + 1,
+ },
+ isNewTurnForAggregation: true,
+ }));
+ }, []);
const value = useMemo(
() => ({
- startTime,
+ stats,
+ startNewTurn,
+ addUsage: aggregateTokens,
}),
- [startTime],
+ [stats, startNewTurn, aggregateTokens],
);
return (
- {children}
+
+ {children}
+
);
};
-export const useSession = () => {
- const context = useContext(SessionContext);
- if (!context) {
- throw new Error('useSession must be used within a SessionProvider');
+// --- Consumer Hook ---
+
+export const useSessionStats = () => {
+ const context = useContext(SessionStatsContext);
+ if (context === undefined) {
+ throw new Error(
+ 'useSessionStats must be used within a SessionStatsProvider',
+ );
}
return context;
};
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
index aa1e701f..cc6be49e 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts
@@ -61,13 +61,13 @@ import {
MCPServerStatus,
getMCPServerStatus,
} from '@gemini-cli/core';
-import { useSession } from '../contexts/SessionContext.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
import * as ShowMemoryCommandModule from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
vi.mock('../contexts/SessionContext.js', () => ({
- useSession: vi.fn(),
+ useSessionStats: vi.fn(),
}));
vi.mock('./useShowMemoryCommand.js', () => ({
@@ -89,7 +89,7 @@ describe('useSlashCommandProcessor', () => {
let mockPerformMemoryRefresh: ReturnType;
let mockConfig: Config;
let mockCorgiMode: ReturnType;
- const mockUseSession = useSession as Mock;
+ const mockUseSessionStats = useSessionStats as Mock;
beforeEach(() => {
mockAddItem = vi.fn();
@@ -105,8 +105,19 @@ describe('useSlashCommandProcessor', () => {
getModel: vi.fn(() => 'test-model'),
} as unknown as Config;
mockCorgiMode = vi.fn();
- mockUseSession.mockReturnValue({
- startTime: new Date('2025-01-01T00:00:00.000Z'),
+ mockUseSessionStats.mockReturnValue({
+ stats: {
+ sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
+ cumulative: {
+ turnCount: 0,
+ promptTokenCount: 0,
+ candidatesTokenCount: 0,
+ totalTokenCount: 0,
+ cachedContentTokenCount: 0,
+ toolUsePromptTokenCount: 0,
+ thoughtsTokenCount: 0,
+ },
+ },
});
(open as Mock).mockClear();
@@ -240,29 +251,55 @@ describe('useSlashCommandProcessor', () => {
});
describe('/stats command', () => {
- it('should show the session duration', async () => {
- const { handleSlashCommand } = getProcessor();
- let commandResult: SlashCommandActionReturn | boolean = false;
-
- // Mock current time
- const mockDate = new Date('2025-01-01T00:01:05.000Z');
- vi.setSystemTime(mockDate);
-
- await act(async () => {
- commandResult = handleSlashCommand('/stats');
+ it('should show detailed session statistics', async () => {
+ // Arrange
+ mockUseSessionStats.mockReturnValue({
+ stats: {
+ sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
+ cumulative: {
+ totalTokenCount: 900,
+ promptTokenCount: 200,
+ candidatesTokenCount: 400,
+ cachedContentTokenCount: 100,
+ turnCount: 1,
+ toolUsePromptTokenCount: 50,
+ thoughtsTokenCount: 150,
+ },
+ },
});
+ const { handleSlashCommand } = getProcessor();
+ const mockDate = new Date('2025-01-01T01:02:03.000Z'); // 1h 2m 3s duration
+ vi.setSystemTime(mockDate);
+
+ // Act
+ await act(async () => {
+ handleSlashCommand('/stats');
+ });
+
+ // Assert
+ const expectedContent = [
+ ` ⎿ Total duration (wall): 1h 2m 3s`,
+ ` Total Token usage:`,
+ ` Turns: 1`,
+ ` Total: 900`,
+ ` ├─ Input: 200`,
+ ` ├─ Output: 400`,
+ ` ├─ Cached: 100`,
+ ` └─ Overhead: 200`,
+ ` ├─ Model thoughts: 150`,
+ ` └─ Tool-use prompts: 50`,
+ ].join('\n');
+
expect(mockAddItem).toHaveBeenNthCalledWith(
- 2,
+ 2, // Called after the user message
expect.objectContaining({
type: MessageType.INFO,
- text: 'Session duration: 1m 5s',
+ text: expectedContent,
}),
expect.any(Number),
);
- expect(commandResult).toBe(true);
- // Restore system time
vi.useRealTimers();
});
});
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
index daec0379..6159fe89 100644
--- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts
+++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts
@@ -11,7 +11,7 @@ import process from 'node:process';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { Config, MCPServerStatus, getMCPServerStatus } from '@gemini-cli/core';
import { Message, MessageType, HistoryItemWithoutId } from '../types.js';
-import { useSession } from '../contexts/SessionContext.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
import { createShowMemoryAction } from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
import { formatMemoryUsage } from '../utils/formatters.js';
@@ -50,8 +50,7 @@ export const useSlashCommandProcessor = (
toggleCorgiMode: () => void,
showToolDescriptions: boolean = false,
) => {
- const session = useSession();
-
+ const session = useSessionStats();
const addMessage = useCallback(
(message: Message) => {
// Convert Message to HistoryItemWithoutId
@@ -147,7 +146,9 @@ export const useSlashCommandProcessor = (
description: 'check session stats',
action: (_mainCommand, _subCommand, _args) => {
const now = new Date();
- const duration = now.getTime() - session.startTime.getTime();
+ const { sessionStartTime, cumulative } = session.stats;
+
+ const duration = now.getTime() - sessionStartTime.getTime();
const durationInSeconds = Math.floor(duration / 1000);
const hours = Math.floor(durationInSeconds / 3600);
const minutes = Math.floor((durationInSeconds % 3600) / 60);
@@ -161,9 +162,25 @@ export const useSlashCommandProcessor = (
.filter(Boolean)
.join(' ');
+ const overheadTotal =
+ cumulative.thoughtsTokenCount + cumulative.toolUsePromptTokenCount;
+
+ const statsContent = [
+ ` ⎿ Total duration (wall): ${durationString}`,
+ ` Total Token usage:`,
+ ` Turns: ${cumulative.turnCount.toLocaleString()}`,
+ ` Total: ${cumulative.totalTokenCount.toLocaleString()}`,
+ ` ├─ Input: ${cumulative.promptTokenCount.toLocaleString()}`,
+ ` ├─ Output: ${cumulative.candidatesTokenCount.toLocaleString()}`,
+ ` ├─ Cached: ${cumulative.cachedContentTokenCount.toLocaleString()}`,
+ ` └─ Overhead: ${overheadTotal.toLocaleString()}`,
+ ` ├─ Model thoughts: ${cumulative.thoughtsTokenCount.toLocaleString()}`,
+ ` └─ Tool-use prompts: ${cumulative.toolUsePromptTokenCount.toLocaleString()}`,
+ ].join('\n');
+
addMessage({
type: MessageType.INFO,
- content: `Session duration: ${durationString}`,
+ content: statsContent,
timestamp: new Date(),
});
},
@@ -477,7 +494,7 @@ Add any other context about the problem here.
toggleCorgiMode,
config,
showToolDescriptions,
- session.startTime,
+ session,
],
);
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index f41f7f9c..ed0f2aac 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -96,6 +96,15 @@ vi.mock('./useLogger.js', () => ({
}),
}));
+const mockStartNewTurn = vi.fn();
+const mockAddUsage = vi.fn();
+vi.mock('../contexts/SessionContext.js', () => ({
+ useSessionStats: vi.fn(() => ({
+ startNewTurn: mockStartNewTurn,
+ addUsage: mockAddUsage,
+ })),
+}));
+
vi.mock('./slashCommandProcessor.js', () => ({
handleSlashCommand: vi.fn().mockReturnValue(false),
}));
@@ -531,4 +540,63 @@ describe('useGeminiStream', () => {
});
});
});
+
+ describe('Session Stats Integration', () => {
+ it('should call startNewTurn and addUsage for a simple prompt', async () => {
+ const mockMetadata = { totalTokenCount: 123 };
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Response' };
+ yield { type: 'usage_metadata', value: mockMetadata };
+ })();
+ mockSendMessageStream.mockReturnValue(mockStream);
+
+ const { result } = renderTestHook();
+
+ await act(async () => {
+ await result.current.submitQuery('Hello, world!');
+ });
+
+ expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
+ });
+
+ it('should only call addUsage for a tool continuation prompt', async () => {
+ const mockMetadata = { totalTokenCount: 456 };
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Final Answer' };
+ yield { type: 'usage_metadata', value: mockMetadata };
+ })();
+ mockSendMessageStream.mockReturnValue(mockStream);
+
+ const { result } = renderTestHook();
+
+ await act(async () => {
+ await result.current.submitQuery([{ text: 'tool response' }], {
+ isContinuation: true,
+ });
+ });
+
+ expect(mockStartNewTurn).not.toHaveBeenCalled();
+ expect(mockAddUsage).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
+ });
+
+ it('should not call addUsage if the stream contains no usage metadata', async () => {
+ // Arrange: A stream that yields content but never a usage_metadata event
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Some response text' };
+ })();
+ mockSendMessageStream.mockReturnValue(mockStream);
+
+ const { result } = renderTestHook();
+
+ await act(async () => {
+ await result.current.submitQuery('Query with no usage data');
+ });
+
+ expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
+ expect(mockAddUsage).not.toHaveBeenCalled();
+ });
+ });
});
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 2b47ae6f..bad9f78a 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -42,6 +42,7 @@ import {
TrackedCompletedToolCall,
TrackedCancelledToolCall,
} from './useReactToolScheduler.js';
+import { useSessionStats } from '../contexts/SessionContext.js';
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
const resultParts: PartListUnion = [];
@@ -82,6 +83,7 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef(null);
const logger = useLogger();
+ const { startNewTurn, addUsage } = useSessionStats();
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler(
@@ -390,6 +392,9 @@ export const useGeminiStream = (
case ServerGeminiEventType.ChatCompressed:
handleChatCompressionEvent();
break;
+ case ServerGeminiEventType.UsageMetadata:
+ addUsage(event.value);
+ break;
case ServerGeminiEventType.ToolCallConfirmation:
case ServerGeminiEventType.ToolCallResponse:
// do nothing
@@ -412,11 +417,12 @@ export const useGeminiStream = (
handleErrorEvent,
scheduleToolCalls,
handleChatCompressionEvent,
+ addUsage,
],
);
const submitQuery = useCallback(
- async (query: PartListUnion) => {
+ async (query: PartListUnion, options?: { isContinuation: boolean }) => {
if (
streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation
@@ -426,6 +432,10 @@ export const useGeminiStream = (
const userMessageTimestamp = Date.now();
setShowHelp(false);
+ if (!options?.isContinuation) {
+ startNewTurn();
+ }
+
abortControllerRef.current = new AbortController();
const abortSignal = abortControllerRef.current.signal;
@@ -491,6 +501,7 @@ export const useGeminiStream = (
setPendingHistoryItem,
setInitError,
geminiClient,
+ startNewTurn,
],
);
@@ -576,7 +587,9 @@ export const useGeminiStream = (
);
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
- submitQuery(mergePartListUnions(responsesToSend));
+ submitQuery(mergePartListUnions(responsesToSend), {
+ isContinuation: true,
+ });
}
}, [
toolCalls,
diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts
index cbbbd113..58ad5dbd 100644
--- a/packages/core/src/core/client.test.ts
+++ b/packages/core/src/core/client.test.ts
@@ -13,14 +13,32 @@ import {
GoogleGenAI,
} from '@google/genai';
import { GeminiClient } from './client.js';
+import { ContentGenerator } from './contentGenerator.js';
+import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
+import { Turn } from './turn.js';
// --- Mocks ---
const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn();
const mockEmbedContentFn = vi.fn();
+const mockTurnRunFn = vi.fn();
vi.mock('@google/genai');
+vi.mock('./turn', () => {
+ // Define a mock class that has the same shape as the real Turn
+ class MockTurn {
+ pendingToolCalls = [];
+ // The run method is a property that holds our mock function
+ run = mockTurnRunFn;
+
+ constructor() {
+ // The constructor can be empty or do some mock setup
+ }
+ }
+ // Export the mock class as 'Turn'
+ return { Turn: MockTurn };
+});
vi.mock('../config/config.js');
vi.mock('./prompts');
@@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => {
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
});
});
+
+ describe('sendMessageStream', () => {
+ it('should return the turn instance after the stream is complete', async () => {
+ // Arrange
+ const mockStream = (async function* () {
+ yield { type: 'content', value: 'Hello' };
+ })();
+ mockTurnRunFn.mockReturnValue(mockStream);
+
+ const mockChat: Partial = {
+ addHistory: vi.fn(),
+ getHistory: vi.fn().mockReturnValue([]),
+ };
+ client['chat'] = Promise.resolve(mockChat as GeminiChat);
+
+ const mockGenerator: Partial = {
+ countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
+ };
+ client['contentGenerator'] = mockGenerator as ContentGenerator;
+
+ // Act
+ const stream = client.sendMessageStream(
+ [{ text: 'Hi' }],
+ new AbortController().signal,
+ );
+
+ // Consume the stream manually to get the final return value.
+ let finalResult: Turn | undefined;
+ while (true) {
+ const result = await stream.next();
+ if (result.done) {
+ finalResult = result.value;
+ break;
+ }
+ }
+
+ // Assert
+ expect(finalResult).toBeInstanceOf(Turn);
+ });
+ });
});
diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts
index 8b921ab1..1b953d30 100644
--- a/packages/core/src/core/client.ts
+++ b/packages/core/src/core/client.ts
@@ -174,9 +174,10 @@ export class GeminiClient {
request: PartListUnion,
signal: AbortSignal,
turns: number = this.MAX_TURNS,
- ): AsyncGenerator {
+ ): AsyncGenerator {
if (!turns) {
- return;
+ const chat = await this.chat;
+ return new Turn(chat);
}
const compressed = await this.tryCompressChat();
@@ -193,9 +194,12 @@ export class GeminiClient {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
+ // This recursive call's events will be yielded out, but the final
+ // turn object will be from the top-level call.
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
}
}
+ return turn;
}
private _logApiRequest(model: string, inputTokenCount: number): void {
@@ -423,6 +427,10 @@ export class GeminiClient {
});
const result = await retryWithBackoff(apiCall);
+ console.log(
+ 'Raw API Response in client.ts:',
+ JSON.stringify(result, null, 2),
+ );
const durationMs = Date.now() - startTime;
this._logApiResponse(modelToUse, durationMs, attempt, result);
return result;
diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts
index 8fb3a4c1..2217e5da 100644
--- a/packages/core/src/core/turn.test.ts
+++ b/packages/core/src/core/turn.test.ts
@@ -10,8 +10,14 @@ import {
GeminiEventType,
ServerGeminiToolCallRequestEvent,
ServerGeminiErrorEvent,
+ ServerGeminiUsageMetadataEvent,
} from './turn.js';
-import { GenerateContentResponse, Part, Content } from '@google/genai';
+import {
+ GenerateContentResponse,
+ Part,
+ Content,
+ GenerateContentResponseUsageMetadata,
+} from '@google/genai';
import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
@@ -49,6 +55,24 @@ describe('Turn', () => {
};
let mockChatInstance: MockedChatInstance;
+ const mockMetadata1: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 10,
+ candidatesTokenCount: 20,
+ totalTokenCount: 30,
+ cachedContentTokenCount: 5,
+ toolUsePromptTokenCount: 2,
+ thoughtsTokenCount: 3,
+ };
+
+ const mockMetadata2: GenerateContentResponseUsageMetadata = {
+ promptTokenCount: 100,
+ candidatesTokenCount: 200,
+ totalTokenCount: 300,
+ cachedContentTokenCount: 50,
+ toolUsePromptTokenCount: 20,
+ thoughtsTokenCount: 30,
+ };
+
beforeEach(() => {
vi.resetAllMocks();
mockChatInstance = {
@@ -96,6 +120,7 @@ describe('Turn', () => {
message: reqParts,
config: { abortSignal: expect.any(AbortSignal) },
});
+
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' },
@@ -208,6 +233,41 @@ describe('Turn', () => {
);
});
+ it('should yield the last UsageMetadata event from the stream', async () => {
+ const mockResponseStream = (async function* () {
+ yield {
+ candidates: [{ content: { parts: [{ text: 'First response' }] } }],
+ usageMetadata: mockMetadata1,
+ } as unknown as GenerateContentResponse;
+ yield {
+ functionCalls: [{ name: 'aTool' }],
+ usageMetadata: mockMetadata2,
+ } as unknown as GenerateContentResponse;
+ })();
+ mockSendMessageStream.mockResolvedValue(mockResponseStream);
+
+ const events = [];
+ const reqParts: Part[] = [{ text: 'Test metadata' }];
+ for await (const event of turn.run(
+ reqParts,
+ new AbortController().signal,
+ )) {
+ events.push(event);
+ }
+
+ // There should be a content event, a tool call, and our metadata event
+ expect(events.length).toBe(3);
+
+ const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent;
+ expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata);
+
+ // The value should be the *last* metadata object received.
+ expect(metadataEvent.value).toEqual(mockMetadata2);
+
+ // Also check the public getter
+ expect(turn.getUsageMetadata()).toEqual(mockMetadata2);
+ });
+
it('should handle function calls with undefined name or args', async () => {
const mockResponseStream = (async function* () {
yield {
@@ -219,7 +279,6 @@ describe('Turn', () => {
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
-
const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run(
diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts
index 637fc19d..34e4a494 100644
--- a/packages/core/src/core/turn.ts
+++ b/packages/core/src/core/turn.ts
@@ -9,6 +9,7 @@ import {
GenerateContentResponse,
FunctionCall,
FunctionDeclaration,
+ GenerateContentResponseUsageMetadata,
} from '@google/genai';
import {
ToolCallConfirmationDetails,
@@ -43,6 +44,7 @@ export enum GeminiEventType {
UserCancelled = 'user_cancelled',
Error = 'error',
ChatCompressed = 'chat_compressed',
+ UsageMetadata = 'usage_metadata',
}
export interface GeminiErrorEventValue {
@@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = {
type: GeminiEventType.ChatCompressed;
};
+export type ServerGeminiUsageMetadataEvent = {
+ type: GeminiEventType.UsageMetadata;
+ value: GenerateContentResponseUsageMetadata;
+};
+
// The original union type, now composed of the individual types
export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent
@@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiToolCallConfirmationEvent
| ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent
- | ServerGeminiChatCompressedEvent;
+ | ServerGeminiChatCompressedEvent
+ | ServerGeminiUsageMetadataEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {
@@ -118,6 +126,7 @@ export class Turn {
args: Record;
}>;
private debugResponses: GenerateContentResponse[];
+ private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
constructor(private readonly chat: GeminiChat) {
this.pendingToolCalls = [];
@@ -157,6 +166,18 @@ export class Turn {
yield event;
}
}
+
+ if (resp.usageMetadata) {
+ this.lastUsageMetadata =
+ resp.usageMetadata as GenerateContentResponseUsageMetadata;
+ }
+ }
+
+ if (this.lastUsageMetadata) {
+ yield {
+ type: GeminiEventType.UsageMetadata,
+ value: this.lastUsageMetadata,
+ };
}
} catch (error) {
if (signal.aborted) {
@@ -197,4 +218,8 @@ export class Turn {
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
+
+ getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
+ return this.lastUsageMetadata;
+ }
}