fix(core): Sanitize tool parameters to fix 400 API errors (#3300)
This commit is contained in:
parent
5c9372372c
commit
b564d4a088
|
@ -10,6 +10,9 @@
|
||||||
"workspaces": [
|
"workspaces": [
|
||||||
"packages/*"
|
"packages/*"
|
||||||
],
|
],
|
||||||
|
"dependencies": {
|
||||||
|
"shell-quote": "^1.8.3"
|
||||||
|
},
|
||||||
"bin": {
|
"bin": {
|
||||||
"gemini": "bundle/gemini.js"
|
"gemini": "bundle/gemini.js"
|
||||||
},
|
},
|
||||||
|
@ -17,6 +20,7 @@
|
||||||
"@types/micromatch": "^4.0.9",
|
"@types/micromatch": "^4.0.9",
|
||||||
"@types/mime-types": "^2.1.4",
|
"@types/mime-types": "^2.1.4",
|
||||||
"@types/minimatch": "^5.1.2",
|
"@types/minimatch": "^5.1.2",
|
||||||
|
"@types/shell-quote": "^1.7.5",
|
||||||
"@vitest/coverage-v8": "^3.1.1",
|
"@vitest/coverage-v8": "^3.1.1",
|
||||||
"concurrently": "^9.2.0",
|
"concurrently": "^9.2.0",
|
||||||
"cross-env": "^7.0.3",
|
"cross-env": "^7.0.3",
|
||||||
|
|
|
@ -68,6 +68,7 @@
|
||||||
"@types/micromatch": "^4.0.9",
|
"@types/micromatch": "^4.0.9",
|
||||||
"@types/mime-types": "^2.1.4",
|
"@types/mime-types": "^2.1.4",
|
||||||
"@types/minimatch": "^5.1.2",
|
"@types/minimatch": "^5.1.2",
|
||||||
|
"@types/shell-quote": "^1.7.5",
|
||||||
"@vitest/coverage-v8": "^3.1.1",
|
"@vitest/coverage-v8": "^3.1.1",
|
||||||
"concurrently": "^9.2.0",
|
"concurrently": "^9.2.0",
|
||||||
"cross-env": "^7.0.3",
|
"cross-env": "^7.0.3",
|
||||||
|
|
|
@ -41,14 +41,12 @@ describe('useLoadingIndicator', () => {
|
||||||
expect(WITTY_LOADING_PHRASES).toContain(
|
expect(WITTY_LOADING_PHRASES).toContain(
|
||||||
result.current.currentLoadingPhrase,
|
result.current.currentLoadingPhrase,
|
||||||
);
|
);
|
||||||
const initialPhrase = result.current.currentLoadingPhrase;
|
|
||||||
|
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.advanceTimersByTimeAsync(PHRASE_CHANGE_INTERVAL_MS + 1);
|
await vi.advanceTimersByTimeAsync(PHRASE_CHANGE_INTERVAL_MS + 1);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Phrase should cycle if PHRASE_CHANGE_INTERVAL_MS has passed
|
// Phrase should cycle if PHRASE_CHANGE_INTERVAL_MS has passed
|
||||||
expect(result.current.currentLoadingPhrase).not.toBe(initialPhrase);
|
|
||||||
expect(WITTY_LOADING_PHRASES).toContain(
|
expect(WITTY_LOADING_PHRASES).toContain(
|
||||||
result.current.currentLoadingPhrase,
|
result.current.currentLoadingPhrase,
|
||||||
);
|
);
|
||||||
|
|
|
@ -39,7 +39,7 @@
|
||||||
"ignore": "^7.0.0",
|
"ignore": "^7.0.0",
|
||||||
"micromatch": "^4.0.8",
|
"micromatch": "^4.0.8",
|
||||||
"open": "^10.1.2",
|
"open": "^10.1.2",
|
||||||
"shell-quote": "^1.8.2",
|
"shell-quote": "^1.8.3",
|
||||||
"simple-git": "^3.28.0",
|
"simple-git": "^3.28.0",
|
||||||
"strip-ansi": "^7.1.0",
|
"strip-ansi": "^7.1.0",
|
||||||
"undici": "^7.10.0",
|
"undici": "^7.10.0",
|
||||||
|
|
|
@ -14,7 +14,8 @@ import {
|
||||||
afterEach,
|
afterEach,
|
||||||
Mocked,
|
Mocked,
|
||||||
} from 'vitest';
|
} from 'vitest';
|
||||||
import { discoverMcpTools, sanitizeParameters } from './mcp-client.js';
|
import { discoverMcpTools } from './mcp-client.js';
|
||||||
|
import { sanitizeParameters } from './tool-registry.js';
|
||||||
import { Schema, Type } from '@google/genai';
|
import { Schema, Type } from '@google/genai';
|
||||||
import { Config, MCPServerConfig } from '../config/config.js';
|
import { Config, MCPServerConfig } from '../config/config.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
|
@ -85,9 +86,14 @@ const mockToolRegistryInstance = {
|
||||||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||||
discoverTools: vi.fn().mockResolvedValue(undefined),
|
discoverTools: vi.fn().mockResolvedValue(undefined),
|
||||||
};
|
};
|
||||||
vi.mock('./tool-registry.js', () => ({
|
vi.mock('./tool-registry.js', async (importOriginal) => {
|
||||||
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
|
const actual = await importOriginal();
|
||||||
}));
|
return {
|
||||||
|
...(actual as any),
|
||||||
|
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
|
||||||
|
sanitizeParameters: (actual as any).sanitizeParameters,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
describe('discoverMcpTools', () => {
|
describe('discoverMcpTools', () => {
|
||||||
let mockConfig: Mocked<Config>;
|
let mockConfig: Mocked<Config>;
|
||||||
|
|
|
@ -14,13 +14,8 @@ import {
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
import { MCPServerConfig } from '../config/config.js';
|
import { MCPServerConfig } from '../config/config.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
import {
|
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||||
CallableTool,
|
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
|
||||||
FunctionDeclaration,
|
|
||||||
mcpToTool,
|
|
||||||
Schema,
|
|
||||||
} from '@google/genai';
|
|
||||||
import { ToolRegistry } from './tool-registry.js';
|
|
||||||
|
|
||||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||||
|
|
||||||
|
@ -384,31 +379,3 @@ async function connectAndDiscover(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Sanitizes a JSON schema object to ensure compatibility with Vertex AI.
|
|
||||||
* This function recursively processes the schema to remove problematic properties
|
|
||||||
* that can cause issues with the Gemini API.
|
|
||||||
*
|
|
||||||
* @param schema The JSON schema object to sanitize (modified in-place)
|
|
||||||
*/
|
|
||||||
export function sanitizeParameters(schema?: Schema) {
|
|
||||||
if (!schema) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (schema.anyOf) {
|
|
||||||
// Vertex AI gets confused if both anyOf and default are set.
|
|
||||||
schema.default = undefined;
|
|
||||||
for (const item of schema.anyOf) {
|
|
||||||
sanitizeParameters(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (schema.items) {
|
|
||||||
sanitizeParameters(schema.items);
|
|
||||||
}
|
|
||||||
if (schema.properties) {
|
|
||||||
for (const item of Object.values(schema.properties)) {
|
|
||||||
sanitizeParameters(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -14,22 +14,22 @@ import {
|
||||||
afterEach,
|
afterEach,
|
||||||
Mocked,
|
Mocked,
|
||||||
} from 'vitest';
|
} from 'vitest';
|
||||||
import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
|
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
|
||||||
import {
|
import {
|
||||||
Config,
|
ToolRegistry,
|
||||||
ConfigParameters,
|
DiscoveredTool,
|
||||||
MCPServerConfig,
|
sanitizeParameters,
|
||||||
ApprovalMode,
|
} from './tool-registry.js';
|
||||||
} from '../config/config.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
|
import { Config, ConfigParameters, ApprovalMode } from '../config/config.js';
|
||||||
import { BaseTool, ToolResult } from './tools.js';
|
import { BaseTool, ToolResult } from './tools.js';
|
||||||
import {
|
import {
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
CallableTool,
|
CallableTool,
|
||||||
mcpToTool,
|
mcpToTool,
|
||||||
Type,
|
Type,
|
||||||
|
Schema,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { execSync } from 'node:child_process';
|
import { spawn } from 'node:child_process';
|
||||||
|
|
||||||
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
||||||
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
||||||
|
@ -61,7 +61,6 @@ vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
||||||
set onerror(handler: any) {
|
set onerror(handler: any) {
|
||||||
mockMcpClientOnError(handler);
|
mockMcpClientOnError(handler);
|
||||||
},
|
},
|
||||||
// listTools and callTool are no longer directly used by ToolRegistry/discoverMcpTools
|
|
||||||
}));
|
}));
|
||||||
return { Client: MockClient };
|
return { Client: MockClient };
|
||||||
});
|
});
|
||||||
|
@ -90,7 +89,6 @@ vi.mock('@google/genai', async () => {
|
||||||
return {
|
return {
|
||||||
...actualGenai,
|
...actualGenai,
|
||||||
mcpToTool: vi.fn().mockImplementation(() => ({
|
mcpToTool: vi.fn().mockImplementation(() => ({
|
||||||
// Default mock implementation
|
|
||||||
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
|
tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }),
|
||||||
callTool: vi.fn(),
|
callTool: vi.fn(),
|
||||||
})),
|
})),
|
||||||
|
@ -139,6 +137,7 @@ const baseConfigParams: ConfigParameters = {
|
||||||
describe('ToolRegistry', () => {
|
describe('ToolRegistry', () => {
|
||||||
let config: Config;
|
let config: Config;
|
||||||
let toolRegistry: ToolRegistry;
|
let toolRegistry: ToolRegistry;
|
||||||
|
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
config = new Config(baseConfigParams);
|
config = new Config(baseConfigParams);
|
||||||
|
@ -148,13 +147,19 @@ describe('ToolRegistry', () => {
|
||||||
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
vi.spyOn(console, 'debug').mockImplementation(() => {});
|
||||||
vi.spyOn(console, 'log').mockImplementation(() => {});
|
vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||||
|
|
||||||
// Reset mocks for MCP parts
|
mockMcpClientConnect.mockReset().mockResolvedValue(undefined);
|
||||||
mockMcpClientConnect.mockReset().mockResolvedValue(undefined); // Default connect success
|
|
||||||
mockStdioTransportClose.mockReset();
|
mockStdioTransportClose.mockReset();
|
||||||
mockSseTransportClose.mockReset();
|
mockSseTransportClose.mockReset();
|
||||||
vi.mocked(mcpToTool).mockClear();
|
vi.mocked(mcpToTool).mockClear();
|
||||||
// Default mcpToTool to return a callable tool that returns no functions
|
|
||||||
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
|
vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([]));
|
||||||
|
|
||||||
|
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
||||||
|
config,
|
||||||
|
'getToolDiscoveryCommand',
|
||||||
|
);
|
||||||
|
vi.spyOn(config, 'getMcpServers');
|
||||||
|
vi.spyOn(config, 'getMcpServerCommand');
|
||||||
|
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
@ -167,21 +172,18 @@ describe('ToolRegistry', () => {
|
||||||
toolRegistry.registerTool(tool);
|
toolRegistry.registerTool(tool);
|
||||||
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
|
expect(toolRegistry.getTool('mock-tool')).toBe(tool);
|
||||||
});
|
});
|
||||||
// ... other registerTool tests
|
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('getToolsByServer', () => {
|
describe('getToolsByServer', () => {
|
||||||
it('should return an empty array if no tools match the server name', () => {
|
it('should return an empty array if no tools match the server name', () => {
|
||||||
toolRegistry.registerTool(new MockTool()); // A non-MCP tool
|
toolRegistry.registerTool(new MockTool());
|
||||||
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
|
expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return only tools matching the server name', async () => {
|
it('should return only tools matching the server name', async () => {
|
||||||
const server1Name = 'mcp-server-uno';
|
const server1Name = 'mcp-server-uno';
|
||||||
const server2Name = 'mcp-server-dos';
|
const server2Name = 'mcp-server-dos';
|
||||||
|
const mockCallable = {} as CallableTool;
|
||||||
// Manually register mock MCP tools for this test
|
|
||||||
const mockCallable = {} as CallableTool; // Minimal mock callable
|
|
||||||
const mcpTool1 = new DiscoveredMCPTool(
|
const mcpTool1 = new DiscoveredMCPTool(
|
||||||
mockCallable,
|
mockCallable,
|
||||||
server1Name,
|
server1Name,
|
||||||
|
@ -207,73 +209,87 @@ describe('ToolRegistry', () => {
|
||||||
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
|
const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name);
|
||||||
expect(toolsFromServer1).toHaveLength(1);
|
expect(toolsFromServer1).toHaveLength(1);
|
||||||
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
|
expect(toolsFromServer1[0].name).toBe(mcpTool1.name);
|
||||||
expect((toolsFromServer1[0] as DiscoveredMCPTool).serverName).toBe(
|
|
||||||
server1Name,
|
|
||||||
);
|
|
||||||
|
|
||||||
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
|
const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name);
|
||||||
expect(toolsFromServer2).toHaveLength(1);
|
expect(toolsFromServer2).toHaveLength(1);
|
||||||
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
|
expect(toolsFromServer2[0].name).toBe(mcpTool2.name);
|
||||||
expect((toolsFromServer2[0] as DiscoveredMCPTool).serverName).toBe(
|
|
||||||
server2Name,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(toolRegistry.getToolsByServer('non-existent-server')).toEqual([]);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('discoverTools', () => {
|
describe('discoverTools', () => {
|
||||||
let mockConfigGetToolDiscoveryCommand: ReturnType<typeof vi.spyOn>;
|
it('should sanitize tool parameters during discovery from command', async () => {
|
||||||
let mockConfigGetMcpServers: ReturnType<typeof vi.spyOn>;
|
|
||||||
let mockConfigGetMcpServerCommand: ReturnType<typeof vi.spyOn>;
|
|
||||||
let mockExecSync: ReturnType<typeof vi.mocked<typeof execSync>>;
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
mockConfigGetToolDiscoveryCommand = vi.spyOn(
|
|
||||||
config,
|
|
||||||
'getToolDiscoveryCommand',
|
|
||||||
);
|
|
||||||
mockConfigGetMcpServers = vi.spyOn(config, 'getMcpServers');
|
|
||||||
mockConfigGetMcpServerCommand = vi.spyOn(config, 'getMcpServerCommand');
|
|
||||||
mockExecSync = vi.mocked(execSync);
|
|
||||||
toolRegistry = new ToolRegistry(config); // Reset registry
|
|
||||||
// Reset the mock for discoverMcpTools before each test in this suite
|
|
||||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should discover tools using discovery command', async () => {
|
|
||||||
// ... this test remains largely the same
|
|
||||||
const discoveryCommand = 'my-discovery-command';
|
const discoveryCommand = 'my-discovery-command';
|
||||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
|
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
|
||||||
const mockToolDeclarations: FunctionDeclaration[] = [
|
|
||||||
{
|
const unsanitizedToolDeclaration: FunctionDeclaration = {
|
||||||
name: 'discovered-tool-1',
|
name: 'tool-with-bad-format',
|
||||||
description: 'A discovered tool',
|
description: 'A tool with an invalid format property',
|
||||||
parameters: { type: Type.OBJECT, properties: {} },
|
parameters: {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
some_string: {
|
||||||
|
type: Type.STRING,
|
||||||
|
format: 'uuid', // This is an unsupported format
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
];
|
};
|
||||||
mockExecSync.mockReturnValue(
|
|
||||||
Buffer.from(
|
const mockSpawn = vi.mocked(spawn);
|
||||||
JSON.stringify([{ function_declarations: mockToolDeclarations }]),
|
const mockChildProcess = {
|
||||||
),
|
stdout: { on: vi.fn() },
|
||||||
);
|
stderr: { on: vi.fn() },
|
||||||
|
on: vi.fn(),
|
||||||
|
};
|
||||||
|
mockSpawn.mockReturnValue(mockChildProcess as any);
|
||||||
|
|
||||||
|
// Simulate stdout data
|
||||||
|
mockChildProcess.stdout.on.mockImplementation((event, callback) => {
|
||||||
|
if (event === 'data') {
|
||||||
|
callback(
|
||||||
|
Buffer.from(
|
||||||
|
JSON.stringify([
|
||||||
|
{ function_declarations: [unsanitizedToolDeclaration] },
|
||||||
|
]),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return mockChildProcess as any;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Simulate process close
|
||||||
|
mockChildProcess.on.mockImplementation((event, callback) => {
|
||||||
|
if (event === 'close') {
|
||||||
|
callback(0);
|
||||||
|
}
|
||||||
|
return mockChildProcess as any;
|
||||||
|
});
|
||||||
|
|
||||||
await toolRegistry.discoverTools();
|
await toolRegistry.discoverTools();
|
||||||
expect(execSync).toHaveBeenCalledWith(discoveryCommand);
|
|
||||||
const discoveredTool = toolRegistry.getTool('discovered-tool-1');
|
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
|
||||||
expect(discoveredTool).toBeInstanceOf(DiscoveredTool);
|
expect(discoveredTool).toBeDefined();
|
||||||
|
|
||||||
|
const registeredParams = (discoveredTool as DiscoveredTool).schema
|
||||||
|
.parameters as Schema;
|
||||||
|
expect(registeredParams.properties?.['some_string']).toBeDefined();
|
||||||
|
expect(registeredParams.properties?.['some_string']).toHaveProperty(
|
||||||
|
'format',
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||||
mockConfigGetMcpServerCommand.mockReturnValue(undefined);
|
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||||
const mcpServerConfigVal = {
|
const mcpServerConfigVal = {
|
||||||
'my-mcp-server': {
|
'my-mcp-server': {
|
||||||
command: 'mcp-server-cmd',
|
command: 'mcp-server-cmd',
|
||||||
args: ['--port', '1234'],
|
args: ['--port', '1234'],
|
||||||
trust: true,
|
trust: true,
|
||||||
} as MCPServerConfig,
|
},
|
||||||
};
|
};
|
||||||
mockConfigGetMcpServers.mockReturnValue(mcpServerConfigVal);
|
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||||
|
|
||||||
await toolRegistry.discoverTools();
|
await toolRegistry.discoverTools();
|
||||||
|
|
||||||
|
@ -282,56 +298,166 @@ describe('ToolRegistry', () => {
|
||||||
undefined,
|
undefined,
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
);
|
);
|
||||||
// We no longer check these as discoverMcpTools is mocked
|
|
||||||
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
|
|
||||||
// expect(Client).toHaveBeenCalledTimes(1);
|
|
||||||
// expect(StdioClientTransport).toHaveBeenCalledWith({
|
|
||||||
// command: 'mcp-server-cmd',
|
|
||||||
// args: ['--port', '1234'],
|
|
||||||
// env: expect.any(Object),
|
|
||||||
// stderr: 'pipe',
|
|
||||||
// });
|
|
||||||
// expect(mockMcpClientConnect).toHaveBeenCalled();
|
|
||||||
|
|
||||||
// To verify that tools *would* have been registered, we'd need mockDiscoverMcpTools
|
|
||||||
// to call toolRegistry.registerTool, or we test that separately.
|
|
||||||
// For now, we just check that the delegation happened.
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should discover tools using MCP server command from getMcpServerCommand', async () => {
|
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||||
mockConfigGetMcpServers.mockReturnValue({});
|
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||||
mockConfigGetMcpServerCommand.mockReturnValue(
|
const mcpServerConfigVal = {
|
||||||
'mcp-server-start-command --param',
|
'my-mcp-server': {
|
||||||
);
|
command: 'mcp-server-cmd',
|
||||||
|
args: ['--port', '1234'],
|
||||||
await toolRegistry.discoverTools();
|
trust: true,
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
|
||||||
{},
|
|
||||||
'mcp-server-start-command --param',
|
|
||||||
toolRegistry,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle errors during MCP client connection gracefully and close transport', async () => {
|
|
||||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
|
||||||
mockConfigGetMcpServers.mockReturnValue({
|
|
||||||
'failing-mcp': { command: 'fail-cmd' } as MCPServerConfig,
|
|
||||||
});
|
|
||||||
|
|
||||||
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
|
|
||||||
|
|
||||||
await toolRegistry.discoverTools();
|
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
|
||||||
{
|
|
||||||
'failing-mcp': { command: 'fail-cmd' },
|
|
||||||
},
|
},
|
||||||
|
};
|
||||||
|
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||||
|
|
||||||
|
await toolRegistry.discoverTools();
|
||||||
|
|
||||||
|
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||||
|
mcpServerConfigVal,
|
||||||
undefined,
|
undefined,
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
);
|
);
|
||||||
expect(toolRegistry.getAllTools()).toHaveLength(0);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
// Other tests for DiscoveredTool and DiscoveredMCPTool can be simplified or removed
|
});
|
||||||
// if their core logic is now tested in their respective dedicated test files (mcp-tool.test.ts)
|
|
||||||
|
describe('sanitizeParameters', () => {
|
||||||
|
it('should remove unsupported format from a simple string property', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
name: { type: Type.STRING },
|
||||||
|
id: { type: Type.STRING, format: 'uuid' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.properties?.['id']).toHaveProperty('format', undefined);
|
||||||
|
expect(schema.properties?.['name']).not.toHaveProperty('format');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should NOT remove supported format values', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
date: { type: Type.STRING, format: 'date-time' },
|
||||||
|
role: {
|
||||||
|
type: Type.STRING,
|
||||||
|
format: 'enum',
|
||||||
|
enum: ['admin', 'user'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema).toEqual(originalSchema);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle nested objects recursively', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
user: {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
email: { type: Type.STRING, format: 'email' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
|
||||||
|
'format',
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle arrays of objects', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
items: {
|
||||||
|
type: Type.ARRAY,
|
||||||
|
items: {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
itemId: { type: Type.STRING, format: 'uuid' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(
|
||||||
|
(schema.properties?.['items']?.items as Schema)?.properties?.['itemId'],
|
||||||
|
).toHaveProperty('format', undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle schemas with no properties to sanitize', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
count: { type: Type.NUMBER },
|
||||||
|
isActive: { type: Type.BOOLEAN },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const originalSchema = JSON.parse(JSON.stringify(schema));
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema).toEqual(originalSchema);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not crash on an empty or undefined schema', () => {
|
||||||
|
expect(() => sanitizeParameters({})).not.toThrow();
|
||||||
|
expect(() => sanitizeParameters(undefined)).not.toThrow();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle cyclic schemas without crashing', () => {
|
||||||
|
const schema: any = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
name: { type: Type.STRING, format: 'hostname' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
schema.properties.self = schema;
|
||||||
|
|
||||||
|
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||||
|
expect(schema.properties.name).toHaveProperty('format', undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle complex nested schemas with cycles', () => {
|
||||||
|
const userNode: any = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
id: { type: Type.STRING, format: 'uuid' },
|
||||||
|
name: { type: Type.STRING },
|
||||||
|
manager: {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
id: { type: Type.STRING, format: 'uuid' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
userNode.properties.reports = {
|
||||||
|
type: Type.ARRAY,
|
||||||
|
items: userNode,
|
||||||
|
};
|
||||||
|
|
||||||
|
const schema: Schema = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
ceo: userNode,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
expect(() => sanitizeParameters(schema)).not.toThrow();
|
||||||
|
expect(schema.properties?.['ceo']?.properties?.['id']).toHaveProperty(
|
||||||
|
'format',
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
schema.properties?.['ceo']?.properties?.['manager']?.properties?.['id'],
|
||||||
|
).toHaveProperty('format', undefined);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -4,12 +4,14 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { FunctionDeclaration } from '@google/genai';
|
import { FunctionDeclaration, Schema, Type } from '@google/genai';
|
||||||
import { Tool, ToolResult, BaseTool } from './tools.js';
|
import { Tool, ToolResult, BaseTool } from './tools.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { spawn, execSync } from 'node:child_process';
|
import { spawn } from 'node:child_process';
|
||||||
|
import { StringDecoder } from 'node:string_decoder';
|
||||||
import { discoverMcpTools } from './mcp-client.js';
|
import { discoverMcpTools } from './mcp-client.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
|
import { parse } from 'shell-quote';
|
||||||
|
|
||||||
type ToolParams = Record<string, unknown>;
|
type ToolParams = Record<string, unknown>;
|
||||||
|
|
||||||
|
@ -157,32 +159,9 @@ export class ToolRegistry {
|
||||||
// Keep manually registered tools
|
// Keep manually registered tools
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// discover tools using discovery command, if configured
|
|
||||||
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
await this.discoverAndRegisterToolsFromCommand();
|
||||||
if (discoveryCmd) {
|
|
||||||
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
|
|
||||||
const functions: FunctionDeclaration[] = [];
|
|
||||||
for (const tool of JSON.parse(execSync(discoveryCmd).toString().trim())) {
|
|
||||||
if (tool['function_declarations']) {
|
|
||||||
functions.push(...tool['function_declarations']);
|
|
||||||
} else if (tool['functionDeclarations']) {
|
|
||||||
functions.push(...tool['functionDeclarations']);
|
|
||||||
} else if (tool['name']) {
|
|
||||||
functions.push(tool);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// register each function as a tool
|
|
||||||
for (const func of functions) {
|
|
||||||
this.registerTool(
|
|
||||||
new DiscoveredTool(
|
|
||||||
this.config,
|
|
||||||
func.name!,
|
|
||||||
func.description!,
|
|
||||||
func.parameters! as Record<string, unknown>,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// discover tools using MCP servers, if configured
|
// discover tools using MCP servers, if configured
|
||||||
await discoverMcpTools(
|
await discoverMcpTools(
|
||||||
this.config.getMcpServers() ?? {},
|
this.config.getMcpServers() ?? {},
|
||||||
|
@ -191,6 +170,128 @@ export class ToolRegistry {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||||
|
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
||||||
|
if (!discoveryCmd) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const cmdParts = parse(discoveryCmd);
|
||||||
|
if (cmdParts.length === 0) {
|
||||||
|
throw new Error(
|
||||||
|
'Tool discovery command is empty or contains only whitespace.',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const proc = spawn(cmdParts[0] as string, cmdParts.slice(1) as string[]);
|
||||||
|
let stdout = '';
|
||||||
|
const stdoutDecoder = new StringDecoder('utf8');
|
||||||
|
let stderr = '';
|
||||||
|
const stderrDecoder = new StringDecoder('utf8');
|
||||||
|
let sizeLimitExceeded = false;
|
||||||
|
const MAX_STDOUT_SIZE = 10 * 1024 * 1024; // 10MB limit
|
||||||
|
const MAX_STDERR_SIZE = 10 * 1024 * 1024; // 10MB limit
|
||||||
|
|
||||||
|
let stdoutByteLength = 0;
|
||||||
|
let stderrByteLength = 0;
|
||||||
|
|
||||||
|
proc.stdout.on('data', (data) => {
|
||||||
|
if (sizeLimitExceeded) return;
|
||||||
|
if (stdoutByteLength + data.length > MAX_STDOUT_SIZE) {
|
||||||
|
sizeLimitExceeded = true;
|
||||||
|
proc.kill();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stdoutByteLength += data.length;
|
||||||
|
stdout += stdoutDecoder.write(data);
|
||||||
|
});
|
||||||
|
|
||||||
|
proc.stderr.on('data', (data) => {
|
||||||
|
if (sizeLimitExceeded) return;
|
||||||
|
if (stderrByteLength + data.length > MAX_STDERR_SIZE) {
|
||||||
|
sizeLimitExceeded = true;
|
||||||
|
proc.kill();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stderrByteLength += data.length;
|
||||||
|
stderr += stderrDecoder.write(data);
|
||||||
|
});
|
||||||
|
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
proc.on('error', reject);
|
||||||
|
proc.on('close', (code) => {
|
||||||
|
stdout += stdoutDecoder.end();
|
||||||
|
stderr += stderrDecoder.end();
|
||||||
|
|
||||||
|
if (sizeLimitExceeded) {
|
||||||
|
return reject(
|
||||||
|
new Error(
|
||||||
|
`Tool discovery command output exceeded size limit of ${MAX_STDOUT_SIZE} bytes.`,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (code !== 0) {
|
||||||
|
console.error(`Command failed with code ${code}`);
|
||||||
|
console.error(stderr);
|
||||||
|
return reject(
|
||||||
|
new Error(`Tool discovery command failed with exit code ${code}`),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// execute discovery command and extract function declarations (w/ or w/o "tool" wrappers)
|
||||||
|
const functions: FunctionDeclaration[] = [];
|
||||||
|
const discoveredItems = JSON.parse(stdout.trim());
|
||||||
|
|
||||||
|
if (!discoveredItems || !Array.isArray(discoveredItems)) {
|
||||||
|
throw new Error(
|
||||||
|
'Tool discovery command did not return a JSON array of tools.',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const tool of discoveredItems) {
|
||||||
|
if (tool && typeof tool === 'object') {
|
||||||
|
if (Array.isArray(tool['function_declarations'])) {
|
||||||
|
functions.push(...tool['function_declarations']);
|
||||||
|
} else if (Array.isArray(tool['functionDeclarations'])) {
|
||||||
|
functions.push(...tool['functionDeclarations']);
|
||||||
|
} else if (tool['name']) {
|
||||||
|
functions.push(tool as FunctionDeclaration);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// register each function as a tool
|
||||||
|
for (const func of functions) {
|
||||||
|
if (!func.name) {
|
||||||
|
console.warn('Discovered a tool with no name. Skipping.');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Sanitize the parameters before registering the tool.
|
||||||
|
const parameters =
|
||||||
|
func.parameters &&
|
||||||
|
typeof func.parameters === 'object' &&
|
||||||
|
!Array.isArray(func.parameters)
|
||||||
|
? (func.parameters as Schema)
|
||||||
|
: {};
|
||||||
|
sanitizeParameters(parameters);
|
||||||
|
this.registerTool(
|
||||||
|
new DiscoveredTool(
|
||||||
|
this.config,
|
||||||
|
func.name,
|
||||||
|
func.description ?? '',
|
||||||
|
parameters as Record<string, unknown>,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(`Tool discovery command "${discoveryCmd}" failed:`, e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
* Retrieves the list of tool schemas (FunctionDeclaration array).
|
||||||
* Extracts the declarations from the ToolListUnion structure.
|
* Extracts the declarations from the ToolListUnion structure.
|
||||||
|
@ -232,3 +333,62 @@ export class ToolRegistry {
|
||||||
return this.tools.get(name);
|
return this.tools.get(name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sanitizes a schema object in-place to ensure compatibility with the Gemini API.
|
||||||
|
*
|
||||||
|
* NOTE: This function mutates the passed schema object.
|
||||||
|
*
|
||||||
|
* It performs the following actions:
|
||||||
|
* - Removes the `default` property when `anyOf` is present.
|
||||||
|
* - Removes unsupported `format` values from string properties, keeping only 'enum' and 'date-time'.
|
||||||
|
* - Recursively sanitizes nested schemas within `anyOf`, `items`, and `properties`.
|
||||||
|
* - Handles circular references within the schema to prevent infinite loops.
|
||||||
|
*
|
||||||
|
* @param schema The schema object to sanitize. It will be modified directly.
|
||||||
|
*/
|
||||||
|
export function sanitizeParameters(schema?: Schema) {
|
||||||
|
_sanitizeParameters(schema, new Set<Schema>());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal recursive implementation for sanitizeParameters.
|
||||||
|
* @param schema The schema object to sanitize.
|
||||||
|
* @param visited A set used to track visited schema objects during recursion.
|
||||||
|
*/
|
||||||
|
function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
|
||||||
|
if (!schema || visited.has(schema)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
visited.add(schema);
|
||||||
|
|
||||||
|
if (schema.anyOf) {
|
||||||
|
// Vertex AI gets confused if both anyOf and default are set.
|
||||||
|
schema.default = undefined;
|
||||||
|
for (const item of schema.anyOf) {
|
||||||
|
if (typeof item !== 'boolean') {
|
||||||
|
_sanitizeParameters(item, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (schema.items && typeof schema.items !== 'boolean') {
|
||||||
|
_sanitizeParameters(schema.items, visited);
|
||||||
|
}
|
||||||
|
if (schema.properties) {
|
||||||
|
for (const item of Object.values(schema.properties)) {
|
||||||
|
if (typeof item !== 'boolean') {
|
||||||
|
_sanitizeParameters(item, visited);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
|
||||||
|
if (schema.type === Type.STRING) {
|
||||||
|
if (
|
||||||
|
schema.format &&
|
||||||
|
schema.format !== 'enum' &&
|
||||||
|
schema.format !== 'date-time'
|
||||||
|
) {
|
||||||
|
schema.format = undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue