enable async tool discovery by making the registry accessor async; remove call to discoverTools that caused duplicate discovery (#691)
This commit is contained in:
parent
467dec4edf
commit
c5869db080
|
@ -42,7 +42,6 @@ describe('runNonInteractive', () => {
|
||||||
startChat: vi.fn().mockResolvedValue(mockChat),
|
startChat: vi.fn().mockResolvedValue(mockChat),
|
||||||
} as unknown as GeminiClient;
|
} as unknown as GeminiClient;
|
||||||
mockToolRegistry = {
|
mockToolRegistry = {
|
||||||
discoverTools: vi.fn().mockResolvedValue(undefined),
|
|
||||||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||||
getTool: vi.fn(),
|
getTool: vi.fn(),
|
||||||
} as unknown as ToolRegistry;
|
} as unknown as ToolRegistry;
|
||||||
|
@ -82,7 +81,6 @@ describe('runNonInteractive', () => {
|
||||||
await runNonInteractive(mockConfig, 'Test input');
|
await runNonInteractive(mockConfig, 'Test input');
|
||||||
|
|
||||||
expect(mockGeminiClient.startChat).toHaveBeenCalled();
|
expect(mockGeminiClient.startChat).toHaveBeenCalled();
|
||||||
expect(mockToolRegistry.discoverTools).toHaveBeenCalled();
|
|
||||||
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
|
expect(mockChat.sendMessageStream).toHaveBeenCalledWith({
|
||||||
message: [{ text: 'Test input' }],
|
message: [{ text: 'Test input' }],
|
||||||
config: {
|
config: {
|
||||||
|
|
|
@ -40,8 +40,7 @@ export async function runNonInteractive(
|
||||||
input: string,
|
input: string,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const geminiClient = new GeminiClient(config);
|
const geminiClient = new GeminiClient(config);
|
||||||
const toolRegistry: ToolRegistry = config.getToolRegistry();
|
const toolRegistry: ToolRegistry = await config.getToolRegistry();
|
||||||
await toolRegistry.discoverTools();
|
|
||||||
|
|
||||||
const chat = await geminiClient.startChat();
|
const chat = await geminiClient.startChat();
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
|
|
|
@ -138,7 +138,7 @@ export async function handleAtCommand({
|
||||||
const atPathToResolvedSpecMap = new Map<string, string>();
|
const atPathToResolvedSpecMap = new Map<string, string>();
|
||||||
const contentLabelsForDisplay: string[] = [];
|
const contentLabelsForDisplay: string[] = [];
|
||||||
|
|
||||||
const toolRegistry = config.getToolRegistry();
|
const toolRegistry = await config.getToolRegistry();
|
||||||
const readManyFilesTool = toolRegistry.getTool('read_many_files');
|
const readManyFilesTool = toolRegistry.getTool('read_many_files');
|
||||||
const globTool = toolRegistry.getTool('glob');
|
const globTool = toolRegistry.getTool('glob');
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ export interface ConfigParameters {
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Config {
|
export class Config {
|
||||||
private toolRegistry: ToolRegistry;
|
private toolRegistry: Promise<ToolRegistry>;
|
||||||
private readonly apiKey: string;
|
private readonly apiKey: string;
|
||||||
private readonly model: string;
|
private readonly model: string;
|
||||||
private readonly sandbox: boolean | string;
|
private readonly sandbox: boolean | string;
|
||||||
|
@ -124,7 +124,7 @@ export class Config {
|
||||||
return this.targetDir;
|
return this.targetDir;
|
||||||
}
|
}
|
||||||
|
|
||||||
getToolRegistry(): ToolRegistry {
|
async getToolRegistry(): Promise<ToolRegistry> {
|
||||||
return this.toolRegistry;
|
return this.toolRegistry;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ export function createServerConfig(params: ConfigParameters): Config {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createToolRegistry(config: Config): ToolRegistry {
|
export function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
||||||
const registry = new ToolRegistry(config);
|
const registry = new ToolRegistry(config);
|
||||||
const targetDir = config.getTargetDir();
|
const targetDir = config.getTargetDir();
|
||||||
const tools = config.getCoreTools()
|
const tools = config.getCoreTools()
|
||||||
|
@ -259,6 +259,8 @@ export function createToolRegistry(config: Config): ToolRegistry {
|
||||||
registerCoreTool(ShellTool, config);
|
registerCoreTool(ShellTool, config);
|
||||||
registerCoreTool(MemoryTool);
|
registerCoreTool(MemoryTool);
|
||||||
registerCoreTool(WebSearchTool, config);
|
registerCoreTool(WebSearchTool, config);
|
||||||
registry.discoverTools();
|
return (async () => {
|
||||||
return registry;
|
await registry.discoverTools();
|
||||||
|
return registry;
|
||||||
|
})();
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,13 +70,14 @@ export class GeminiClient {
|
||||||
`.trim();
|
`.trim();
|
||||||
|
|
||||||
const initialParts: Part[] = [{ text: context }];
|
const initialParts: Part[] = [{ text: context }];
|
||||||
|
const toolRegistry = await this.config.getToolRegistry();
|
||||||
|
|
||||||
// Add full file context if the flag is set
|
// Add full file context if the flag is set
|
||||||
if (this.config.getFullContext()) {
|
if (this.config.getFullContext()) {
|
||||||
try {
|
try {
|
||||||
const readManyFilesTool = this.config
|
const readManyFilesTool = toolRegistry.getTool(
|
||||||
.getToolRegistry()
|
'read_many_files',
|
||||||
.getTool('read_many_files') as ReadManyFilesTool;
|
) as ReadManyFilesTool;
|
||||||
if (readManyFilesTool) {
|
if (readManyFilesTool) {
|
||||||
// Read all files in the target directory
|
// Read all files in the target directory
|
||||||
const result = await readManyFilesTool.execute(
|
const result = await readManyFilesTool.execute(
|
||||||
|
@ -114,9 +115,8 @@ export class GeminiClient {
|
||||||
|
|
||||||
async startChat(): Promise<GeminiChat> {
|
async startChat(): Promise<GeminiChat> {
|
||||||
const envParts = await this.getEnvironment();
|
const envParts = await this.getEnvironment();
|
||||||
const toolDeclarations = this.config
|
const toolRegistry = await this.config.getToolRegistry();
|
||||||
.getToolRegistry()
|
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||||
.getFunctionDeclarations();
|
|
||||||
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||||
const history: Content[] = [
|
const history: Content[] = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -155,14 +155,14 @@ const createErrorResponse = (
|
||||||
});
|
});
|
||||||
|
|
||||||
interface CoreToolSchedulerOptions {
|
interface CoreToolSchedulerOptions {
|
||||||
toolRegistry: ToolRegistry;
|
toolRegistry: Promise<ToolRegistry>;
|
||||||
outputUpdateHandler?: OutputUpdateHandler;
|
outputUpdateHandler?: OutputUpdateHandler;
|
||||||
onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
||||||
onToolCallsUpdate?: ToolCallsUpdateHandler;
|
onToolCallsUpdate?: ToolCallsUpdateHandler;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class CoreToolScheduler {
|
export class CoreToolScheduler {
|
||||||
private toolRegistry: ToolRegistry;
|
private toolRegistry: Promise<ToolRegistry>;
|
||||||
private toolCalls: ToolCall[] = [];
|
private toolCalls: ToolCall[] = [];
|
||||||
private abortController: AbortController;
|
private abortController: AbortController;
|
||||||
private outputUpdateHandler?: OutputUpdateHandler;
|
private outputUpdateHandler?: OutputUpdateHandler;
|
||||||
|
@ -295,10 +295,11 @@ export class CoreToolScheduler {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
const requestsToProcess = Array.isArray(request) ? request : [request];
|
const requestsToProcess = Array.isArray(request) ? request : [request];
|
||||||
|
const toolRegistry = await this.toolRegistry;
|
||||||
|
|
||||||
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
||||||
(reqInfo): ToolCall => {
|
(reqInfo): ToolCall => {
|
||||||
const toolInstance = this.toolRegistry.getTool(reqInfo.name);
|
const toolInstance = toolRegistry.getTool(reqInfo.name);
|
||||||
if (!toolInstance) {
|
if (!toolInstance) {
|
||||||
return {
|
return {
|
||||||
status: 'error',
|
status: 'error',
|
||||||
|
|
|
@ -100,6 +100,7 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||||
|
|
||||||
export class ToolRegistry {
|
export class ToolRegistry {
|
||||||
private tools: Map<string, Tool> = new Map();
|
private tools: Map<string, Tool> = new Map();
|
||||||
|
private discovery: Promise<void> | null = null;
|
||||||
private config: Config;
|
private config: Config;
|
||||||
|
|
||||||
constructor(config: Config) {
|
constructor(config: Config) {
|
||||||
|
@ -121,7 +122,7 @@ export class ToolRegistry {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Discovers tools from project, if a discovery command is configured.
|
* Discovers tools from project (if available and configured).
|
||||||
* Can be called multiple times to update discovered tools.
|
* Can be called multiple times to update discovered tools.
|
||||||
*/
|
*/
|
||||||
async discoverTools(): Promise<void> {
|
async discoverTools(): Promise<void> {
|
||||||
|
|
Loading…
Reference in New Issue