refactor: maintain 1 GeminiChat per GeminiClient (#710)
This commit is contained in:
parent
447826ab40
commit
74801e9004
|
@ -39,7 +39,7 @@ describe('runNonInteractive', () => {
|
|||
sendMessageStream: vi.fn(),
|
||||
};
|
||||
mockGeminiClient = {
|
||||
startChat: vi.fn().mockResolvedValue(mockChat),
|
||||
getChat: vi.fn().mockResolvedValue(mockChat),
|
||||
} as unknown as GeminiClient;
|
||||
mockToolRegistry = {
|
||||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||
|
@ -80,7 +80,6 @@ describe('runNonInteractive', () => {
|
|||
|
||||
await runNonInteractive(mockConfig, 'Test input');
|
||||
|
||||
expect(mockGeminiClient.startChat).toHaveBeenCalled();
|
||||
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
|
||||
message: [{ text: 'Test input' }],
|
||||
config: {
|
||||
|
|
|
@ -42,7 +42,7 @@ export async function runNonInteractive(
|
|||
const geminiClient = new GeminiClient(config);
|
||||
const toolRegistry: ToolRegistry = await config.getToolRegistry();
|
||||
|
||||
const chat = await geminiClient.startChat();
|
||||
const chat = await geminiClient.getChat();
|
||||
const abortController = new AbortController();
|
||||
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
|
||||
|
||||
|
|
|
@ -405,10 +405,9 @@ describe('useGeminiStream', () => {
|
|||
} as TrackedCancelledToolCall,
|
||||
];
|
||||
|
||||
let hookResult: any;
|
||||
await act(async () => {
|
||||
hookResult = renderTestHook(simplifiedToolCalls);
|
||||
});
|
||||
const hookResult = await act(async () =>
|
||||
renderTestHook(simplifiedToolCalls),
|
||||
);
|
||||
|
||||
const {
|
||||
mockMarkToolsAsSubmitted,
|
||||
|
@ -431,9 +430,8 @@ describe('useGeminiStream', () => {
|
|||
toolCall2ResponseParts,
|
||||
]);
|
||||
expect(localMockSendMessageStream).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expectedMergedResponse,
|
||||
expect.anything(),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -17,7 +17,6 @@ import {
|
|||
Config,
|
||||
MessageSenderType,
|
||||
ToolCallRequestInfo,
|
||||
GeminiChat,
|
||||
} from '@gemini-code/core';
|
||||
import { type PartListUnion } from '@google/genai';
|
||||
import {
|
||||
|
@ -76,7 +75,6 @@ export const useGeminiStream = (
|
|||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const chatSessionRef = useRef<GeminiChat | null>(null);
|
||||
const geminiClientRef = useRef<GeminiClient | null>(null);
|
||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
|
@ -256,31 +254,6 @@ export const useGeminiStream = (
|
|||
],
|
||||
);
|
||||
|
||||
const ensureChatSession = useCallback(async (): Promise<{
|
||||
client: GeminiClient | null;
|
||||
chat: GeminiChat | null;
|
||||
}> => {
|
||||
const currentClient = geminiClientRef.current;
|
||||
if (!currentClient) {
|
||||
const errorMsg = 'Gemini client is not available.';
|
||||
setInitError(errorMsg);
|
||||
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
|
||||
return { client: null, chat: null };
|
||||
}
|
||||
|
||||
if (!chatSessionRef.current) {
|
||||
try {
|
||||
chatSessionRef.current = await currentClient.startChat();
|
||||
} catch (err: unknown) {
|
||||
const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`;
|
||||
setInitError(errorMsg);
|
||||
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
|
||||
return { client: currentClient, chat: null };
|
||||
}
|
||||
}
|
||||
return { client: currentClient, chat: chatSessionRef.current };
|
||||
}, [addItem]);
|
||||
|
||||
// --- Stream Event Handlers ---
|
||||
|
||||
const handleContentEvent = useCallback(
|
||||
|
@ -444,9 +417,12 @@ export const useGeminiStream = (
|
|||
return;
|
||||
}
|
||||
|
||||
const { client, chat } = await ensureChatSession();
|
||||
const client = geminiClientRef.current;
|
||||
|
||||
if (!client || !chat) {
|
||||
if (!client) {
|
||||
const errorMsg = 'Gemini client is not available.';
|
||||
setInitError(errorMsg);
|
||||
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -454,7 +430,7 @@ export const useGeminiStream = (
|
|||
setInitError(null);
|
||||
|
||||
try {
|
||||
const stream = client.sendMessageStream(chat, queryToSend, abortSignal);
|
||||
const stream = client.sendMessageStream(queryToSend, abortSignal);
|
||||
const processingStatus = await processGeminiStreamEvents(
|
||||
stream,
|
||||
userMessageTimestamp,
|
||||
|
@ -487,7 +463,6 @@ export const useGeminiStream = (
|
|||
streamingState,
|
||||
setShowHelp,
|
||||
prepareQueryForGemini,
|
||||
ensureChatSession,
|
||||
processGeminiStreamEvents,
|
||||
pendingHistoryItemRef,
|
||||
addItem,
|
||||
|
|
|
@ -35,6 +35,7 @@ vi.mock('../tools/memoryTool', () => ({
|
|||
setGeminiMdFilename: vi.fn(),
|
||||
getCurrentGeminiMdFilename: vi.fn(() => 'GEMINI.md'), // Mock the original filename
|
||||
DEFAULT_CONTEXT_FILENAME: 'GEMINI.md',
|
||||
GEMINI_CONFIG_DIR: '.gemini',
|
||||
}));
|
||||
|
||||
describe('Server Config (config.ts)', () => {
|
||||
|
|
|
@ -27,6 +27,7 @@ import { GeminiChat } from './geminiChat.js';
|
|||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
|
||||
export class GeminiClient {
|
||||
private chat: Promise<GeminiChat>;
|
||||
private client: GoogleGenAI;
|
||||
private model: string;
|
||||
private generateContentConfig: GenerateContentConfig = {
|
||||
|
@ -50,6 +51,11 @@ export class GeminiClient {
|
|||
},
|
||||
});
|
||||
this.model = config.getModel();
|
||||
this.chat = this.startChat();
|
||||
}
|
||||
|
||||
getChat(): Promise<GeminiChat> {
|
||||
return this.chat;
|
||||
}
|
||||
|
||||
private async getEnvironment(): Promise<Part[]> {
|
||||
|
@ -114,12 +120,12 @@ export class GeminiClient {
|
|||
return initialParts;
|
||||
}
|
||||
|
||||
async startChat(): Promise<GeminiChat> {
|
||||
private async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
||||
const envParts = await this.getEnvironment();
|
||||
const toolRegistry = await this.config.getToolRegistry();
|
||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||
const history: Content[] = [
|
||||
const initialHistory: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: envParts,
|
||||
|
@ -129,6 +135,7 @@ export class GeminiClient {
|
|||
parts: [{ text: 'Got it. Thanks for the context!' }],
|
||||
},
|
||||
];
|
||||
const history = initialHistory.concat(extraHistory ?? []);
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||
|
@ -157,7 +164,6 @@ export class GeminiClient {
|
|||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
chat: GeminiChat,
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
turns: number = this.MAX_TURNS,
|
||||
|
@ -166,6 +172,7 @@ export class GeminiClient {
|
|||
return;
|
||||
}
|
||||
|
||||
const chat = await this.chat;
|
||||
const turn = new Turn(chat);
|
||||
const resultStream = turn.run(request, signal);
|
||||
for await (const event of resultStream) {
|
||||
|
@ -175,7 +182,7 @@ export class GeminiClient {
|
|||
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
||||
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -56,10 +56,10 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
|||
let stdout = '';
|
||||
let stderr = '';
|
||||
child.stdout.on('data', (data) => {
|
||||
stdout += data.toString();
|
||||
stdout += data?.toString();
|
||||
});
|
||||
child.stderr.on('data', (data) => {
|
||||
stderr += data.toString();
|
||||
stderr += data?.toString();
|
||||
});
|
||||
let error: Error | null = null;
|
||||
child.on('error', (err: Error) => {
|
||||
|
|
Loading…
Reference in New Issue