feat: Display initial token usage metrics in /stats (#879)

This commit is contained in:
Abhi 2025-06-09 20:25:37 -04:00 committed by GitHub
parent 6484dc9008
commit 7f1252d364
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 608 additions and 63 deletions

View File

@ -48,7 +48,7 @@ import {
} from '@gemini-cli/core'; } from '@gemini-cli/core';
import { useLogger } from './hooks/useLogger.js'; import { useLogger } from './hooks/useLogger.js';
import { StreamingContext } from './contexts/StreamingContext.js'; import { StreamingContext } from './contexts/StreamingContext.js';
import { SessionProvider } from './contexts/SessionContext.js'; import { SessionStatsProvider } from './contexts/SessionContext.js';
import { useGitBranchName } from './hooks/useGitBranchName.js'; import { useGitBranchName } from './hooks/useGitBranchName.js';
const CTRL_C_PROMPT_DURATION_MS = 1000; const CTRL_C_PROMPT_DURATION_MS = 1000;
@ -60,9 +60,9 @@ interface AppProps {
} }
export const AppWrapper = (props: AppProps) => ( export const AppWrapper = (props: AppProps) => (
<SessionProvider> <SessionStatsProvider>
<App {...props} /> <App {...props} />
</SessionProvider> </SessionStatsProvider>
); );
const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const App = ({ config, settings, startupWarnings = [] }: AppProps) => {

View File

@ -4,26 +4,181 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { type MutableRefObject } from 'react';
import { render } from 'ink-testing-library'; import { render } from 'ink-testing-library';
import { Text } from 'ink'; import { act } from 'react-dom/test-utils';
import { SessionProvider, useSession } from './SessionContext.js'; import { SessionStatsProvider, useSessionStats } from './SessionContext.js';
import { describe, it, expect } from 'vitest'; import { describe, it, expect, vi } from 'vitest';
import { GenerateContentResponseUsageMetadata } from '@google/genai';
const TestComponent = () => { // Mock data that simulates what the Gemini API would return.
const { startTime } = useSession(); const mockMetadata1: GenerateContentResponseUsageMetadata = {
return <Text>{startTime.toISOString()}</Text>; promptTokenCount: 100,
candidatesTokenCount: 200,
totalTokenCount: 300,
cachedContentTokenCount: 50,
toolUsePromptTokenCount: 10,
thoughtsTokenCount: 20,
}; };
describe('SessionContext', () => { const mockMetadata2: GenerateContentResponseUsageMetadata = {
it('should provide a start time', () => { promptTokenCount: 10,
const { lastFrame } = render( candidatesTokenCount: 20,
<SessionProvider> totalTokenCount: 30,
<TestComponent /> cachedContentTokenCount: 5,
</SessionProvider>, toolUsePromptTokenCount: 1,
thoughtsTokenCount: 2,
};
/**
* A test harness component that uses the hook and exposes the context value
* via a mutable ref. This allows us to interact with the context's functions
* and assert against its state directly in our tests.
*/
const TestHarness = ({
contextRef,
}: {
contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>;
}) => {
contextRef.current = useSessionStats();
return null;
};
describe('SessionStatsContext', () => {
it('should provide the correct initial state', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
); );
const frameText = lastFrame(); const stats = contextRef.current?.stats;
// Check if the output is a valid ISO string, which confirms it's a Date object.
expect(new Date(frameText!).toString()).not.toBe('Invalid Date'); expect(stats?.sessionStartTime).toBeInstanceOf(Date);
expect(stats?.lastTurn).toBeNull();
expect(stats?.cumulative.turnCount).toBe(0);
expect(stats?.cumulative.totalTokenCount).toBe(0);
expect(stats?.cumulative.promptTokenCount).toBe(0);
});
it('should increment turnCount when startNewTurn is called', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
act(() => {
contextRef.current?.startNewTurn();
});
const stats = contextRef.current?.stats;
expect(stats?.cumulative.turnCount).toBe(1);
// Ensure token counts are unaffected
expect(stats?.cumulative.totalTokenCount).toBe(0);
});
it('should aggregate token usage correctly when addUsage is called', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
act(() => {
contextRef.current?.addUsage(mockMetadata1);
});
const stats = contextRef.current?.stats;
// Check that token counts are updated
expect(stats?.cumulative.totalTokenCount).toBe(
mockMetadata1.totalTokenCount ?? 0,
);
expect(stats?.cumulative.promptTokenCount).toBe(
mockMetadata1.promptTokenCount ?? 0,
);
// Check that turn count is NOT incremented
expect(stats?.cumulative.turnCount).toBe(0);
// Check that lastTurn is updated
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1);
});
it('should correctly track a full logical turn with multiple API calls', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
// 1. User starts a new turn
act(() => {
contextRef.current?.startNewTurn();
});
// 2. First API call (e.g., prompt with a tool request)
act(() => {
contextRef.current?.addUsage(mockMetadata1);
});
// 3. Second API call (e.g., sending tool response back)
act(() => {
contextRef.current?.addUsage(mockMetadata2);
});
const stats = contextRef.current?.stats;
// Turn count should only be 1
expect(stats?.cumulative.turnCount).toBe(1);
// These fields should be the SUM of both calls
expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30
expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20
expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2
// These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true
expect(stats?.cumulative.promptTokenCount).toBe(100);
expect(stats?.cumulative.cachedContentTokenCount).toBe(50);
expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10);
// Last turn should hold the metadata from the most recent call
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2);
});
it('should throw an error when useSessionStats is used outside of a provider', () => {
// Suppress the expected console error during this test.
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
const contextRef = { current: undefined };
// We expect rendering to fail, which React will catch and log as an error.
render(<TestHarness contextRef={contextRef} />);
// Assert that the first argument of the first call to console.error
// contains the expected message. This is more robust than checking
// the exact arguments, which can be affected by React/JSDOM internals.
expect(errorSpy.mock.calls[0][0]).toContain(
'useSessionStats must be used within a SessionStatsProvider',
);
errorSpy.mockRestore();
}); });
}); });

View File

@ -4,35 +4,140 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import React, { createContext, useContext, useState, useMemo } from 'react'; import React, {
createContext,
useContext,
useState,
useMemo,
useCallback,
} from 'react';
interface SessionContextType { import { type GenerateContentResponseUsageMetadata } from '@google/genai';
startTime: Date;
// --- Interface Definitions ---
interface CumulativeStats {
turnCount: number;
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
cachedContentTokenCount: number;
toolUsePromptTokenCount: number;
thoughtsTokenCount: number;
} }
const SessionContext = createContext<SessionContextType | null>(null); interface LastTurnStats {
metadata: GenerateContentResponseUsageMetadata;
// TODO(abhipatel12): Add apiTime, etc. here in a future step.
}
export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({ interface SessionStatsState {
sessionStartTime: Date;
cumulative: CumulativeStats;
lastTurn: LastTurnStats | null;
isNewTurnForAggregation: boolean;
}
// Defines the final "value" of our context, including the state
// and the functions to update it.
interface SessionStatsContextValue {
stats: SessionStatsState;
startNewTurn: () => void;
addUsage: (metadata: GenerateContentResponseUsageMetadata) => void;
}
// --- Context Definition ---
const SessionStatsContext = createContext<SessionStatsContextValue | undefined>(
undefined,
);
// --- Provider Component ---
export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({
children, children,
}) => { }) => {
const [startTime] = useState(new Date()); const [stats, setStats] = useState<SessionStatsState>({
sessionStartTime: new Date(),
cumulative: {
turnCount: 0,
promptTokenCount: 0,
candidatesTokenCount: 0,
totalTokenCount: 0,
cachedContentTokenCount: 0,
toolUsePromptTokenCount: 0,
thoughtsTokenCount: 0,
},
lastTurn: null,
isNewTurnForAggregation: true,
});
// A single, internal worker function to handle all metadata aggregation.
const aggregateTokens = useCallback(
(metadata: GenerateContentResponseUsageMetadata) => {
setStats((prevState) => {
const { isNewTurnForAggregation } = prevState;
const newCumulative = { ...prevState.cumulative };
newCumulative.candidatesTokenCount +=
metadata.candidatesTokenCount ?? 0;
newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0;
newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0;
if (isNewTurnForAggregation) {
newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0;
newCumulative.cachedContentTokenCount +=
metadata.cachedContentTokenCount ?? 0;
newCumulative.toolUsePromptTokenCount +=
metadata.toolUsePromptTokenCount ?? 0;
}
return {
...prevState,
cumulative: newCumulative,
lastTurn: { metadata },
isNewTurnForAggregation: false,
};
});
},
[],
);
const startNewTurn = useCallback(() => {
setStats((prevState) => ({
...prevState,
cumulative: {
...prevState.cumulative,
turnCount: prevState.cumulative.turnCount + 1,
},
isNewTurnForAggregation: true,
}));
}, []);
const value = useMemo( const value = useMemo(
() => ({ () => ({
startTime, stats,
startNewTurn,
addUsage: aggregateTokens,
}), }),
[startTime], [stats, startNewTurn, aggregateTokens],
); );
return ( return (
<SessionContext.Provider value={value}>{children}</SessionContext.Provider> <SessionStatsContext.Provider value={value}>
{children}
</SessionStatsContext.Provider>
); );
}; };
export const useSession = () => { // --- Consumer Hook ---
const context = useContext(SessionContext);
if (!context) { export const useSessionStats = () => {
throw new Error('useSession must be used within a SessionProvider'); const context = useContext(SessionStatsContext);
if (context === undefined) {
throw new Error(
'useSessionStats must be used within a SessionStatsProvider',
);
} }
return context; return context;
}; };

View File

@ -61,13 +61,13 @@ import {
MCPServerStatus, MCPServerStatus,
getMCPServerStatus, getMCPServerStatus,
} from '@gemini-cli/core'; } from '@gemini-cli/core';
import { useSession } from '../contexts/SessionContext.js'; import { useSessionStats } from '../contexts/SessionContext.js';
import * as ShowMemoryCommandModule from './useShowMemoryCommand.js'; import * as ShowMemoryCommandModule from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
vi.mock('../contexts/SessionContext.js', () => ({ vi.mock('../contexts/SessionContext.js', () => ({
useSession: vi.fn(), useSessionStats: vi.fn(),
})); }));
vi.mock('./useShowMemoryCommand.js', () => ({ vi.mock('./useShowMemoryCommand.js', () => ({
@ -89,7 +89,7 @@ describe('useSlashCommandProcessor', () => {
let mockPerformMemoryRefresh: ReturnType<typeof vi.fn>; let mockPerformMemoryRefresh: ReturnType<typeof vi.fn>;
let mockConfig: Config; let mockConfig: Config;
let mockCorgiMode: ReturnType<typeof vi.fn>; let mockCorgiMode: ReturnType<typeof vi.fn>;
const mockUseSession = useSession as Mock; const mockUseSessionStats = useSessionStats as Mock;
beforeEach(() => { beforeEach(() => {
mockAddItem = vi.fn(); mockAddItem = vi.fn();
@ -105,8 +105,19 @@ describe('useSlashCommandProcessor', () => {
getModel: vi.fn(() => 'test-model'), getModel: vi.fn(() => 'test-model'),
} as unknown as Config; } as unknown as Config;
mockCorgiMode = vi.fn(); mockCorgiMode = vi.fn();
mockUseSession.mockReturnValue({ mockUseSessionStats.mockReturnValue({
startTime: new Date('2025-01-01T00:00:00.000Z'), stats: {
sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
cumulative: {
turnCount: 0,
promptTokenCount: 0,
candidatesTokenCount: 0,
totalTokenCount: 0,
cachedContentTokenCount: 0,
toolUsePromptTokenCount: 0,
thoughtsTokenCount: 0,
},
},
}); });
(open as Mock).mockClear(); (open as Mock).mockClear();
@ -240,29 +251,55 @@ describe('useSlashCommandProcessor', () => {
}); });
describe('/stats command', () => { describe('/stats command', () => {
it('should show the session duration', async () => { it('should show detailed session statistics', async () => {
const { handleSlashCommand } = getProcessor(); // Arrange
let commandResult: SlashCommandActionReturn | boolean = false; mockUseSessionStats.mockReturnValue({
stats: {
// Mock current time sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
const mockDate = new Date('2025-01-01T00:01:05.000Z'); cumulative: {
vi.setSystemTime(mockDate); totalTokenCount: 900,
promptTokenCount: 200,
await act(async () => { candidatesTokenCount: 400,
commandResult = handleSlashCommand('/stats'); cachedContentTokenCount: 100,
turnCount: 1,
toolUsePromptTokenCount: 50,
thoughtsTokenCount: 150,
},
},
}); });
const { handleSlashCommand } = getProcessor();
const mockDate = new Date('2025-01-01T01:02:03.000Z'); // 1h 2m 3s duration
vi.setSystemTime(mockDate);
// Act
await act(async () => {
handleSlashCommand('/stats');
});
// Assert
const expectedContent = [
` ⎿ Total duration (wall): 1h 2m 3s`,
` Total Token usage:`,
` Turns: 1`,
` Total: 900`,
` ├─ Input: 200`,
` ├─ Output: 400`,
` ├─ Cached: 100`,
` └─ Overhead: 200`,
` ├─ Model thoughts: 150`,
` └─ Tool-use prompts: 50`,
].join('\n');
expect(mockAddItem).toHaveBeenNthCalledWith( expect(mockAddItem).toHaveBeenNthCalledWith(
2, 2, // Called after the user message
expect.objectContaining({ expect.objectContaining({
type: MessageType.INFO, type: MessageType.INFO,
text: 'Session duration: 1m 5s', text: expectedContent,
}), }),
expect.any(Number), expect.any(Number),
); );
expect(commandResult).toBe(true);
// Restore system time
vi.useRealTimers(); vi.useRealTimers();
}); });
}); });

View File

@ -11,7 +11,7 @@ import process from 'node:process';
import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { Config, MCPServerStatus, getMCPServerStatus } from '@gemini-cli/core'; import { Config, MCPServerStatus, getMCPServerStatus } from '@gemini-cli/core';
import { Message, MessageType, HistoryItemWithoutId } from '../types.js'; import { Message, MessageType, HistoryItemWithoutId } from '../types.js';
import { useSession } from '../contexts/SessionContext.js'; import { useSessionStats } from '../contexts/SessionContext.js';
import { createShowMemoryAction } from './useShowMemoryCommand.js'; import { createShowMemoryAction } from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
import { formatMemoryUsage } from '../utils/formatters.js'; import { formatMemoryUsage } from '../utils/formatters.js';
@ -50,8 +50,7 @@ export const useSlashCommandProcessor = (
toggleCorgiMode: () => void, toggleCorgiMode: () => void,
showToolDescriptions: boolean = false, showToolDescriptions: boolean = false,
) => { ) => {
const session = useSession(); const session = useSessionStats();
const addMessage = useCallback( const addMessage = useCallback(
(message: Message) => { (message: Message) => {
// Convert Message to HistoryItemWithoutId // Convert Message to HistoryItemWithoutId
@ -147,7 +146,9 @@ export const useSlashCommandProcessor = (
description: 'check session stats', description: 'check session stats',
action: (_mainCommand, _subCommand, _args) => { action: (_mainCommand, _subCommand, _args) => {
const now = new Date(); const now = new Date();
const duration = now.getTime() - session.startTime.getTime(); const { sessionStartTime, cumulative } = session.stats;
const duration = now.getTime() - sessionStartTime.getTime();
const durationInSeconds = Math.floor(duration / 1000); const durationInSeconds = Math.floor(duration / 1000);
const hours = Math.floor(durationInSeconds / 3600); const hours = Math.floor(durationInSeconds / 3600);
const minutes = Math.floor((durationInSeconds % 3600) / 60); const minutes = Math.floor((durationInSeconds % 3600) / 60);
@ -161,9 +162,25 @@ export const useSlashCommandProcessor = (
.filter(Boolean) .filter(Boolean)
.join(' '); .join(' ');
const overheadTotal =
cumulative.thoughtsTokenCount + cumulative.toolUsePromptTokenCount;
const statsContent = [
` ⎿ Total duration (wall): ${durationString}`,
` Total Token usage:`,
` Turns: ${cumulative.turnCount.toLocaleString()}`,
` Total: ${cumulative.totalTokenCount.toLocaleString()}`,
` ├─ Input: ${cumulative.promptTokenCount.toLocaleString()}`,
` ├─ Output: ${cumulative.candidatesTokenCount.toLocaleString()}`,
` ├─ Cached: ${cumulative.cachedContentTokenCount.toLocaleString()}`,
` └─ Overhead: ${overheadTotal.toLocaleString()}`,
` ├─ Model thoughts: ${cumulative.thoughtsTokenCount.toLocaleString()}`,
` └─ Tool-use prompts: ${cumulative.toolUsePromptTokenCount.toLocaleString()}`,
].join('\n');
addMessage({ addMessage({
type: MessageType.INFO, type: MessageType.INFO,
content: `Session duration: ${durationString}`, content: statsContent,
timestamp: new Date(), timestamp: new Date(),
}); });
}, },
@ -477,7 +494,7 @@ Add any other context about the problem here.
toggleCorgiMode, toggleCorgiMode,
config, config,
showToolDescriptions, showToolDescriptions,
session.startTime, session,
], ],
); );

View File

@ -96,6 +96,15 @@ vi.mock('./useLogger.js', () => ({
}), }),
})); }));
const mockStartNewTurn = vi.fn();
const mockAddUsage = vi.fn();
vi.mock('../contexts/SessionContext.js', () => ({
useSessionStats: vi.fn(() => ({
startNewTurn: mockStartNewTurn,
addUsage: mockAddUsage,
})),
}));
vi.mock('./slashCommandProcessor.js', () => ({ vi.mock('./slashCommandProcessor.js', () => ({
handleSlashCommand: vi.fn().mockReturnValue(false), handleSlashCommand: vi.fn().mockReturnValue(false),
})); }));
@ -531,4 +540,63 @@ describe('useGeminiStream', () => {
}); });
}); });
}); });
describe('Session Stats Integration', () => {
it('should call startNewTurn and addUsage for a simple prompt', async () => {
const mockMetadata = { totalTokenCount: 123 };
const mockStream = (async function* () {
yield { type: 'content', value: 'Response' };
yield { type: 'usage_metadata', value: mockMetadata };
})();
mockSendMessageStream.mockReturnValue(mockStream);
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('Hello, world!');
});
expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
expect(mockAddUsage).toHaveBeenCalledTimes(1);
expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
});
it('should only call addUsage for a tool continuation prompt', async () => {
const mockMetadata = { totalTokenCount: 456 };
const mockStream = (async function* () {
yield { type: 'content', value: 'Final Answer' };
yield { type: 'usage_metadata', value: mockMetadata };
})();
mockSendMessageStream.mockReturnValue(mockStream);
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery([{ text: 'tool response' }], {
isContinuation: true,
});
});
expect(mockStartNewTurn).not.toHaveBeenCalled();
expect(mockAddUsage).toHaveBeenCalledTimes(1);
expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
});
it('should not call addUsage if the stream contains no usage metadata', async () => {
// Arrange: A stream that yields content but never a usage_metadata event
const mockStream = (async function* () {
yield { type: 'content', value: 'Some response text' };
})();
mockSendMessageStream.mockReturnValue(mockStream);
const { result } = renderTestHook();
await act(async () => {
await result.current.submitQuery('Query with no usage data');
});
expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
expect(mockAddUsage).not.toHaveBeenCalled();
});
});
}); });

View File

@ -42,6 +42,7 @@ import {
TrackedCompletedToolCall, TrackedCompletedToolCall,
TrackedCancelledToolCall, TrackedCancelledToolCall,
} from './useReactToolScheduler.js'; } from './useReactToolScheduler.js';
import { useSessionStats } from '../contexts/SessionContext.js';
export function mergePartListUnions(list: PartListUnion[]): PartListUnion { export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
const resultParts: PartListUnion = []; const resultParts: PartListUnion = [];
@ -82,6 +83,7 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] = const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null); useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger(); const logger = useLogger();
const { startNewTurn, addUsage } = useSessionStats();
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler( useReactToolScheduler(
@ -390,6 +392,9 @@ export const useGeminiStream = (
case ServerGeminiEventType.ChatCompressed: case ServerGeminiEventType.ChatCompressed:
handleChatCompressionEvent(); handleChatCompressionEvent();
break; break;
case ServerGeminiEventType.UsageMetadata:
addUsage(event.value);
break;
case ServerGeminiEventType.ToolCallConfirmation: case ServerGeminiEventType.ToolCallConfirmation:
case ServerGeminiEventType.ToolCallResponse: case ServerGeminiEventType.ToolCallResponse:
// do nothing // do nothing
@ -412,11 +417,12 @@ export const useGeminiStream = (
handleErrorEvent, handleErrorEvent,
scheduleToolCalls, scheduleToolCalls,
handleChatCompressionEvent, handleChatCompressionEvent,
addUsage,
], ],
); );
const submitQuery = useCallback( const submitQuery = useCallback(
async (query: PartListUnion) => { async (query: PartListUnion, options?: { isContinuation: boolean }) => {
if ( if (
streamingState === StreamingState.Responding || streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation streamingState === StreamingState.WaitingForConfirmation
@ -426,6 +432,10 @@ export const useGeminiStream = (
const userMessageTimestamp = Date.now(); const userMessageTimestamp = Date.now();
setShowHelp(false); setShowHelp(false);
if (!options?.isContinuation) {
startNewTurn();
}
abortControllerRef.current = new AbortController(); abortControllerRef.current = new AbortController();
const abortSignal = abortControllerRef.current.signal; const abortSignal = abortControllerRef.current.signal;
@ -491,6 +501,7 @@ export const useGeminiStream = (
setPendingHistoryItem, setPendingHistoryItem,
setInitError, setInitError,
geminiClient, geminiClient,
startNewTurn,
], ],
); );
@ -576,7 +587,9 @@ export const useGeminiStream = (
); );
markToolsAsSubmitted(callIdsToMarkAsSubmitted); markToolsAsSubmitted(callIdsToMarkAsSubmitted);
submitQuery(mergePartListUnions(responsesToSend)); submitQuery(mergePartListUnions(responsesToSend), {
isContinuation: true,
});
} }
}, [ }, [
toolCalls, toolCalls,

View File

@ -13,14 +13,32 @@ import {
GoogleGenAI, GoogleGenAI,
} from '@google/genai'; } from '@google/genai';
import { GeminiClient } from './client.js'; import { GeminiClient } from './client.js';
import { ContentGenerator } from './contentGenerator.js';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { Turn } from './turn.js';
// --- Mocks --- // --- Mocks ---
const mockChatCreateFn = vi.fn(); const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn(); const mockGenerateContentFn = vi.fn();
const mockEmbedContentFn = vi.fn(); const mockEmbedContentFn = vi.fn();
const mockTurnRunFn = vi.fn();
vi.mock('@google/genai'); vi.mock('@google/genai');
vi.mock('./turn', () => {
// Define a mock class that has the same shape as the real Turn
class MockTurn {
pendingToolCalls = [];
// The run method is a property that holds our mock function
run = mockTurnRunFn;
constructor() {
// The constructor can be empty or do some mock setup
}
}
// Export the mock class as 'Turn'
return { Turn: MockTurn };
});
vi.mock('../config/config.js'); vi.mock('../config/config.js');
vi.mock('./prompts'); vi.mock('./prompts');
@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => {
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent); expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
}); });
}); });
describe('sendMessageStream', () => {
it('should return the turn instance after the stream is complete', async () => {
// Arrange
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = Promise.resolve(mockChat as GeminiChat);
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
);
// Consume the stream manually to get the final return value.
let finalResult: Turn | undefined;
while (true) {
const result = await stream.next();
if (result.done) {
finalResult = result.value;
break;
}
}
// Assert
expect(finalResult).toBeInstanceOf(Turn);
});
});
}); });

View File

@ -174,9 +174,10 @@ export class GeminiClient {
request: PartListUnion, request: PartListUnion,
signal: AbortSignal, signal: AbortSignal,
turns: number = this.MAX_TURNS, turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!turns) { if (!turns) {
return; const chat = await this.chat;
return new Turn(chat);
} }
const compressed = await this.tryCompressChat(); const compressed = await this.tryCompressChat();
@ -193,9 +194,12 @@ export class GeminiClient {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') { if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }]; const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
// turn object will be from the top-level call.
yield* this.sendMessageStream(nextRequest, signal, turns - 1); yield* this.sendMessageStream(nextRequest, signal, turns - 1);
} }
} }
return turn;
} }
private _logApiRequest(model: string, inputTokenCount: number): void { private _logApiRequest(model: string, inputTokenCount: number): void {
@ -423,6 +427,10 @@ export class GeminiClient {
}); });
const result = await retryWithBackoff(apiCall); const result = await retryWithBackoff(apiCall);
console.log(
'Raw API Response in client.ts:',
JSON.stringify(result, null, 2),
);
const durationMs = Date.now() - startTime; const durationMs = Date.now() - startTime;
this._logApiResponse(modelToUse, durationMs, attempt, result); this._logApiResponse(modelToUse, durationMs, attempt, result);
return result; return result;

View File

@ -10,8 +10,14 @@ import {
GeminiEventType, GeminiEventType,
ServerGeminiToolCallRequestEvent, ServerGeminiToolCallRequestEvent,
ServerGeminiErrorEvent, ServerGeminiErrorEvent,
ServerGeminiUsageMetadataEvent,
} from './turn.js'; } from './turn.js';
import { GenerateContentResponse, Part, Content } from '@google/genai'; import {
GenerateContentResponse,
Part,
Content,
GenerateContentResponseUsageMetadata,
} from '@google/genai';
import { reportError } from '../utils/errorReporting.js'; import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
@ -49,6 +55,24 @@ describe('Turn', () => {
}; };
let mockChatInstance: MockedChatInstance; let mockChatInstance: MockedChatInstance;
const mockMetadata1: GenerateContentResponseUsageMetadata = {
promptTokenCount: 10,
candidatesTokenCount: 20,
totalTokenCount: 30,
cachedContentTokenCount: 5,
toolUsePromptTokenCount: 2,
thoughtsTokenCount: 3,
};
const mockMetadata2: GenerateContentResponseUsageMetadata = {
promptTokenCount: 100,
candidatesTokenCount: 200,
totalTokenCount: 300,
cachedContentTokenCount: 50,
toolUsePromptTokenCount: 20,
thoughtsTokenCount: 30,
};
beforeEach(() => { beforeEach(() => {
vi.resetAllMocks(); vi.resetAllMocks();
mockChatInstance = { mockChatInstance = {
@ -96,6 +120,7 @@ describe('Turn', () => {
message: reqParts, message: reqParts,
config: { abortSignal: expect.any(AbortSignal) }, config: { abortSignal: expect.any(AbortSignal) },
}); });
expect(events).toEqual([ expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' }, { type: GeminiEventType.Content, value: ' world' },
@ -208,6 +233,41 @@ describe('Turn', () => {
); );
}); });
it('should yield the last UsageMetadata event from the stream', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'First response' }] } }],
usageMetadata: mockMetadata1,
} as unknown as GenerateContentResponse;
yield {
functionCalls: [{ name: 'aTool' }],
usageMetadata: mockMetadata2,
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Test metadata' }];
for await (const event of turn.run(
reqParts,
new AbortController().signal,
)) {
events.push(event);
}
// There should be a content event, a tool call, and our metadata event
expect(events.length).toBe(3);
const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent;
expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata);
// The value should be the *last* metadata object received.
expect(metadataEvent.value).toEqual(mockMetadata2);
// Also check the public getter
expect(turn.getUsageMetadata()).toEqual(mockMetadata2);
});
it('should handle function calls with undefined name or args', async () => { it('should handle function calls with undefined name or args', async () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
@ -219,7 +279,6 @@ describe('Turn', () => {
} as unknown as GenerateContentResponse; } as unknown as GenerateContentResponse;
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = []; const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run( for await (const event of turn.run(

View File

@ -9,6 +9,7 @@ import {
GenerateContentResponse, GenerateContentResponse,
FunctionCall, FunctionCall,
FunctionDeclaration, FunctionDeclaration,
GenerateContentResponseUsageMetadata,
} from '@google/genai'; } from '@google/genai';
import { import {
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
@ -43,6 +44,7 @@ export enum GeminiEventType {
UserCancelled = 'user_cancelled', UserCancelled = 'user_cancelled',
Error = 'error', Error = 'error',
ChatCompressed = 'chat_compressed', ChatCompressed = 'chat_compressed',
UsageMetadata = 'usage_metadata',
} }
export interface GeminiErrorEventValue { export interface GeminiErrorEventValue {
@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = {
type: GeminiEventType.ChatCompressed; type: GeminiEventType.ChatCompressed;
}; };
export type ServerGeminiUsageMetadataEvent = {
type: GeminiEventType.UsageMetadata;
value: GenerateContentResponseUsageMetadata;
};
// The original union type, now composed of the individual types // The original union type, now composed of the individual types
export type ServerGeminiStreamEvent = export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent | ServerGeminiContentEvent
@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiToolCallConfirmationEvent | ServerGeminiToolCallConfirmationEvent
| ServerGeminiUserCancelledEvent | ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent | ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent; | ServerGeminiChatCompressedEvent
| ServerGeminiUsageMetadataEvent;
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {
@ -118,6 +126,7 @@ export class Turn {
args: Record<string, unknown>; args: Record<string, unknown>;
}>; }>;
private debugResponses: GenerateContentResponse[]; private debugResponses: GenerateContentResponse[];
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
constructor(private readonly chat: GeminiChat) { constructor(private readonly chat: GeminiChat) {
this.pendingToolCalls = []; this.pendingToolCalls = [];
@ -157,6 +166,18 @@ export class Turn {
yield event; yield event;
} }
} }
if (resp.usageMetadata) {
this.lastUsageMetadata =
resp.usageMetadata as GenerateContentResponseUsageMetadata;
}
}
if (this.lastUsageMetadata) {
yield {
type: GeminiEventType.UsageMetadata,
value: this.lastUsageMetadata,
};
} }
} catch (error) { } catch (error) {
if (signal.aborted) { if (signal.aborted) {
@ -197,4 +218,8 @@ export class Turn {
getDebugResponses(): GenerateContentResponse[] { getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses; return this.debugResponses;
} }
getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
return this.lastUsageMetadata;
}
} }