enable async tool discovery by making the registry accessor async; remove call to discoverTools that caused duplicate discovery (#691)

This commit is contained in:
Olcan 2025-06-02 09:56:32 -07:00 committed by GitHub
parent 467dec4edf
commit c5869db080
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 21 additions and 20 deletions

View File

@ -42,7 +42,6 @@ describe('runNonInteractive', () => {
startChat: vi.fn().mockResolvedValue(mockChat), startChat: vi.fn().mockResolvedValue(mockChat),
} as unknown as GeminiClient; } as unknown as GeminiClient;
mockToolRegistry = { mockToolRegistry = {
discoverTools: vi.fn().mockResolvedValue(undefined),
getFunctionDeclarations: vi.fn().mockReturnValue([]), getFunctionDeclarations: vi.fn().mockReturnValue([]),
getTool: vi.fn(), getTool: vi.fn(),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
@ -82,7 +81,6 @@ describe('runNonInteractive', () => {
await runNonInteractive(mockConfig, 'Test input'); await runNonInteractive(mockConfig, 'Test input');
expect(mockGeminiClient.startChat).toHaveBeenCalled(); expect(mockGeminiClient.startChat).toHaveBeenCalled();
expect(mockToolRegistry.discoverTools).toHaveBeenCalled();
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({ expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
message: [{ text: 'Test input' }], message: [{ text: 'Test input' }],
config: { config: {

View File

@ -40,8 +40,7 @@ export async function runNonInteractive(
input: string, input: string,
): Promise<void> { ): Promise<void> {
const geminiClient = new GeminiClient(config); const geminiClient = new GeminiClient(config);
const toolRegistry: ToolRegistry = config.getToolRegistry(); const toolRegistry: ToolRegistry = await config.getToolRegistry();
await toolRegistry.discoverTools();
const chat = await geminiClient.startChat(); const chat = await geminiClient.startChat();
const abortController = new AbortController(); const abortController = new AbortController();

View File

@ -138,7 +138,7 @@ export async function handleAtCommand({
const atPathToResolvedSpecMap = new Map<string, string>(); const atPathToResolvedSpecMap = new Map<string, string>();
const contentLabelsForDisplay: string[] = []; const contentLabelsForDisplay: string[] = [];
const toolRegistry = config.getToolRegistry(); const toolRegistry = await config.getToolRegistry();
const readManyFilesTool = toolRegistry.getTool('read_many_files'); const readManyFilesTool = toolRegistry.getTool('read_many_files');
const globTool = toolRegistry.getTool('glob'); const globTool = toolRegistry.getTool('glob');

View File

@ -60,7 +60,7 @@ export interface ConfigParameters {
} }
export class Config { export class Config {
private toolRegistry: ToolRegistry; private toolRegistry: Promise<ToolRegistry>;
private readonly apiKey: string; private readonly apiKey: string;
private readonly model: string; private readonly model: string;
private readonly sandbox: boolean | string; private readonly sandbox: boolean | string;
@ -124,7 +124,7 @@ export class Config {
return this.targetDir; return this.targetDir;
} }
getToolRegistry(): ToolRegistry { async getToolRegistry(): Promise<ToolRegistry> {
return this.toolRegistry; return this.toolRegistry;
} }
@ -232,7 +232,7 @@ export function createServerConfig(params: ConfigParameters): Config {
}); });
} }
export function createToolRegistry(config: Config): ToolRegistry { export function createToolRegistry(config: Config): Promise<ToolRegistry> {
const registry = new ToolRegistry(config); const registry = new ToolRegistry(config);
const targetDir = config.getTargetDir(); const targetDir = config.getTargetDir();
const tools = config.getCoreTools() const tools = config.getCoreTools()
@ -259,6 +259,8 @@ export function createToolRegistry(config: Config): ToolRegistry {
registerCoreTool(ShellTool, config); registerCoreTool(ShellTool, config);
registerCoreTool(MemoryTool); registerCoreTool(MemoryTool);
registerCoreTool(WebSearchTool, config); registerCoreTool(WebSearchTool, config);
registry.discoverTools(); return (async () => {
await registry.discoverTools();
return registry; return registry;
})();
} }

View File

@ -70,13 +70,14 @@ export class GeminiClient {
`.trim(); `.trim();
const initialParts: Part[] = [{ text: context }]; const initialParts: Part[] = [{ text: context }];
const toolRegistry = await this.config.getToolRegistry();
// Add full file context if the flag is set // Add full file context if the flag is set
if (this.config.getFullContext()) { if (this.config.getFullContext()) {
try { try {
const readManyFilesTool = this.config const readManyFilesTool = toolRegistry.getTool(
.getToolRegistry() 'read_many_files',
.getTool('read_many_files') as ReadManyFilesTool; ) as ReadManyFilesTool;
if (readManyFilesTool) { if (readManyFilesTool) {
// Read all files in the target directory // Read all files in the target directory
const result = await readManyFilesTool.execute( const result = await readManyFilesTool.execute(
@ -114,9 +115,8 @@ export class GeminiClient {
async startChat(): Promise<GeminiChat> { async startChat(): Promise<GeminiChat> {
const envParts = await this.getEnvironment(); const envParts = await this.getEnvironment();
const toolDeclarations = this.config const toolRegistry = await this.config.getToolRegistry();
.getToolRegistry() const toolDeclarations = toolRegistry.getFunctionDeclarations();
.getFunctionDeclarations();
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
const history: Content[] = [ const history: Content[] = [
{ {

View File

@ -155,14 +155,14 @@ const createErrorResponse = (
}); });
interface CoreToolSchedulerOptions { interface CoreToolSchedulerOptions {
toolRegistry: ToolRegistry; toolRegistry: Promise<ToolRegistry>;
outputUpdateHandler?: OutputUpdateHandler; outputUpdateHandler?: OutputUpdateHandler;
onAllToolCallsComplete?: AllToolCallsCompleteHandler; onAllToolCallsComplete?: AllToolCallsCompleteHandler;
onToolCallsUpdate?: ToolCallsUpdateHandler; onToolCallsUpdate?: ToolCallsUpdateHandler;
} }
export class CoreToolScheduler { export class CoreToolScheduler {
private toolRegistry: ToolRegistry; private toolRegistry: Promise<ToolRegistry>;
private toolCalls: ToolCall[] = []; private toolCalls: ToolCall[] = [];
private abortController: AbortController; private abortController: AbortController;
private outputUpdateHandler?: OutputUpdateHandler; private outputUpdateHandler?: OutputUpdateHandler;
@ -295,10 +295,11 @@ export class CoreToolScheduler {
); );
} }
const requestsToProcess = Array.isArray(request) ? request : [request]; const requestsToProcess = Array.isArray(request) ? request : [request];
const toolRegistry = await this.toolRegistry;
const newToolCalls: ToolCall[] = requestsToProcess.map( const newToolCalls: ToolCall[] = requestsToProcess.map(
(reqInfo): ToolCall => { (reqInfo): ToolCall => {
const toolInstance = this.toolRegistry.getTool(reqInfo.name); const toolInstance = toolRegistry.getTool(reqInfo.name);
if (!toolInstance) { if (!toolInstance) {
return { return {
status: 'error', status: 'error',

View File

@ -100,6 +100,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
export class ToolRegistry { export class ToolRegistry {
private tools: Map<string, Tool> = new Map(); private tools: Map<string, Tool> = new Map();
private discovery: Promise<void> | null = null;
private config: Config; private config: Config;
constructor(config: Config) { constructor(config: Config) {
@ -121,7 +122,7 @@ export class ToolRegistry {
} }
/** /**
* Discovers tools from project, if a discovery command is configured. * Discovers tools from project (if available and configured).
* Can be called multiple times to update discovered tools. * Can be called multiple times to update discovered tools.
*/ */
async discoverTools(): Promise<void> { async discoverTools(): Promise<void> {