feat: Display initial token usage metrics in /stats (#879)
This commit is contained in:
parent
6484dc9008
commit
7f1252d364
|
@ -48,7 +48,7 @@ import {
|
|||
} from '@gemini-cli/core';
|
||||
import { useLogger } from './hooks/useLogger.js';
|
||||
import { StreamingContext } from './contexts/StreamingContext.js';
|
||||
import { SessionProvider } from './contexts/SessionContext.js';
|
||||
import { SessionStatsProvider } from './contexts/SessionContext.js';
|
||||
import { useGitBranchName } from './hooks/useGitBranchName.js';
|
||||
|
||||
const CTRL_C_PROMPT_DURATION_MS = 1000;
|
||||
|
@ -60,9 +60,9 @@ interface AppProps {
|
|||
}
|
||||
|
||||
export const AppWrapper = (props: AppProps) => (
|
||||
<SessionProvider>
|
||||
<SessionStatsProvider>
|
||||
<App {...props} />
|
||||
</SessionProvider>
|
||||
</SessionStatsProvider>
|
||||
);
|
||||
|
||||
const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
|
||||
|
|
|
@ -4,26 +4,181 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type MutableRefObject } from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { Text } from 'ink';
|
||||
import { SessionProvider, useSession } from './SessionContext.js';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { act } from 'react-dom/test-utils';
|
||||
import { SessionStatsProvider, useSessionStats } from './SessionContext.js';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { GenerateContentResponseUsageMetadata } from '@google/genai';
|
||||
|
||||
const TestComponent = () => {
|
||||
const { startTime } = useSession();
|
||||
return <Text>{startTime.toISOString()}</Text>;
|
||||
// Mock data that simulates what the Gemini API would return.
|
||||
const mockMetadata1: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 100,
|
||||
candidatesTokenCount: 200,
|
||||
totalTokenCount: 300,
|
||||
cachedContentTokenCount: 50,
|
||||
toolUsePromptTokenCount: 10,
|
||||
thoughtsTokenCount: 20,
|
||||
};
|
||||
|
||||
describe('SessionContext', () => {
|
||||
it('should provide a start time', () => {
|
||||
const { lastFrame } = render(
|
||||
<SessionProvider>
|
||||
<TestComponent />
|
||||
</SessionProvider>,
|
||||
const mockMetadata2: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30,
|
||||
cachedContentTokenCount: 5,
|
||||
toolUsePromptTokenCount: 1,
|
||||
thoughtsTokenCount: 2,
|
||||
};
|
||||
|
||||
/**
|
||||
* A test harness component that uses the hook and exposes the context value
|
||||
* via a mutable ref. This allows us to interact with the context's functions
|
||||
* and assert against its state directly in our tests.
|
||||
*/
|
||||
const TestHarness = ({
|
||||
contextRef,
|
||||
}: {
|
||||
contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>;
|
||||
}) => {
|
||||
contextRef.current = useSessionStats();
|
||||
return null;
|
||||
};
|
||||
|
||||
describe('SessionStatsContext', () => {
|
||||
it('should provide the correct initial state', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
const frameText = lastFrame();
|
||||
// Check if the output is a valid ISO string, which confirms it's a Date object.
|
||||
expect(new Date(frameText!).toString()).not.toBe('Invalid Date');
|
||||
const stats = contextRef.current?.stats;
|
||||
|
||||
expect(stats?.sessionStartTime).toBeInstanceOf(Date);
|
||||
expect(stats?.lastTurn).toBeNull();
|
||||
expect(stats?.cumulative.turnCount).toBe(0);
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(0);
|
||||
expect(stats?.cumulative.promptTokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should increment turnCount when startNewTurn is called', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
act(() => {
|
||||
contextRef.current?.startNewTurn();
|
||||
});
|
||||
|
||||
const stats = contextRef.current?.stats;
|
||||
expect(stats?.cumulative.turnCount).toBe(1);
|
||||
// Ensure token counts are unaffected
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should aggregate token usage correctly when addUsage is called', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
act(() => {
|
||||
contextRef.current?.addUsage(mockMetadata1);
|
||||
});
|
||||
|
||||
const stats = contextRef.current?.stats;
|
||||
|
||||
// Check that token counts are updated
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(
|
||||
mockMetadata1.totalTokenCount ?? 0,
|
||||
);
|
||||
expect(stats?.cumulative.promptTokenCount).toBe(
|
||||
mockMetadata1.promptTokenCount ?? 0,
|
||||
);
|
||||
|
||||
// Check that turn count is NOT incremented
|
||||
expect(stats?.cumulative.turnCount).toBe(0);
|
||||
|
||||
// Check that lastTurn is updated
|
||||
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1);
|
||||
});
|
||||
|
||||
it('should correctly track a full logical turn with multiple API calls', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
// 1. User starts a new turn
|
||||
act(() => {
|
||||
contextRef.current?.startNewTurn();
|
||||
});
|
||||
|
||||
// 2. First API call (e.g., prompt with a tool request)
|
||||
act(() => {
|
||||
contextRef.current?.addUsage(mockMetadata1);
|
||||
});
|
||||
|
||||
// 3. Second API call (e.g., sending tool response back)
|
||||
act(() => {
|
||||
contextRef.current?.addUsage(mockMetadata2);
|
||||
});
|
||||
|
||||
const stats = contextRef.current?.stats;
|
||||
|
||||
// Turn count should only be 1
|
||||
expect(stats?.cumulative.turnCount).toBe(1);
|
||||
|
||||
// These fields should be the SUM of both calls
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30
|
||||
expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20
|
||||
expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2
|
||||
|
||||
// These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true
|
||||
expect(stats?.cumulative.promptTokenCount).toBe(100);
|
||||
expect(stats?.cumulative.cachedContentTokenCount).toBe(50);
|
||||
expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10);
|
||||
|
||||
// Last turn should hold the metadata from the most recent call
|
||||
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2);
|
||||
});
|
||||
|
||||
it('should throw an error when useSessionStats is used outside of a provider', () => {
|
||||
// Suppress the expected console error during this test.
|
||||
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
const contextRef = { current: undefined };
|
||||
|
||||
// We expect rendering to fail, which React will catch and log as an error.
|
||||
render(<TestHarness contextRef={contextRef} />);
|
||||
|
||||
// Assert that the first argument of the first call to console.error
|
||||
// contains the expected message. This is more robust than checking
|
||||
// the exact arguments, which can be affected by React/JSDOM internals.
|
||||
expect(errorSpy.mock.calls[0][0]).toContain(
|
||||
'useSessionStats must be used within a SessionStatsProvider',
|
||||
);
|
||||
|
||||
errorSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -4,35 +4,140 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useMemo,
|
||||
useCallback,
|
||||
} from 'react';
|
||||
|
||||
interface SessionContextType {
|
||||
startTime: Date;
|
||||
import { type GenerateContentResponseUsageMetadata } from '@google/genai';
|
||||
|
||||
// --- Interface Definitions ---
|
||||
|
||||
interface CumulativeStats {
|
||||
turnCount: number;
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
cachedContentTokenCount: number;
|
||||
toolUsePromptTokenCount: number;
|
||||
thoughtsTokenCount: number;
|
||||
}
|
||||
|
||||
const SessionContext = createContext<SessionContextType | null>(null);
|
||||
interface LastTurnStats {
|
||||
metadata: GenerateContentResponseUsageMetadata;
|
||||
// TODO(abhipatel12): Add apiTime, etc. here in a future step.
|
||||
}
|
||||
|
||||
export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({
|
||||
interface SessionStatsState {
|
||||
sessionStartTime: Date;
|
||||
cumulative: CumulativeStats;
|
||||
lastTurn: LastTurnStats | null;
|
||||
isNewTurnForAggregation: boolean;
|
||||
}
|
||||
|
||||
// Defines the final "value" of our context, including the state
|
||||
// and the functions to update it.
|
||||
interface SessionStatsContextValue {
|
||||
stats: SessionStatsState;
|
||||
startNewTurn: () => void;
|
||||
addUsage: (metadata: GenerateContentResponseUsageMetadata) => void;
|
||||
}
|
||||
|
||||
// --- Context Definition ---
|
||||
|
||||
const SessionStatsContext = createContext<SessionStatsContextValue | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
// --- Provider Component ---
|
||||
|
||||
export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({
|
||||
children,
|
||||
}) => {
|
||||
const [startTime] = useState(new Date());
|
||||
const [stats, setStats] = useState<SessionStatsState>({
|
||||
sessionStartTime: new Date(),
|
||||
cumulative: {
|
||||
turnCount: 0,
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0,
|
||||
cachedContentTokenCount: 0,
|
||||
toolUsePromptTokenCount: 0,
|
||||
thoughtsTokenCount: 0,
|
||||
},
|
||||
lastTurn: null,
|
||||
isNewTurnForAggregation: true,
|
||||
});
|
||||
|
||||
// A single, internal worker function to handle all metadata aggregation.
|
||||
const aggregateTokens = useCallback(
|
||||
(metadata: GenerateContentResponseUsageMetadata) => {
|
||||
setStats((prevState) => {
|
||||
const { isNewTurnForAggregation } = prevState;
|
||||
const newCumulative = { ...prevState.cumulative };
|
||||
|
||||
newCumulative.candidatesTokenCount +=
|
||||
metadata.candidatesTokenCount ?? 0;
|
||||
newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0;
|
||||
newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0;
|
||||
|
||||
if (isNewTurnForAggregation) {
|
||||
newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0;
|
||||
newCumulative.cachedContentTokenCount +=
|
||||
metadata.cachedContentTokenCount ?? 0;
|
||||
newCumulative.toolUsePromptTokenCount +=
|
||||
metadata.toolUsePromptTokenCount ?? 0;
|
||||
}
|
||||
|
||||
return {
|
||||
...prevState,
|
||||
cumulative: newCumulative,
|
||||
lastTurn: { metadata },
|
||||
isNewTurnForAggregation: false,
|
||||
};
|
||||
});
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const startNewTurn = useCallback(() => {
|
||||
setStats((prevState) => ({
|
||||
...prevState,
|
||||
cumulative: {
|
||||
...prevState.cumulative,
|
||||
turnCount: prevState.cumulative.turnCount + 1,
|
||||
},
|
||||
isNewTurnForAggregation: true,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
const value = useMemo(
|
||||
() => ({
|
||||
startTime,
|
||||
stats,
|
||||
startNewTurn,
|
||||
addUsage: aggregateTokens,
|
||||
}),
|
||||
[startTime],
|
||||
[stats, startNewTurn, aggregateTokens],
|
||||
);
|
||||
|
||||
return (
|
||||
<SessionContext.Provider value={value}>{children}</SessionContext.Provider>
|
||||
<SessionStatsContext.Provider value={value}>
|
||||
{children}
|
||||
</SessionStatsContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const useSession = () => {
|
||||
const context = useContext(SessionContext);
|
||||
if (!context) {
|
||||
throw new Error('useSession must be used within a SessionProvider');
|
||||
// --- Consumer Hook ---
|
||||
|
||||
export const useSessionStats = () => {
|
||||
const context = useContext(SessionStatsContext);
|
||||
if (context === undefined) {
|
||||
throw new Error(
|
||||
'useSessionStats must be used within a SessionStatsProvider',
|
||||
);
|
||||
}
|
||||
return context;
|
||||
};
|
||||
|
|
|
@ -61,13 +61,13 @@ import {
|
|||
MCPServerStatus,
|
||||
getMCPServerStatus,
|
||||
} from '@gemini-cli/core';
|
||||
import { useSession } from '../contexts/SessionContext.js';
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
|
||||
import * as ShowMemoryCommandModule from './useShowMemoryCommand.js';
|
||||
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
|
||||
|
||||
vi.mock('../contexts/SessionContext.js', () => ({
|
||||
useSession: vi.fn(),
|
||||
useSessionStats: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('./useShowMemoryCommand.js', () => ({
|
||||
|
@ -89,7 +89,7 @@ describe('useSlashCommandProcessor', () => {
|
|||
let mockPerformMemoryRefresh: ReturnType<typeof vi.fn>;
|
||||
let mockConfig: Config;
|
||||
let mockCorgiMode: ReturnType<typeof vi.fn>;
|
||||
const mockUseSession = useSession as Mock;
|
||||
const mockUseSessionStats = useSessionStats as Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
mockAddItem = vi.fn();
|
||||
|
@ -105,8 +105,19 @@ describe('useSlashCommandProcessor', () => {
|
|||
getModel: vi.fn(() => 'test-model'),
|
||||
} as unknown as Config;
|
||||
mockCorgiMode = vi.fn();
|
||||
mockUseSession.mockReturnValue({
|
||||
startTime: new Date('2025-01-01T00:00:00.000Z'),
|
||||
mockUseSessionStats.mockReturnValue({
|
||||
stats: {
|
||||
sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
|
||||
cumulative: {
|
||||
turnCount: 0,
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0,
|
||||
cachedContentTokenCount: 0,
|
||||
toolUsePromptTokenCount: 0,
|
||||
thoughtsTokenCount: 0,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
(open as Mock).mockClear();
|
||||
|
@ -240,29 +251,55 @@ describe('useSlashCommandProcessor', () => {
|
|||
});
|
||||
|
||||
describe('/stats command', () => {
|
||||
it('should show the session duration', async () => {
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
|
||||
// Mock current time
|
||||
const mockDate = new Date('2025-01-01T00:01:05.000Z');
|
||||
vi.setSystemTime(mockDate);
|
||||
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/stats');
|
||||
it('should show detailed session statistics', async () => {
|
||||
// Arrange
|
||||
mockUseSessionStats.mockReturnValue({
|
||||
stats: {
|
||||
sessionStartTime: new Date('2025-01-01T00:00:00.000Z'),
|
||||
cumulative: {
|
||||
totalTokenCount: 900,
|
||||
promptTokenCount: 200,
|
||||
candidatesTokenCount: 400,
|
||||
cachedContentTokenCount: 100,
|
||||
turnCount: 1,
|
||||
toolUsePromptTokenCount: 50,
|
||||
thoughtsTokenCount: 150,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
const mockDate = new Date('2025-01-01T01:02:03.000Z'); // 1h 2m 3s duration
|
||||
vi.setSystemTime(mockDate);
|
||||
|
||||
// Act
|
||||
await act(async () => {
|
||||
handleSlashCommand('/stats');
|
||||
});
|
||||
|
||||
// Assert
|
||||
const expectedContent = [
|
||||
` ⎿ Total duration (wall): 1h 2m 3s`,
|
||||
` Total Token usage:`,
|
||||
` Turns: 1`,
|
||||
` Total: 900`,
|
||||
` ├─ Input: 200`,
|
||||
` ├─ Output: 400`,
|
||||
` ├─ Cached: 100`,
|
||||
` └─ Overhead: 200`,
|
||||
` ├─ Model thoughts: 150`,
|
||||
` └─ Tool-use prompts: 50`,
|
||||
].join('\n');
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
2, // Called after the user message
|
||||
expect.objectContaining({
|
||||
type: MessageType.INFO,
|
||||
text: 'Session duration: 1m 5s',
|
||||
text: expectedContent,
|
||||
}),
|
||||
expect.any(Number),
|
||||
);
|
||||
expect(commandResult).toBe(true);
|
||||
|
||||
// Restore system time
|
||||
vi.useRealTimers();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -11,7 +11,7 @@ import process from 'node:process';
|
|||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { Config, MCPServerStatus, getMCPServerStatus } from '@gemini-cli/core';
|
||||
import { Message, MessageType, HistoryItemWithoutId } from '../types.js';
|
||||
import { useSession } from '../contexts/SessionContext.js';
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
import { createShowMemoryAction } from './useShowMemoryCommand.js';
|
||||
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
|
||||
import { formatMemoryUsage } from '../utils/formatters.js';
|
||||
|
@ -50,8 +50,7 @@ export const useSlashCommandProcessor = (
|
|||
toggleCorgiMode: () => void,
|
||||
showToolDescriptions: boolean = false,
|
||||
) => {
|
||||
const session = useSession();
|
||||
|
||||
const session = useSessionStats();
|
||||
const addMessage = useCallback(
|
||||
(message: Message) => {
|
||||
// Convert Message to HistoryItemWithoutId
|
||||
|
@ -147,7 +146,9 @@ export const useSlashCommandProcessor = (
|
|||
description: 'check session stats',
|
||||
action: (_mainCommand, _subCommand, _args) => {
|
||||
const now = new Date();
|
||||
const duration = now.getTime() - session.startTime.getTime();
|
||||
const { sessionStartTime, cumulative } = session.stats;
|
||||
|
||||
const duration = now.getTime() - sessionStartTime.getTime();
|
||||
const durationInSeconds = Math.floor(duration / 1000);
|
||||
const hours = Math.floor(durationInSeconds / 3600);
|
||||
const minutes = Math.floor((durationInSeconds % 3600) / 60);
|
||||
|
@ -161,9 +162,25 @@ export const useSlashCommandProcessor = (
|
|||
.filter(Boolean)
|
||||
.join(' ');
|
||||
|
||||
const overheadTotal =
|
||||
cumulative.thoughtsTokenCount + cumulative.toolUsePromptTokenCount;
|
||||
|
||||
const statsContent = [
|
||||
` ⎿ Total duration (wall): ${durationString}`,
|
||||
` Total Token usage:`,
|
||||
` Turns: ${cumulative.turnCount.toLocaleString()}`,
|
||||
` Total: ${cumulative.totalTokenCount.toLocaleString()}`,
|
||||
` ├─ Input: ${cumulative.promptTokenCount.toLocaleString()}`,
|
||||
` ├─ Output: ${cumulative.candidatesTokenCount.toLocaleString()}`,
|
||||
` ├─ Cached: ${cumulative.cachedContentTokenCount.toLocaleString()}`,
|
||||
` └─ Overhead: ${overheadTotal.toLocaleString()}`,
|
||||
` ├─ Model thoughts: ${cumulative.thoughtsTokenCount.toLocaleString()}`,
|
||||
` └─ Tool-use prompts: ${cumulative.toolUsePromptTokenCount.toLocaleString()}`,
|
||||
].join('\n');
|
||||
|
||||
addMessage({
|
||||
type: MessageType.INFO,
|
||||
content: `Session duration: ${durationString}`,
|
||||
content: statsContent,
|
||||
timestamp: new Date(),
|
||||
});
|
||||
},
|
||||
|
@ -477,7 +494,7 @@ Add any other context about the problem here.
|
|||
toggleCorgiMode,
|
||||
config,
|
||||
showToolDescriptions,
|
||||
session.startTime,
|
||||
session,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
|
@ -96,6 +96,15 @@ vi.mock('./useLogger.js', () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
const mockStartNewTurn = vi.fn();
|
||||
const mockAddUsage = vi.fn();
|
||||
vi.mock('../contexts/SessionContext.js', () => ({
|
||||
useSessionStats: vi.fn(() => ({
|
||||
startNewTurn: mockStartNewTurn,
|
||||
addUsage: mockAddUsage,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('./slashCommandProcessor.js', () => ({
|
||||
handleSlashCommand: vi.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
@ -531,4 +540,63 @@ describe('useGeminiStream', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Session Stats Integration', () => {
|
||||
it('should call startNewTurn and addUsage for a simple prompt', async () => {
|
||||
const mockMetadata = { totalTokenCount: 123 };
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Response' };
|
||||
yield { type: 'usage_metadata', value: mockMetadata };
|
||||
})();
|
||||
mockSendMessageStream.mockReturnValue(mockStream);
|
||||
|
||||
const { result } = renderTestHook();
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Hello, world!');
|
||||
});
|
||||
|
||||
expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
|
||||
expect(mockAddUsage).toHaveBeenCalledTimes(1);
|
||||
expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
|
||||
});
|
||||
|
||||
it('should only call addUsage for a tool continuation prompt', async () => {
|
||||
const mockMetadata = { totalTokenCount: 456 };
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Final Answer' };
|
||||
yield { type: 'usage_metadata', value: mockMetadata };
|
||||
})();
|
||||
mockSendMessageStream.mockReturnValue(mockStream);
|
||||
|
||||
const { result } = renderTestHook();
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery([{ text: 'tool response' }], {
|
||||
isContinuation: true,
|
||||
});
|
||||
});
|
||||
|
||||
expect(mockStartNewTurn).not.toHaveBeenCalled();
|
||||
expect(mockAddUsage).toHaveBeenCalledTimes(1);
|
||||
expect(mockAddUsage).toHaveBeenCalledWith(mockMetadata);
|
||||
});
|
||||
|
||||
it('should not call addUsage if the stream contains no usage metadata', async () => {
|
||||
// Arrange: A stream that yields content but never a usage_metadata event
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Some response text' };
|
||||
})();
|
||||
mockSendMessageStream.mockReturnValue(mockStream);
|
||||
|
||||
const { result } = renderTestHook();
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Query with no usage data');
|
||||
});
|
||||
|
||||
expect(mockStartNewTurn).toHaveBeenCalledTimes(1);
|
||||
expect(mockAddUsage).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -42,6 +42,7 @@ import {
|
|||
TrackedCompletedToolCall,
|
||||
TrackedCancelledToolCall,
|
||||
} from './useReactToolScheduler.js';
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
|
||||
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
|
||||
const resultParts: PartListUnion = [];
|
||||
|
@ -82,6 +83,7 @@ export const useGeminiStream = (
|
|||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const logger = useLogger();
|
||||
const { startNewTurn, addUsage } = useSessionStats();
|
||||
|
||||
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||
useReactToolScheduler(
|
||||
|
@ -390,6 +392,9 @@ export const useGeminiStream = (
|
|||
case ServerGeminiEventType.ChatCompressed:
|
||||
handleChatCompressionEvent();
|
||||
break;
|
||||
case ServerGeminiEventType.UsageMetadata:
|
||||
addUsage(event.value);
|
||||
break;
|
||||
case ServerGeminiEventType.ToolCallConfirmation:
|
||||
case ServerGeminiEventType.ToolCallResponse:
|
||||
// do nothing
|
||||
|
@ -412,11 +417,12 @@ export const useGeminiStream = (
|
|||
handleErrorEvent,
|
||||
scheduleToolCalls,
|
||||
handleChatCompressionEvent,
|
||||
addUsage,
|
||||
],
|
||||
);
|
||||
|
||||
const submitQuery = useCallback(
|
||||
async (query: PartListUnion) => {
|
||||
async (query: PartListUnion, options?: { isContinuation: boolean }) => {
|
||||
if (
|
||||
streamingState === StreamingState.Responding ||
|
||||
streamingState === StreamingState.WaitingForConfirmation
|
||||
|
@ -426,6 +432,10 @@ export const useGeminiStream = (
|
|||
const userMessageTimestamp = Date.now();
|
||||
setShowHelp(false);
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
startNewTurn();
|
||||
}
|
||||
|
||||
abortControllerRef.current = new AbortController();
|
||||
const abortSignal = abortControllerRef.current.signal;
|
||||
|
||||
|
@ -491,6 +501,7 @@ export const useGeminiStream = (
|
|||
setPendingHistoryItem,
|
||||
setInitError,
|
||||
geminiClient,
|
||||
startNewTurn,
|
||||
],
|
||||
);
|
||||
|
||||
|
@ -576,7 +587,9 @@ export const useGeminiStream = (
|
|||
);
|
||||
|
||||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||
submitQuery(mergePartListUnions(responsesToSend));
|
||||
submitQuery(mergePartListUnions(responsesToSend), {
|
||||
isContinuation: true,
|
||||
});
|
||||
}
|
||||
}, [
|
||||
toolCalls,
|
||||
|
|
|
@ -13,14 +13,32 @@ import {
|
|||
GoogleGenAI,
|
||||
} from '@google/genai';
|
||||
import { GeminiClient } from './client.js';
|
||||
import { ContentGenerator } from './contentGenerator.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { Turn } from './turn.js';
|
||||
|
||||
// --- Mocks ---
|
||||
const mockChatCreateFn = vi.fn();
|
||||
const mockGenerateContentFn = vi.fn();
|
||||
const mockEmbedContentFn = vi.fn();
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('./turn', () => {
|
||||
// Define a mock class that has the same shape as the real Turn
|
||||
class MockTurn {
|
||||
pendingToolCalls = [];
|
||||
// The run method is a property that holds our mock function
|
||||
run = mockTurnRunFn;
|
||||
|
||||
constructor() {
|
||||
// The constructor can be empty or do some mock setup
|
||||
}
|
||||
}
|
||||
// Export the mock class as 'Turn'
|
||||
return { Turn: MockTurn };
|
||||
});
|
||||
|
||||
vi.mock('../config/config.js');
|
||||
vi.mock('./prompts');
|
||||
|
@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => {
|
|||
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
it('should return the turn instance after the stream is complete', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = Promise.resolve(mockChat as GeminiChat);
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Consume the stream manually to get the final return value.
|
||||
let finalResult: Turn | undefined;
|
||||
while (true) {
|
||||
const result = await stream.next();
|
||||
if (result.done) {
|
||||
finalResult = result.value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Assert
|
||||
expect(finalResult).toBeInstanceOf(Turn);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -174,9 +174,10 @@ export class GeminiClient {
|
|||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
turns: number = this.MAX_TURNS,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!turns) {
|
||||
return;
|
||||
const chat = await this.chat;
|
||||
return new Turn(chat);
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat();
|
||||
|
@ -193,9 +194,12 @@ export class GeminiClient {
|
|||
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, but the final
|
||||
// turn object will be from the top-level call.
|
||||
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
private _logApiRequest(model: string, inputTokenCount: number): void {
|
||||
|
@ -423,6 +427,10 @@ export class GeminiClient {
|
|||
});
|
||||
|
||||
const result = await retryWithBackoff(apiCall);
|
||||
console.log(
|
||||
'Raw API Response in client.ts:',
|
||||
JSON.stringify(result, null, 2),
|
||||
);
|
||||
const durationMs = Date.now() - startTime;
|
||||
this._logApiResponse(modelToUse, durationMs, attempt, result);
|
||||
return result;
|
||||
|
|
|
@ -10,8 +10,14 @@ import {
|
|||
GeminiEventType,
|
||||
ServerGeminiToolCallRequestEvent,
|
||||
ServerGeminiErrorEvent,
|
||||
ServerGeminiUsageMetadataEvent,
|
||||
} from './turn.js';
|
||||
import { GenerateContentResponse, Part, Content } from '@google/genai';
|
||||
import {
|
||||
GenerateContentResponse,
|
||||
Part,
|
||||
Content,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
} from '@google/genai';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
|
||||
|
@ -49,6 +55,24 @@ describe('Turn', () => {
|
|||
};
|
||||
let mockChatInstance: MockedChatInstance;
|
||||
|
||||
const mockMetadata1: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30,
|
||||
cachedContentTokenCount: 5,
|
||||
toolUsePromptTokenCount: 2,
|
||||
thoughtsTokenCount: 3,
|
||||
};
|
||||
|
||||
const mockMetadata2: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 100,
|
||||
candidatesTokenCount: 200,
|
||||
totalTokenCount: 300,
|
||||
cachedContentTokenCount: 50,
|
||||
toolUsePromptTokenCount: 20,
|
||||
thoughtsTokenCount: 30,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockChatInstance = {
|
||||
|
@ -96,6 +120,7 @@ describe('Turn', () => {
|
|||
message: reqParts,
|
||||
config: { abortSignal: expect.any(AbortSignal) },
|
||||
});
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Hello' },
|
||||
{ type: GeminiEventType.Content, value: ' world' },
|
||||
|
@ -208,6 +233,41 @@ describe('Turn', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should yield the last UsageMetadata event from the stream', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'First response' }] } }],
|
||||
usageMetadata: mockMetadata1,
|
||||
} as unknown as GenerateContentResponse;
|
||||
yield {
|
||||
functionCalls: [{ name: 'aTool' }],
|
||||
usageMetadata: mockMetadata2,
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test metadata' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// There should be a content event, a tool call, and our metadata event
|
||||
expect(events.length).toBe(3);
|
||||
|
||||
const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent;
|
||||
expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata);
|
||||
|
||||
// The value should be the *last* metadata object received.
|
||||
expect(metadataEvent.value).toEqual(mockMetadata2);
|
||||
|
||||
// Also check the public getter
|
||||
expect(turn.getUsageMetadata()).toEqual(mockMetadata2);
|
||||
});
|
||||
|
||||
it('should handle function calls with undefined name or args', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
|
@ -219,7 +279,6 @@ describe('Turn', () => {
|
|||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
||||
for await (const event of turn.run(
|
||||
|
|
|
@ -9,6 +9,7 @@ import {
|
|||
GenerateContentResponse,
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
ToolCallConfirmationDetails,
|
||||
|
@ -43,6 +44,7 @@ export enum GeminiEventType {
|
|||
UserCancelled = 'user_cancelled',
|
||||
Error = 'error',
|
||||
ChatCompressed = 'chat_compressed',
|
||||
UsageMetadata = 'usage_metadata',
|
||||
}
|
||||
|
||||
export interface GeminiErrorEventValue {
|
||||
|
@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = {
|
|||
type: GeminiEventType.ChatCompressed;
|
||||
};
|
||||
|
||||
export type ServerGeminiUsageMetadataEvent = {
|
||||
type: GeminiEventType.UsageMetadata;
|
||||
value: GenerateContentResponseUsageMetadata;
|
||||
};
|
||||
|
||||
// The original union type, now composed of the individual types
|
||||
export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiContentEvent
|
||||
|
@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent =
|
|||
| ServerGeminiToolCallConfirmationEvent
|
||||
| ServerGeminiUserCancelledEvent
|
||||
| ServerGeminiErrorEvent
|
||||
| ServerGeminiChatCompressedEvent;
|
||||
| ServerGeminiChatCompressedEvent
|
||||
| ServerGeminiUsageMetadataEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
|
@ -118,6 +126,7 @@ export class Turn {
|
|||
args: Record<string, unknown>;
|
||||
}>;
|
||||
private debugResponses: GenerateContentResponse[];
|
||||
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
|
||||
|
||||
constructor(private readonly chat: GeminiChat) {
|
||||
this.pendingToolCalls = [];
|
||||
|
@ -157,6 +166,18 @@ export class Turn {
|
|||
yield event;
|
||||
}
|
||||
}
|
||||
|
||||
if (resp.usageMetadata) {
|
||||
this.lastUsageMetadata =
|
||||
resp.usageMetadata as GenerateContentResponseUsageMetadata;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.lastUsageMetadata) {
|
||||
yield {
|
||||
type: GeminiEventType.UsageMetadata,
|
||||
value: this.lastUsageMetadata,
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
|
@ -197,4 +218,8 @@ export class Turn {
|
|||
getDebugResponses(): GenerateContentResponse[] {
|
||||
return this.debugResponses;
|
||||
}
|
||||
|
||||
getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
|
||||
return this.lastUsageMetadata;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue