refactor: maintain 1 GeminiChat per GeminiClient (#710)
This commit is contained in:
parent
447826ab40
commit
74801e9004
|
@ -39,7 +39,7 @@ describe('runNonInteractive', () => {
|
||||||
sendMessageStream: vi.fn(),
|
sendMessageStream: vi.fn(),
|
||||||
};
|
};
|
||||||
mockGeminiClient = {
|
mockGeminiClient = {
|
||||||
startChat: vi.fn().mockResolvedValue(mockChat),
|
getChat: vi.fn().mockResolvedValue(mockChat),
|
||||||
} as unknown as GeminiClient;
|
} as unknown as GeminiClient;
|
||||||
mockToolRegistry = {
|
mockToolRegistry = {
|
||||||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||||
|
@ -80,7 +80,6 @@ describe('runNonInteractive', () => {
|
||||||
|
|
||||||
await runNonInteractive(mockConfig, 'Test input');
|
await runNonInteractive(mockConfig, 'Test input');
|
||||||
|
|
||||||
expect(mockGeminiClient.startChat).toHaveBeenCalled();
|
|
||||||
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
|
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
|
||||||
message: [{ text: 'Test input' }],
|
message: [{ text: 'Test input' }],
|
||||||
config: {
|
config: {
|
||||||
|
|
|
@ -42,7 +42,7 @@ export async function runNonInteractive(
|
||||||
const geminiClient = new GeminiClient(config);
|
const geminiClient = new GeminiClient(config);
|
||||||
const toolRegistry: ToolRegistry = await config.getToolRegistry();
|
const toolRegistry: ToolRegistry = await config.getToolRegistry();
|
||||||
|
|
||||||
const chat = await geminiClient.startChat();
|
const chat = await geminiClient.getChat();
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
|
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
|
||||||
|
|
||||||
|
|
|
@ -405,10 +405,9 @@ describe('useGeminiStream', () => {
|
||||||
} as TrackedCancelledToolCall,
|
} as TrackedCancelledToolCall,
|
||||||
];
|
];
|
||||||
|
|
||||||
let hookResult: any;
|
const hookResult = await act(async () =>
|
||||||
await act(async () => {
|
renderTestHook(simplifiedToolCalls),
|
||||||
hookResult = renderTestHook(simplifiedToolCalls);
|
);
|
||||||
});
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
mockMarkToolsAsSubmitted,
|
mockMarkToolsAsSubmitted,
|
||||||
|
@ -431,9 +430,8 @@ describe('useGeminiStream', () => {
|
||||||
toolCall2ResponseParts,
|
toolCall2ResponseParts,
|
||||||
]);
|
]);
|
||||||
expect(localMockSendMessageStream).toHaveBeenCalledWith(
|
expect(localMockSendMessageStream).toHaveBeenCalledWith(
|
||||||
expect.anything(),
|
|
||||||
expectedMergedResponse,
|
expectedMergedResponse,
|
||||||
expect.anything(),
|
expect.any(AbortSignal),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -17,7 +17,6 @@ import {
|
||||||
Config,
|
Config,
|
||||||
MessageSenderType,
|
MessageSenderType,
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
GeminiChat,
|
|
||||||
} from '@gemini-code/core';
|
} from '@gemini-code/core';
|
||||||
import { type PartListUnion } from '@google/genai';
|
import { type PartListUnion } from '@google/genai';
|
||||||
import {
|
import {
|
||||||
|
@ -76,7 +75,6 @@ export const useGeminiStream = (
|
||||||
) => {
|
) => {
|
||||||
const [initError, setInitError] = useState<string | null>(null);
|
const [initError, setInitError] = useState<string | null>(null);
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
const chatSessionRef = useRef<GeminiChat | null>(null);
|
|
||||||
const geminiClientRef = useRef<GeminiClient | null>(null);
|
const geminiClientRef = useRef<GeminiClient | null>(null);
|
||||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||||
|
@ -256,31 +254,6 @@ export const useGeminiStream = (
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
const ensureChatSession = useCallback(async (): Promise<{
|
|
||||||
client: GeminiClient | null;
|
|
||||||
chat: GeminiChat | null;
|
|
||||||
}> => {
|
|
||||||
const currentClient = geminiClientRef.current;
|
|
||||||
if (!currentClient) {
|
|
||||||
const errorMsg = 'Gemini client is not available.';
|
|
||||||
setInitError(errorMsg);
|
|
||||||
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
|
|
||||||
return { client: null, chat: null };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!chatSessionRef.current) {
|
|
||||||
try {
|
|
||||||
chatSessionRef.current = await currentClient.startChat();
|
|
||||||
} catch (err: unknown) {
|
|
||||||
const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`;
|
|
||||||
setInitError(errorMsg);
|
|
||||||
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
|
|
||||||
return { client: currentClient, chat: null };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return { client: currentClient, chat: chatSessionRef.current };
|
|
||||||
}, [addItem]);
|
|
||||||
|
|
||||||
// --- Stream Event Handlers ---
|
// --- Stream Event Handlers ---
|
||||||
|
|
||||||
const handleContentEvent = useCallback(
|
const handleContentEvent = useCallback(
|
||||||
|
@ -444,9 +417,12 @@ export const useGeminiStream = (
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { client, chat } = await ensureChatSession();
|
const client = geminiClientRef.current;
|
||||||
|
|
||||||
if (!client || !chat) {
|
if (!client) {
|
||||||
|
const errorMsg = 'Gemini client is not available.';
|
||||||
|
setInitError(errorMsg);
|
||||||
|
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -454,7 +430,7 @@ export const useGeminiStream = (
|
||||||
setInitError(null);
|
setInitError(null);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const stream = client.sendMessageStream(chat, queryToSend, abortSignal);
|
const stream = client.sendMessageStream(queryToSend, abortSignal);
|
||||||
const processingStatus = await processGeminiStreamEvents(
|
const processingStatus = await processGeminiStreamEvents(
|
||||||
stream,
|
stream,
|
||||||
userMessageTimestamp,
|
userMessageTimestamp,
|
||||||
|
@ -487,7 +463,6 @@ export const useGeminiStream = (
|
||||||
streamingState,
|
streamingState,
|
||||||
setShowHelp,
|
setShowHelp,
|
||||||
prepareQueryForGemini,
|
prepareQueryForGemini,
|
||||||
ensureChatSession,
|
|
||||||
processGeminiStreamEvents,
|
processGeminiStreamEvents,
|
||||||
pendingHistoryItemRef,
|
pendingHistoryItemRef,
|
||||||
addItem,
|
addItem,
|
||||||
|
|
|
@ -35,6 +35,7 @@ vi.mock('../tools/memoryTool', () => ({
|
||||||
setGeminiMdFilename: vi.fn(),
|
setGeminiMdFilename: vi.fn(),
|
||||||
getCurrentGeminiMdFilename: vi.fn(() => 'GEMINI.md'), // Mock the original filename
|
getCurrentGeminiMdFilename: vi.fn(() => 'GEMINI.md'), // Mock the original filename
|
||||||
DEFAULT_CONTEXT_FILENAME: 'GEMINI.md',
|
DEFAULT_CONTEXT_FILENAME: 'GEMINI.md',
|
||||||
|
GEMINI_CONFIG_DIR: '.gemini',
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('Server Config (config.ts)', () => {
|
describe('Server Config (config.ts)', () => {
|
||||||
|
|
|
@ -27,6 +27,7 @@ import { GeminiChat } from './geminiChat.js';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
import { retryWithBackoff } from '../utils/retry.js';
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
|
private chat: Promise<GeminiChat>;
|
||||||
private client: GoogleGenAI;
|
private client: GoogleGenAI;
|
||||||
private model: string;
|
private model: string;
|
||||||
private generateContentConfig: GenerateContentConfig = {
|
private generateContentConfig: GenerateContentConfig = {
|
||||||
|
@ -50,6 +51,11 @@ export class GeminiClient {
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
this.model = config.getModel();
|
this.model = config.getModel();
|
||||||
|
this.chat = this.startChat();
|
||||||
|
}
|
||||||
|
|
||||||
|
getChat(): Promise<GeminiChat> {
|
||||||
|
return this.chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getEnvironment(): Promise<Part[]> {
|
private async getEnvironment(): Promise<Part[]> {
|
||||||
|
@ -114,12 +120,12 @@ export class GeminiClient {
|
||||||
return initialParts;
|
return initialParts;
|
||||||
}
|
}
|
||||||
|
|
||||||
async startChat(): Promise<GeminiChat> {
|
private async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
||||||
const envParts = await this.getEnvironment();
|
const envParts = await this.getEnvironment();
|
||||||
const toolRegistry = await this.config.getToolRegistry();
|
const toolRegistry = await this.config.getToolRegistry();
|
||||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||||
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||||
const history: Content[] = [
|
const initialHistory: Content[] = [
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: envParts,
|
parts: envParts,
|
||||||
|
@ -129,6 +135,7 @@ export class GeminiClient {
|
||||||
parts: [{ text: 'Got it. Thanks for the context!' }],
|
parts: [{ text: 'Got it. Thanks for the context!' }],
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
const history = initialHistory.concat(extraHistory ?? []);
|
||||||
try {
|
try {
|
||||||
const userMemory = this.config.getUserMemory();
|
const userMemory = this.config.getUserMemory();
|
||||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||||
|
@ -157,7 +164,6 @@ export class GeminiClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
async *sendMessageStream(
|
async *sendMessageStream(
|
||||||
chat: GeminiChat,
|
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
turns: number = this.MAX_TURNS,
|
turns: number = this.MAX_TURNS,
|
||||||
|
@ -166,6 +172,7 @@ export class GeminiClient {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const chat = await this.chat;
|
||||||
const turn = new Turn(chat);
|
const turn = new Turn(chat);
|
||||||
const resultStream = turn.run(request, signal);
|
const resultStream = turn.run(request, signal);
|
||||||
for await (const event of resultStream) {
|
for await (const event of resultStream) {
|
||||||
|
@ -175,7 +182,7 @@ export class GeminiClient {
|
||||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
||||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||||
const nextRequest = [{ text: 'Please continue.' }];
|
const nextRequest = [{ text: 'Please continue.' }];
|
||||||
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,10 +56,10 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||||
let stdout = '';
|
let stdout = '';
|
||||||
let stderr = '';
|
let stderr = '';
|
||||||
child.stdout.on('data', (data) => {
|
child.stdout.on('data', (data) => {
|
||||||
stdout += data.toString();
|
stdout += data?.toString();
|
||||||
});
|
});
|
||||||
child.stderr.on('data', (data) => {
|
child.stderr.on('data', (data) => {
|
||||||
stderr += data.toString();
|
stderr += data?.toString();
|
||||||
});
|
});
|
||||||
let error: Error | null = null;
|
let error: Error | null = null;
|
||||||
child.on('error', (err: Error) => {
|
child.on('error', (err: Error) => {
|
||||||
|
|
Loading…
Reference in New Issue