Refactor: Use config.getGeminiClient() for GeminiClient instantiation (#715)

This commit is contained in:
N. Taylor Mullen 2025-06-02 22:30:52 -07:00 committed by GitHub
parent cf84f1af68
commit 8ab74ef1bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 39 additions and 16 deletions

View File

@ -51,6 +51,7 @@ describe('runNonInteractive', () => {
mockConfig = { mockConfig = {
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
} as unknown as Config; } as unknown as Config;
mockProcessStdoutWrite = vi.fn().mockImplementation(() => true); mockProcessStdoutWrite = vi.fn().mockImplementation(() => true);

View File

@ -6,7 +6,6 @@
import { import {
Config, Config,
GeminiClient,
ToolCallRequestInfo, ToolCallRequestInfo,
executeToolCall, executeToolCall,
ToolRegistry, ToolRegistry,
@ -39,7 +38,7 @@ export async function runNonInteractive(
config: Config, config: Config,
input: string, input: string,
): Promise<void> { ): Promise<void> {
const geminiClient = new GeminiClient(config); const geminiClient = config.getGeminiClient();
const toolRegistry: ToolRegistry = await config.getToolRegistry(); const toolRegistry: ToolRegistry = await config.getToolRegistry();
const chat = await geminiClient.getChat(); const chat = await geminiClient.getChat();

View File

@ -25,20 +25,20 @@ const mockSendMessageStream = vi
.mockReturnValue((async function* () {})()); .mockReturnValue((async function* () {})());
const mockStartChat = vi.fn(); const mockStartChat = vi.fn();
vi.mock('@gemini-code/core', async (importOriginal) => { const MockedGeminiClientClass = vi.hoisted(() =>
const actualCoreModule = (await importOriginal()) as any; vi.fn().mockImplementation(function (this: any, _config: any) {
const MockedGeminiClientClass = vi.fn().mockImplementation(function (
this: any,
_config: any,
) {
// _config // _config
this.startChat = mockStartChat; this.startChat = mockStartChat;
this.sendMessageStream = mockSendMessageStream; this.sendMessageStream = mockSendMessageStream;
}); }),
);
vi.mock('@gemini-code/core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
return { return {
...(actualCoreModule || {}), ...(actualCoreModule || {}),
GeminiClient: MockedGeminiClientClass, GeminiClient: MockedGeminiClientClass, // Export the class for type checking or other direct uses
// GeminiChat will be from actualCoreModule if it exists, otherwise undefined Config: actualCoreModule.Config, // Ensure Config is passed through
}; };
}); });
@ -235,6 +235,14 @@ describe('useGeminiStream', () => {
mockAddItem = vi.fn(); mockAddItem = vi.fn();
mockSetShowHelp = vi.fn(); mockSetShowHelp = vi.fn();
// Define the mock for getGeminiClient
const mockGetGeminiClient = vi.fn().mockImplementation(() => {
// MockedGeminiClientClass is defined in the module scope by the previous change.
// It will use the mockStartChat and mockSendMessageStream that are managed within beforeEach.
const clientInstance = new MockedGeminiClientClass(mockConfig);
return clientInstance;
});
mockConfig = { mockConfig = {
apiKey: 'test-api-key', apiKey: 'test-api-key',
model: 'gemini-pro', model: 'gemini-pro',
@ -258,6 +266,7 @@ describe('useGeminiStream', () => {
getToolRegistry: vi.fn( getToolRegistry: vi.fn(
() => ({ getToolSchemaList: vi.fn(() => []) }) as any, () => ({ getToolSchemaList: vi.fn(() => []) }) as any,
), ),
getGeminiClient: mockGetGeminiClient,
} as unknown as Config; } as unknown as Config;
mockOnDebugMessage = vi.fn(); mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockReturnValue(false); mockHandleSlashCommand = vi.fn().mockReturnValue(false);

View File

@ -145,7 +145,7 @@ export const useGeminiStream = (
setInitError(null); setInitError(null);
if (!geminiClientRef.current) { if (!geminiClientRef.current) {
try { try {
geminiClientRef.current = new GeminiClient(config); geminiClientRef.current = config.getGeminiClient();
} catch (error: unknown) { } catch (error: unknown) {
const errorMsg = `Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`; const errorMsg = `Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`;
setInitError(errorMsg); setInitError(errorMsg);

View File

@ -25,6 +25,7 @@ import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
import { reportError } from '../utils/errorReporting.js'; import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
import { retryWithBackoff } from '../utils/retry.js'; import { retryWithBackoff } from '../utils/retry.js';
import { getErrorMessage } from '../utils/errors.js';
export class GeminiClient { export class GeminiClient {
private chat: Promise<GeminiChat>; private chat: Promise<GeminiChat>;
@ -158,8 +159,7 @@ export class GeminiClient {
history, history,
'startChat', 'startChat',
); );
const message = error instanceof Error ? error.message : 'Unknown error.'; throw new Error(`Failed to initialize chat: ${getErrorMessage(error)}`);
throw new Error(`Failed to initialize chat: ${message}`);
} }
} }

View File

@ -39,7 +39,15 @@ describe('EditTool', () => {
rootDir = path.join(tempDir, 'root'); rootDir = path.join(tempDir, 'root');
fs.mkdirSync(rootDir); fs.mkdirSync(rootDir);
// The client instance that EditTool will use
const mockClientInstanceWithGenerateJson = {
generateJson: mockGenerateJson, // mockGenerateJson is already defined and hoisted
};
mockConfig = { mockConfig = {
getGeminiClient: vi
.fn()
.mockReturnValue(mockClientInstanceWithGenerateJson),
getTargetDir: () => rootDir, getTargetDir: () => rootDir,
getApprovalMode: vi.fn(() => false), getApprovalMode: vi.fn(() => false),
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),

View File

@ -114,7 +114,7 @@ Expectation for required parameters:
); );
this.config = config; this.config = config;
this.rootDirectory = path.resolve(this.config.getTargetDir()); this.rootDirectory = path.resolve(this.config.getTargetDir());
this.client = new GeminiClient(this.config); this.client = config.getGeminiClient();
} }
/** /**

View File

@ -53,6 +53,7 @@ const mockConfigInternal = {
getTargetDir: () => rootDir, getTargetDir: () => rootDir,
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT), getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
getGeminiClient: vi.fn(), // Initialize as a plain mock function
getApiKey: () => 'test-key', getApiKey: () => 'test-key',
getModel: () => 'test-model', getModel: () => 'test-model',
getSandbox: () => false, getSandbox: () => false,
@ -97,6 +98,11 @@ describe('WriteFileTool', () => {
) as Mocked<GeminiClient>; ) as Mocked<GeminiClient>;
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClientInstance); vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClientInstance);
// Now that mockGeminiClientInstance is initialized, set the mock implementation for getGeminiClient
mockConfigInternal.getGeminiClient.mockReturnValue(
mockGeminiClientInstance,
);
tool = new WriteFileTool(mockConfig); tool = new WriteFileTool(mockConfig);
// Reset mocks before each test // Reset mocks before each test

View File

@ -77,7 +77,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
}, },
); );
this.client = new GeminiClient(this.config); this.client = this.config.getGeminiClient();
} }
private isWithinRoot(pathToCheck: string): boolean { private isWithinRoot(pathToCheck: string): boolean {