Get ToolRegistry from config instead of passing it (#6592)

This commit is contained in:
Tommaso Sciortino 2025-08-19 16:27:15 -07:00 committed by GitHub
parent f1575f6d8d
commit a01d411c5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 14 additions and 29 deletions

View File

@ -137,7 +137,6 @@ describe('runNonInteractive', () => {
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith( expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
mockConfig, mockConfig,
expect.objectContaining({ name: 'testTool' }), expect.objectContaining({ name: 'testTool' }),
mockToolRegistry,
expect.any(AbortSignal), expect.any(AbortSignal),
); );
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith( expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(

View File

@ -8,7 +8,6 @@ import {
Config, Config,
ToolCallRequestInfo, ToolCallRequestInfo,
executeToolCall, executeToolCall,
ToolRegistry,
shutdownTelemetry, shutdownTelemetry,
isTelemetrySdkInitialized, isTelemetrySdkInitialized,
GeminiEventType, GeminiEventType,
@ -39,7 +38,6 @@ export async function runNonInteractive(
}); });
const geminiClient = config.getGeminiClient(); const geminiClient = config.getGeminiClient();
const toolRegistry: ToolRegistry = config.getToolRegistry();
const abortController = new AbortController(); const abortController = new AbortController();
let currentMessages: Content[] = [ let currentMessages: Content[] = [
@ -100,7 +98,6 @@ export async function runNonInteractive(
const toolResponse = await executeToolCall( const toolResponse = await executeToolCall(
config, config,
requestInfo, requestInfo,
toolRegistry,
abortController.signal, abortController.signal,
); );

View File

@ -16,20 +16,11 @@ import {
import { Part } from '@google/genai'; import { Part } from '@google/genai';
import { MockTool } from '../test-utils/tools.js'; import { MockTool } from '../test-utils/tools.js';
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
} as unknown as Config;
describe('executeToolCall', () => { describe('executeToolCall', () => {
let mockToolRegistry: ToolRegistry; let mockToolRegistry: ToolRegistry;
let mockTool: MockTool; let mockTool: MockTool;
let abortController: AbortController; let abortController: AbortController;
let mockConfig: Config;
beforeEach(() => { beforeEach(() => {
mockTool = new MockTool(); mockTool = new MockTool();
@ -39,6 +30,17 @@ describe('executeToolCall', () => {
// Add other ToolRegistry methods if needed, or use a more complete mock // Add other ToolRegistry methods if needed, or use a more complete mock
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
model: 'test-model',
authType: 'oauth-personal',
}),
getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
abortController = new AbortController(); abortController = new AbortController();
}); });
@ -60,7 +62,6 @@ describe('executeToolCall', () => {
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
request, request,
mockToolRegistry,
abortController.signal, abortController.signal,
); );
@ -94,7 +95,6 @@ describe('executeToolCall', () => {
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
request, request,
mockToolRegistry,
abortController.signal, abortController.signal,
); );
@ -141,7 +141,6 @@ describe('executeToolCall', () => {
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
request, request,
mockToolRegistry,
abortController.signal, abortController.signal,
); );
expect(response).toStrictEqual({ expect(response).toStrictEqual({
@ -185,7 +184,6 @@ describe('executeToolCall', () => {
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
request, request,
mockToolRegistry,
abortController.signal, abortController.signal,
); );
expect(response).toStrictEqual({ expect(response).toStrictEqual({
@ -222,7 +220,6 @@ describe('executeToolCall', () => {
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
request, request,
mockToolRegistry,
abortController.signal, abortController.signal,
); );
@ -262,7 +259,6 @@ describe('executeToolCall', () => {
const response = await executeToolCall( const response = await executeToolCall(
mockConfig, mockConfig,
request, request,
mockToolRegistry,
abortController.signal, abortController.signal,
); );

View File

@ -10,7 +10,6 @@ import {
ToolCallRequestInfo, ToolCallRequestInfo,
ToolCallResponseInfo, ToolCallResponseInfo,
ToolErrorType, ToolErrorType,
ToolRegistry,
ToolResult, ToolResult,
} from '../index.js'; } from '../index.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
@ -25,10 +24,9 @@ import { ToolCallDecision } from '../telemetry/tool-call-decision.js';
export async function executeToolCall( export async function executeToolCall(
config: Config, config: Config,
toolCallRequest: ToolCallRequestInfo, toolCallRequest: ToolCallRequestInfo,
toolRegistry: ToolRegistry,
abortSignal?: AbortSignal, abortSignal?: AbortSignal,
): Promise<ToolCallResponseInfo> { ): Promise<ToolCallResponseInfo> {
const tool = toolRegistry.getTool(toolCallRequest.name); const tool = config.getToolRegistry().getTool(toolCallRequest.name);
const startTime = Date.now(); const startTime = Date.now();
if (!tool) { if (!tool) {

View File

@ -534,7 +534,7 @@ describe('subagent.ts', () => {
parameters: { type: Type.OBJECT, properties: {} }, parameters: { type: Type.OBJECT, properties: {} },
}; };
const { config, toolRegistry } = await createMockConfig({ const { config } = await createMockConfig({
getFunctionDeclarationsFiltered: vi getFunctionDeclarationsFiltered: vi
.fn() .fn()
.mockReturnValue([listFilesToolDef]), .mockReturnValue([listFilesToolDef]),
@ -580,7 +580,6 @@ describe('subagent.ts', () => {
expect(executeToolCall).toHaveBeenCalledWith( expect(executeToolCall).toHaveBeenCalledWith(
config, config,
expect.objectContaining({ name: 'list_files', args: { path: '.' } }), expect.objectContaining({ name: 'list_files', args: { path: '.' } }),
toolRegistry,
expect.any(AbortSignal), expect.any(AbortSignal),
); );

View File

@ -5,7 +5,6 @@
*/ */
import { reportError } from '../utils/errorReporting.js'; import { reportError } from '../utils/errorReporting.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { ToolCallRequestInfo } from './turn.js'; import { ToolCallRequestInfo } from './turn.js';
import { executeToolCall } from './nonInteractiveToolExecutor.js'; import { executeToolCall } from './nonInteractiveToolExecutor.js';
@ -422,7 +421,6 @@ export class SubAgentScope {
if (functionCalls.length > 0) { if (functionCalls.length > 0) {
currentMessages = await this.processFunctionCalls( currentMessages = await this.processFunctionCalls(
functionCalls, functionCalls,
toolRegistry,
abortController, abortController,
promptId, promptId,
); );
@ -479,7 +477,6 @@ export class SubAgentScope {
*/ */
private async processFunctionCalls( private async processFunctionCalls(
functionCalls: FunctionCall[], functionCalls: FunctionCall[],
toolRegistry: ToolRegistry,
abortController: AbortController, abortController: AbortController,
promptId: string, promptId: string,
): Promise<Content[]> { ): Promise<Content[]> {
@ -513,7 +510,6 @@ export class SubAgentScope {
toolResponse = await executeToolCall( toolResponse = await executeToolCall(
this.runtimeContext, this.runtimeContext,
requestInfo, requestInfo,
toolRegistry,
abortController.signal, abortController.signal,
); );
} }