fix(core): Sanitize tool parameters to fix 400 API errors (#3300)
This commit is contained in:
parent
5c9372372c
commit
b564d4a088
|
@ -10,6 +10,9 @@
|
|||
"workspaces": [
|
||||
"packages/*"
|
||||
],
|
||||
"dependencies": {
|
||||
"shell-quote": "^1.8.3"
|
||||
},
|
||||
"bin": {
|
||||
"gemini": "bundle/gemini.js"
|
||||
},
|
||||
|
@ -17,6 +20,7 @@
|
|||
"@types/micromatch": "^4.0.9",
|
||||
"@types/mime-types": "^2.1.4",
|
||||
"@types/minimatch": "^5.1.2",
|
||||
"@types/shell-quote": "^1.7.5",
|
||||
"@vitest/coverage-v8": "^3.1.1",
|
||||
"concurrently": "^9.2.0",
|
||||
"cross-env": "^7.0.3",
|
||||
|
|
|
@ -68,6 +68,7 @@
|
|||
"@types/micromatch": "^4.0.9",
|
||||
"@types/mime-types": "^2.1.4",
|
||||
"@types/minimatch": "^5.1.2",
|
||||
"@types/shell-quote": "^1.7.5",
|
||||
"@vitest/coverage-v8": "^3.1.1",
|
||||
"concurrently": "^9.2.0",
|
||||
"cross-env": "^7.0.3",
|
||||
|
|
|
@ -41,14 +41,12 @@ describe('useLoadingIndicator', () => {
|
|||
expect(WITTY_LOADING_PHRASES).toContain(
|
||||
result.current.currentLoadingPhrase,
|
||||
);
|
||||
const initialPhrase = result.current.currentLoadingPhrase;
|
||||
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(PHRASE_CHANGE_INTERVAL_MS + 1);
|
||||
});
|
||||
|
||||
// Phrase should cycle if PHRASE_CHANGE_INTERVAL_MS has passed
|
||||
expect(result.current.currentLoadingPhrase).not.toBe(initialPhrase);
|
||||
expect(WITTY_LOADING_PHRASES).toContain(
|
||||
result.current.currentLoadingPhrase,
|
||||
);
|
||||
|
|
|
@ -39,7 +39,7 @@
|
|||
"ignore": "^7.0.0",
|
||||
"micromatch": "^4.0.8",
|
||||
"open": "^10.1.2",
|
||||
"shell-quote": "^1.8.2",
|
||||
"shell-quote": "^1.8.3",
|
||||
"simple-git": "^3.28.0",
|
||||
"strip-ansi": "^7.1.0",
|
||||
"undici": "^7.10.0",
|
||||
|
|
|
@ -14,7 +14,8 @@ import {
|
|||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import { discoverMcpTools, sanitizeParameters } from './mcp-client.js';
|
||||
import { discoverMcpTools } from './mcp-client.js';
|
||||
import { sanitizeParameters } from './tool-registry.js';
|
||||
import { Schema, Type } from '@google/genai';
|
||||
import { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
@ -85,9 +86,14 @@ const mockToolRegistryInstance = {
|
|||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||
discoverTools: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
vi.mock('./tool-registry.js', () => ({
|
||||
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
|
||||
}));
|
||||
vi.mock('./tool-registry.js', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as any),
|
||||
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
|
||||
sanitizeParameters: (actual as any).sanitizeParameters,
|
||||
};
|
||||
});
|
||||
|
||||
describe('discoverMcpTools', () => {
|
||||
let mockConfig: Mocked<Config>;
|
||||
|
|
|
@ -14,13 +14,8 @@ import {
|
|||
import { parse } from 'shell-quote';
|
||||
import { MCPServerConfig } from '../config/config.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import {
|
||||
CallableTool,
|
||||
FunctionDeclaration,
|
||||
mcpToTool,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
|
@ -384,31 +379,3 @@ async function connectAndDiscover(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a JSON schema object to ensure compatibility with Vertex AI.
|
||||
* This function recursively processes the schema to remove problematic properties
|
||||
* that can cause issues with the Gemini API.
|
||||
*
|
||||
* @param schema The JSON schema object to sanitize (modified in-place)
|
||||
*/
|
||||
export function sanitizeParameters(schema?: Schema) {
|
||||
if (!schema) {
|
||||
return;
|
||||
}
|
||||
if (schema.anyOf) {
|
||||
// Vertex AI gets confused if both anyOf and default are set.
|
||||
schema.default = undefined;
|
||||
for (const item of schema.anyOf) {
|
||||
sanitizeParameters(item);
|
||||
}
|
||||
}
|
||||
if (schema.items) {
|
||||
sanitizeParameters(schema.items);
|
||||
}
|
||||
if (schema.properties) {
|
||||
for (const item of Object.values(schema.properties)) {
|
||||
sanitizeParameters(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,22 +14,22 @@ import {
|
|||
afterEach,
|
||||
Mocked,
|
||||
} from 'vitest';
|
||||
import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import {
|
||||
Config,
|
||||
ConfigParameters,
|
||||
MCPServerConfig,
|
||||
ApprovalMode,
|
||||
} from '../config/config.js';
|
||||
ToolRegistry,
|
||||
DiscoveredTool,
|
||||
sanitizeParameters,
|
||||
} from './tool-registry.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
|
||||
import { BaseTool, ToolResult } from './tools.js';
|
||||
import {
|
||||
FunctionDeclaration,
|
||||
CallableTool,
|
||||
mcpToTool,
|
||||
Type,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
import { execSync } from 'node:child_process';
|
||||
import { spawn } from 'node:child_process';
|
||||
|
||||
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
||||
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
||||
|
@ -61,7 +61,6 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
|||
set onerror(handler: any) {
|
||||
mockMcpClientOnError(handler);
|
||||
},
|
||||
// listTools and callTool are no longer directly used by ToolRegistry/discoverMcpTools
|
||||
}));
|
||||
return { Client: MockClient };
|
||||
});
|
||||
|
@ -90,7 +89,6 @@ vi.mock('@google/genai', async () => {
|
|||
return {
|
||||
...actualGenai,
|
||||
mcpToTool: vi.fn().mockImplementation(() => ({
|
||||
// Default mock implementation
|
||||
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
|
||||
callTool: vi.fn(),
|
||||
})),
|
||||
|
@ -139,6 +137,7 @@ const baseConfigParams: ConfigParameters = {
|
|||
describe('ToolRegistry', () => {
|
||||
let config: Config;
|
||||
let toolRegistry: ToolRegistry;
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
config = new Config(baseConfigParams);
|
||||
|
@ -148,13 +147,19 @@ describe('ToolRegistry', () => {
|
|||
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
|
||||
// Reset mocks for MCP parts
|
||||
mockMcpClientConnect.mockReset().mockResolvedValue(undefined); // Default connect success
|
||||
mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
|
||||
mockStdioTransportClose.mockReset();
|
||||
mockSseTransportClose.mockReset();
|
||||
vi.mocked(mcpToTool).mockClear();
|
||||
// Default mcpToTool to return a callable tool that returns no functions
|
||||
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
|
||||
|
||||
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
||||
config,
|
||||
'getToolDiscoveryCommand',
|
||||
);
|
||||
vi.spyOn(config, 'getMcpServers');
|
||||
vi.spyOn(config, 'getMcpServerCommand');
|
||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
@ -167,21 +172,18 @@ describe('ToolRegistry', () => {
|
|||
toolRegistry.registerTool(tool);
|
||||
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
|
||||
});
|
||||
// ... other registerTool tests
|
||||
});
|
||||
|
||||
describe('getToolsByServer', () => {
|
||||
it('should return an empty array if no tools match the server name', () => {
|
||||
toolRegistry.registerTool(new MockTool()); // A non-MCP tool
|
||||
toolRegistry.registerTool(new MockTool());
|
||||
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return only tools matching the server name', async () => {
|
||||
const server1Name = 'mcp-server-uno';
|
||||
const server2Name = 'mcp-server-dos';
|
||||
|
||||
// Manually register mock MCP tools for this test
|
||||
const mockCallable = {} as CallableTool; // Minimal mock callable
|
||||
const mockCallable = {} as CallableTool;
|
||||
const mcpTool1 = new DiscoveredMCPTool(
|
||||
mockCallable,
|
||||
server1Name,
|
||||
|
@ -207,73 +209,87 @@ describe('ToolRegistry', () => {
|
|||
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
|
||||
expect(toolsFromServer1).toHaveLength(1);
|
||||
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
|
||||
expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe(
|
||||
server1Name,
|
||||
);
|
||||
|
||||
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
|
||||
expect(toolsFromServer2).toHaveLength(1);
|
||||
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
|
||||
expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe(
|
||||
server2Name,
|
||||
);
|
||||
|
||||
expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('discoverTools', () => {
|
||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||
let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>;
|
||||
let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>;
|
||||
let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
||||
config,
|
||||
'getToolDiscoveryCommand',
|
||||
);
|
||||
mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers');
|
||||
mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand');
|
||||
mockExecSync = vi.mocked(execSync);
|
||||
toolRegistry = new ToolRegistry(config); // Reset registry
|
||||
// Reset the mock for discoverMcpTools before each test in this suite
|
||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
||||
it('should discover tools using discovery command', async () => {
|
||||
// ... this test remains largely the same
|
||||
it('should sanitize tool parameters during discovery from command', async () => {
|
||||
const discoveryCommand = 'my-discovery-command';
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
|
||||
const mockToolDeclarations: FunctionDeclaration[] = [
|
||||
{
|
||||
name: 'discovered-tool-1',
|
||||
description: 'A discovered tool',
|
||||
parameters: { type: Type.OBJECT, properties: {} },
|
||||
|
||||
const unsanitizedToolDeclaration: FunctionDeclaration = {
|
||||
name: 'tool-with-bad-format',
|
||||
description: 'A tool with an invalid format property',
|
||||
parameters: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
some_string: {
|
||||
type: Type.STRING,
|
||||
format: 'uuid', // This is an unsupported format
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
mockExecSync.mockReturnValue(
|
||||
Buffer.from(
|
||||
JSON.stringify([{ function_declarations: mockToolDeclarations }]),
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
const mockSpawn = vi.mocked(spawn);
|
||||
const mockChildProcess = {
|
||||
stdout: { on: vi.fn() },
|
||||
stderr: { on: vi.fn() },
|
||||
on: vi.fn(),
|
||||
};
|
||||
mockSpawn.mockReturnValue(mockChildProcess as any);
|
||||
|
||||
// Simulate stdout data
|
||||
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
|
||||
if (event === 'data') {
|
||||
callback(
|
||||
Buffer.from(
|
||||
JSON.stringify([
|
||||
{ function_declarations: [unsanitizedToolDeclaration] },
|
||||
]),
|
||||
),
|
||||
);
|
||||
}
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
// Simulate process close
|
||||
mockChildProcess.on.mockImplementation((event, callback) => {
|
||||
if (event === 'close') {
|
||||
callback(0);
|
||||
}
|
||||
return mockChildProcess as any;
|
||||
});
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(execSync).toHaveBeenCalledWith(discoveryCommand);
|
||||
const discoveredTool = toolRegistry.getTool('discovered-tool-1');
|
||||
expect(discoveredTool).toBeInstanceOf(DiscoveredTool);
|
||||
|
||||
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
|
||||
expect(discoveredTool).toBeDefined();
|
||||
|
||||
const registeredParams = (discoveredTool as DiscoveredTool).schema
|
||||
.parameters as Schema;
|
||||
expect(registeredParams.properties?.['some_string']).toBeDefined();
|
||||
expect(registeredParams.properties?.['some_string']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
mockConfigGetMcpServerCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
} as MCPServerConfig,
|
||||
},
|
||||
};
|
||||
mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal);
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
|
@ -282,56 +298,166 @@ describe('ToolRegistry', () => {
|
|||
undefined,
|
||||
toolRegistry,
|
||||
);
|
||||
// We no longer check these as discoverMcpTools is mocked
|
||||
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
|
||||
// expect(Client).toHaveBeenCalledTimes(1);
|
||||
// expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||
// command: 'mcp-server-cmd',
|
||||
// args: ['--port', '1234'],
|
||||
// env: expect.any(Object),
|
||||
// stderr: 'pipe',
|
||||
// });
|
||||
// expect(mockMcpClientConnect).toHaveBeenCalled();
|
||||
|
||||
// To verify that tools *would* have been registered, we'd need mockDiscoverMcpTools
|
||||
// to call toolRegistry.registerTool, or we test that separately.
|
||||
// For now, we just check that the delegation happened.
|
||||
});
|
||||
|
||||
it('should discover tools using MCP server command from getMcpServerCommand', async () => {
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
mockConfigGetMcpServers.mockReturnValue({});
|
||||
mockConfigGetMcpServerCommand.mockReturnValue(
|
||||
'mcp-server-start-command --param',
|
||||
);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
{},
|
||||
'mcp-server-start-command --param',
|
||||
toolRegistry,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors during MCP client connection gracefully and close transport', async () => {
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
mockConfigGetMcpServers.mockReturnValue({
|
||||
'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig,
|
||||
});
|
||||
|
||||
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
{
|
||||
'failing-mcp': { command: 'fail-cmd' },
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
);
|
||||
expect(toolRegistry.getAllTools()).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
// Other tests for DiscoveredTool and DiscoveredMCPTool can be simplified or removed
|
||||
// if their core logic is now tested in their respective dedicated test files (mcp-tool.test.ts)
|
||||
});
|
||||
|
||||
describe('sanitizeParameters', () => {
|
||||
it('should remove unsupported format from a simple string property', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
name: { type: Type.STRING },
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
|
||||
expect(schema.properties?.['name']).not.toHaveProperty('format');
|
||||
});
|
||||
|
||||
it('should NOT remove supported format values', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
date: { type: Type.STRING, format: 'date-time' },
|
||||
role: {
|
||||
type: Type.STRING,
|
||||
format: 'enum',
|
||||
enum: ['admin', 'user'],
|
||||
},
|
||||
},
|
||||
};
|
||||
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||
sanitizeParameters(schema);
|
||||
expect(schema).toEqual(originalSchema);
|
||||
});
|
||||
|
||||
it('should handle nested objects recursively', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
user: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
email: { type: Type.STRING, format: 'email' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle arrays of objects', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
items: {
|
||||
type: Type.ARRAY,
|
||||
items: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
itemId: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
sanitizeParameters(schema);
|
||||
expect(
|
||||
(schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
|
||||
).toHaveProperty('format', undefined);
|
||||
});
|
||||
|
||||
it('should handle schemas with no properties to sanitize', () => {
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
count: { type: Type.NUMBER },
|
||||
isActive: { type: Type.BOOLEAN },
|
||||
},
|
||||
};
|
||||
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||
sanitizeParameters(schema);
|
||||
expect(schema).toEqual(originalSchema);
|
||||
});
|
||||
|
||||
it('should not crash on an empty or undefined schema', () => {
|
||||
expect(() => sanitizeParameters({})).not.toThrow();
|
||||
expect(() => sanitizeParameters(undefined)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle cyclic schemas without crashing', () => {
|
||||
const schema: any = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
name: { type: Type.STRING, format: 'hostname' },
|
||||
},
|
||||
};
|
||||
schema.properties.self = schema;
|
||||
|
||||
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||
expect(schema.properties.name).toHaveProperty('format', undefined);
|
||||
});
|
||||
|
||||
it('should handle complex nested schemas with cycles', () => {
|
||||
const userNode: any = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
name: { type: Type.STRING },
|
||||
manager: {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
id: { type: Type.STRING, format: 'uuid' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
userNode.properties.reports = {
|
||||
type: Type.ARRAY,
|
||||
items: userNode,
|
||||
};
|
||||
|
||||
const schema: Schema = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
ceo: userNode,
|
||||
},
|
||||
};
|
||||
|
||||
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
|
||||
'format',
|
||||
undefined,
|
||||
);
|
||||
expect(
|
||||
schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
|
||||
).toHaveProperty('format', undefined);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -4,12 +4,14 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { FunctionDeclaration } from '@google/genai';
|
||||
import { FunctionDeclaration, Schema, Type } from '@google/genai';
|
||||
import { Tool, ToolResult, BaseTool } from './tools.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { spawn, execSync } from 'node:child_process';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
import { discoverMcpTools } from './mcp-client.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { parse } from 'shell-quote';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
|
@ -157,32 +159,9 @@ export class ToolRegistry {
|
|||
// Keep manually registered tools
|
||||
}
|
||||
}
|
||||
// discover tools using discovery command, if configured
|
||||
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
||||
if (discoveryCmd) {
|
||||
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
|
||||
const functions: FunctionDeclaration[] = [];
|
||||
for (const tool of JSON.parse(execSync(discoveryCmd).toString().trim())) {
|
||||
if (tool['function_declarations']) {
|
||||
functions.push(...tool['function_declarations']);
|
||||
} else if (tool['functionDeclarations']) {
|
||||
functions.push(...tool['functionDeclarations']);
|
||||
} else if (tool['name']) {
|
||||
functions.push(tool);
|
||||
}
|
||||
}
|
||||
// register each function as a tool
|
||||
for (const func of functions) {
|
||||
this.registerTool(
|
||||
new DiscoveredTool(
|
||||
this.config,
|
||||
func.name!,
|
||||
func.description!,
|
||||
func.parameters! as Record<string, unknown>,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await discoverMcpTools(
|
||||
this.config.getMcpServers() ?? {},
|
||||
|
@ -191,6 +170,128 @@ export class ToolRegistry {
|
|||
);
|
||||
}
|
||||
|
||||
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
||||
if (!discoveryCmd) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const cmdParts = parse(discoveryCmd);
|
||||
if (cmdParts.length === 0) {
|
||||
throw new Error(
|
||||
'Tool discovery command is empty or contains only whitespace.',
|
||||
);
|
||||
}
|
||||
const proc = spawn(cmdParts[0] as string, cmdParts.slice(1) as string[]);
|
||||
let stdout = '';
|
||||
const stdoutDecoder = new StringDecoder('utf8');
|
||||
let stderr = '';
|
||||
const stderrDecoder = new StringDecoder('utf8');
|
||||
let sizeLimitExceeded = false;
|
||||
const MAX_STDOUT_SIZE = 10 * 1024 * 1024; // 10MB limit
|
||||
const MAX_STDERR_SIZE = 10 * 1024 * 1024; // 10MB limit
|
||||
|
||||
let stdoutByteLength = 0;
|
||||
let stderrByteLength = 0;
|
||||
|
||||
proc.stdout.on('data', (data) => {
|
||||
if (sizeLimitExceeded) return;
|
||||
if (stdoutByteLength + data.length > MAX_STDOUT_SIZE) {
|
||||
sizeLimitExceeded = true;
|
||||
proc.kill();
|
||||
return;
|
||||
}
|
||||
stdoutByteLength += data.length;
|
||||
stdout += stdoutDecoder.write(data);
|
||||
});
|
||||
|
||||
proc.stderr.on('data', (data) => {
|
||||
if (sizeLimitExceeded) return;
|
||||
if (stderrByteLength + data.length > MAX_STDERR_SIZE) {
|
||||
sizeLimitExceeded = true;
|
||||
proc.kill();
|
||||
return;
|
||||
}
|
||||
stderrByteLength += data.length;
|
||||
stderr += stderrDecoder.write(data);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
proc.on('error', reject);
|
||||
proc.on('close', (code) => {
|
||||
stdout += stdoutDecoder.end();
|
||||
stderr += stderrDecoder.end();
|
||||
|
||||
if (sizeLimitExceeded) {
|
||||
return reject(
|
||||
new Error(
|
||||
`Tool discovery command output exceeded size limit of ${MAX_STDOUT_SIZE} bytes.`,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (code !== 0) {
|
||||
console.error(`Command failed with code ${code}`);
|
||||
console.error(stderr);
|
||||
return reject(
|
||||
new Error(`Tool discovery command failed with exit code ${code}`),
|
||||
);
|
||||
}
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
|
||||
const functions: FunctionDeclaration[] = [];
|
||||
const discoveredItems = JSON.parse(stdout.trim());
|
||||
|
||||
if (!discoveredItems || !Array.isArray(discoveredItems)) {
|
||||
throw new Error(
|
||||
'Tool discovery command did not return a JSON array of tools.',
|
||||
);
|
||||
}
|
||||
|
||||
for (const tool of discoveredItems) {
|
||||
if (tool && typeof tool === 'object') {
|
||||
if (Array.isArray(tool['function_declarations'])) {
|
||||
functions.push(...tool['function_declarations']);
|
||||
} else if (Array.isArray(tool['functionDeclarations'])) {
|
||||
functions.push(...tool['functionDeclarations']);
|
||||
} else if (tool['name']) {
|
||||
functions.push(tool as FunctionDeclaration);
|
||||
}
|
||||
}
|
||||
}
|
||||
// register each function as a tool
|
||||
for (const func of functions) {
|
||||
if (!func.name) {
|
||||
console.warn('Discovered a tool with no name. Skipping.');
|
||||
continue;
|
||||
}
|
||||
// Sanitize the parameters before registering the tool.
|
||||
const parameters =
|
||||
func.parameters &&
|
||||
typeof func.parameters === 'object' &&
|
||||
!Array.isArray(func.parameters)
|
||||
? (func.parameters as Schema)
|
||||
: {};
|
||||
sanitizeParameters(parameters);
|
||||
this.registerTool(
|
||||
new DiscoveredTool(
|
||||
this.config,
|
||||
func.name,
|
||||
func.description ?? '',
|
||||
parameters as Record<string, unknown>,
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(`Tool discovery command "${discoveryCmd}" failed:`, e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
||||
* Extracts the declarations from the ToolListUnion structure.
|
||||
|
@ -232,3 +333,62 @@ export class ToolRegistry {
|
|||
return this.tools.get(name);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a schema object in-place to ensure compatibility with the Gemini API.
|
||||
*
|
||||
* NOTE: This function mutates the passed schema object.
|
||||
*
|
||||
* It performs the following actions:
|
||||
* - Removes the `default` property when `anyOf` is present.
|
||||
* - Removes unsupported `format` values from string properties, keeping only 'enum' and 'date-time'.
|
||||
* - Recursively sanitizes nested schemas within `anyOf`, `items`, and `properties`.
|
||||
* - Handles circular references within the schema to prevent infinite loops.
|
||||
*
|
||||
* @param schema The schema object to sanitize. It will be modified directly.
|
||||
*/
|
||||
export function sanitizeParameters(schema?: Schema) {
|
||||
_sanitizeParameters(schema, new Set<Schema>());
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal recursive implementation for sanitizeParameters.
|
||||
* @param schema The schema object to sanitize.
|
||||
* @param visited A set used to track visited schema objects during recursion.
|
||||
*/
|
||||
function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
|
||||
if (!schema || visited.has(schema)) {
|
||||
return;
|
||||
}
|
||||
visited.add(schema);
|
||||
|
||||
if (schema.anyOf) {
|
||||
// Vertex AI gets confused if both anyOf and default are set.
|
||||
schema.default = undefined;
|
||||
for (const item of schema.anyOf) {
|
||||
if (typeof item !== 'boolean') {
|
||||
_sanitizeParameters(item, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (schema.items && typeof schema.items !== 'boolean') {
|
||||
_sanitizeParameters(schema.items, visited);
|
||||
}
|
||||
if (schema.properties) {
|
||||
for (const item of Object.values(schema.properties)) {
|
||||
if (typeof item !== 'boolean') {
|
||||
_sanitizeParameters(item, visited);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
|
||||
if (schema.type === Type.STRING) {
|
||||
if (
|
||||
schema.format &&
|
||||
schema.format !== 'enum' &&
|
||||
schema.format !== 'date-time'
|
||||
) {
|
||||
schema.format = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue