Get ToolRegistry from config instead of passing it (#6592)
This commit is contained in:
parent
f1575f6d8d
commit
a01d411c5a
|
@ -137,7 +137,6 @@ describe('runNonInteractive', () => {
|
||||||
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
|
expect(mockCoreExecuteToolCall).toHaveBeenCalledWith(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
expect.objectContaining({ name: 'testTool' }),
|
expect.objectContaining({ name: 'testTool' }),
|
||||||
mockToolRegistry,
|
|
||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
);
|
);
|
||||||
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
|
expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith(
|
||||||
|
|
|
@ -8,7 +8,6 @@ import {
|
||||||
Config,
|
Config,
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
executeToolCall,
|
executeToolCall,
|
||||||
ToolRegistry,
|
|
||||||
shutdownTelemetry,
|
shutdownTelemetry,
|
||||||
isTelemetrySdkInitialized,
|
isTelemetrySdkInitialized,
|
||||||
GeminiEventType,
|
GeminiEventType,
|
||||||
|
@ -39,7 +38,6 @@ export async function runNonInteractive(
|
||||||
});
|
});
|
||||||
|
|
||||||
const geminiClient = config.getGeminiClient();
|
const geminiClient = config.getGeminiClient();
|
||||||
const toolRegistry: ToolRegistry = config.getToolRegistry();
|
|
||||||
|
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
let currentMessages: Content[] = [
|
let currentMessages: Content[] = [
|
||||||
|
@ -100,7 +98,6 @@ export async function runNonInteractive(
|
||||||
const toolResponse = await executeToolCall(
|
const toolResponse = await executeToolCall(
|
||||||
config,
|
config,
|
||||||
requestInfo,
|
requestInfo,
|
||||||
toolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -16,20 +16,11 @@ import {
|
||||||
import { Part } from '@google/genai';
|
import { Part } from '@google/genai';
|
||||||
import { MockTool } from '../test-utils/tools.js';
|
import { MockTool } from '../test-utils/tools.js';
|
||||||
|
|
||||||
const mockConfig = {
|
|
||||||
getSessionId: () => 'test-session-id',
|
|
||||||
getUsageStatisticsEnabled: () => true,
|
|
||||||
getDebugMode: () => false,
|
|
||||||
getContentGeneratorConfig: () => ({
|
|
||||||
model: 'test-model',
|
|
||||||
authType: 'oauth-personal',
|
|
||||||
}),
|
|
||||||
} as unknown as Config;
|
|
||||||
|
|
||||||
describe('executeToolCall', () => {
|
describe('executeToolCall', () => {
|
||||||
let mockToolRegistry: ToolRegistry;
|
let mockToolRegistry: ToolRegistry;
|
||||||
let mockTool: MockTool;
|
let mockTool: MockTool;
|
||||||
let abortController: AbortController;
|
let abortController: AbortController;
|
||||||
|
let mockConfig: Config;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockTool = new MockTool();
|
mockTool = new MockTool();
|
||||||
|
@ -39,6 +30,17 @@ describe('executeToolCall', () => {
|
||||||
// Add other ToolRegistry methods if needed, or use a more complete mock
|
// Add other ToolRegistry methods if needed, or use a more complete mock
|
||||||
} as unknown as ToolRegistry;
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
mockConfig = {
|
||||||
|
getSessionId: () => 'test-session-id',
|
||||||
|
getUsageStatisticsEnabled: () => true,
|
||||||
|
getDebugMode: () => false,
|
||||||
|
getContentGeneratorConfig: () => ({
|
||||||
|
model: 'test-model',
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
}),
|
||||||
|
getToolRegistry: () => mockToolRegistry,
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
abortController = new AbortController();
|
abortController = new AbortController();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -60,7 +62,6 @@ describe('executeToolCall', () => {
|
||||||
const response = await executeToolCall(
|
const response = await executeToolCall(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
request,
|
request,
|
||||||
mockToolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -94,7 +95,6 @@ describe('executeToolCall', () => {
|
||||||
const response = await executeToolCall(
|
const response = await executeToolCall(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
request,
|
request,
|
||||||
mockToolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -141,7 +141,6 @@ describe('executeToolCall', () => {
|
||||||
const response = await executeToolCall(
|
const response = await executeToolCall(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
request,
|
request,
|
||||||
mockToolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
expect(response).toStrictEqual({
|
expect(response).toStrictEqual({
|
||||||
|
@ -185,7 +184,6 @@ describe('executeToolCall', () => {
|
||||||
const response = await executeToolCall(
|
const response = await executeToolCall(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
request,
|
request,
|
||||||
mockToolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
expect(response).toStrictEqual({
|
expect(response).toStrictEqual({
|
||||||
|
@ -222,7 +220,6 @@ describe('executeToolCall', () => {
|
||||||
const response = await executeToolCall(
|
const response = await executeToolCall(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
request,
|
request,
|
||||||
mockToolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -262,7 +259,6 @@ describe('executeToolCall', () => {
|
||||||
const response = await executeToolCall(
|
const response = await executeToolCall(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
request,
|
request,
|
||||||
mockToolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,6 @@ import {
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
ToolCallResponseInfo,
|
ToolCallResponseInfo,
|
||||||
ToolErrorType,
|
ToolErrorType,
|
||||||
ToolRegistry,
|
|
||||||
ToolResult,
|
ToolResult,
|
||||||
} from '../index.js';
|
} from '../index.js';
|
||||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||||
|
@ -25,10 +24,9 @@ import { ToolCallDecision } from '../telemetry/tool-call-decision.js';
|
||||||
export async function executeToolCall(
|
export async function executeToolCall(
|
||||||
config: Config,
|
config: Config,
|
||||||
toolCallRequest: ToolCallRequestInfo,
|
toolCallRequest: ToolCallRequestInfo,
|
||||||
toolRegistry: ToolRegistry,
|
|
||||||
abortSignal?: AbortSignal,
|
abortSignal?: AbortSignal,
|
||||||
): Promise<ToolCallResponseInfo> {
|
): Promise<ToolCallResponseInfo> {
|
||||||
const tool = toolRegistry.getTool(toolCallRequest.name);
|
const tool = config.getToolRegistry().getTool(toolCallRequest.name);
|
||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
if (!tool) {
|
if (!tool) {
|
||||||
|
|
|
@ -534,7 +534,7 @@ describe('subagent.ts', () => {
|
||||||
parameters: { type: Type.OBJECT, properties: {} },
|
parameters: { type: Type.OBJECT, properties: {} },
|
||||||
};
|
};
|
||||||
|
|
||||||
const { config, toolRegistry } = await createMockConfig({
|
const { config } = await createMockConfig({
|
||||||
getFunctionDeclarationsFiltered: vi
|
getFunctionDeclarationsFiltered: vi
|
||||||
.fn()
|
.fn()
|
||||||
.mockReturnValue([listFilesToolDef]),
|
.mockReturnValue([listFilesToolDef]),
|
||||||
|
@ -580,7 +580,6 @@ describe('subagent.ts', () => {
|
||||||
expect(executeToolCall).toHaveBeenCalledWith(
|
expect(executeToolCall).toHaveBeenCalledWith(
|
||||||
config,
|
config,
|
||||||
expect.objectContaining({ name: 'list_files', args: { path: '.' } }),
|
expect.objectContaining({ name: 'list_files', args: { path: '.' } }),
|
||||||
toolRegistry,
|
|
||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { ToolCallRequestInfo } from './turn.js';
|
import { ToolCallRequestInfo } from './turn.js';
|
||||||
import { executeToolCall } from './nonInteractiveToolExecutor.js';
|
import { executeToolCall } from './nonInteractiveToolExecutor.js';
|
||||||
|
@ -422,7 +421,6 @@ export class SubAgentScope {
|
||||||
if (functionCalls.length > 0) {
|
if (functionCalls.length > 0) {
|
||||||
currentMessages = await this.processFunctionCalls(
|
currentMessages = await this.processFunctionCalls(
|
||||||
functionCalls,
|
functionCalls,
|
||||||
toolRegistry,
|
|
||||||
abortController,
|
abortController,
|
||||||
promptId,
|
promptId,
|
||||||
);
|
);
|
||||||
|
@ -479,7 +477,6 @@ export class SubAgentScope {
|
||||||
*/
|
*/
|
||||||
private async processFunctionCalls(
|
private async processFunctionCalls(
|
||||||
functionCalls: FunctionCall[],
|
functionCalls: FunctionCall[],
|
||||||
toolRegistry: ToolRegistry,
|
|
||||||
abortController: AbortController,
|
abortController: AbortController,
|
||||||
promptId: string,
|
promptId: string,
|
||||||
): Promise<Content[]> {
|
): Promise<Content[]> {
|
||||||
|
@ -513,7 +510,6 @@ export class SubAgentScope {
|
||||||
toolResponse = await executeToolCall(
|
toolResponse = await executeToolCall(
|
||||||
this.runtimeContext,
|
this.runtimeContext,
|
||||||
requestInfo,
|
requestInfo,
|
||||||
toolRegistry,
|
|
||||||
abortController.signal,
|
abortController.signal,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue