check for the prompt capability before listing prompts from MCP servers (#5616)
Co-authored-by: Jacob Richman <jacob314@gmail.com> Co-authored-by: Sandy Tao <sandytao520@icloud.com>
This commit is contained in:
parent
aebe3ace3c
commit
6a72cd064b
|
@ -58,9 +58,7 @@ describe('mcp-client', () => {
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
const mockedClient = {} as unknown as ClientLib.Client;
|
||||||
const consoleErrorSpy = vi
|
const consoleErrorSpy = vi
|
||||||
.spyOn(console, 'error')
|
.spyOn(console, 'error')
|
||||||
.mockImplementation(() => {
|
.mockImplementation(() => {});
|
||||||
// no-op
|
|
||||||
});
|
|
||||||
|
|
||||||
const testError = new Error('Invalid tool name');
|
const testError = new Error('Invalid tool name');
|
||||||
vi.mocked(DiscoveredMCPTool).mockImplementation(
|
vi.mocked(DiscoveredMCPTool).mockImplementation(
|
||||||
|
@ -113,12 +111,17 @@ describe('mcp-client', () => {
|
||||||
{ name: 'prompt2' },
|
{ name: 'prompt2' },
|
||||||
],
|
],
|
||||||
});
|
});
|
||||||
|
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
||||||
|
prompts: {},
|
||||||
|
});
|
||||||
const mockedClient = {
|
const mockedClient = {
|
||||||
|
getServerCapabilities: mockGetServerCapabilities,
|
||||||
request: mockRequest,
|
request: mockRequest,
|
||||||
} as unknown as ClientLib.Client;
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
||||||
expect(mockRequest).toHaveBeenCalledWith(
|
expect(mockRequest).toHaveBeenCalledWith(
|
||||||
{ method: 'prompts/list', params: {} },
|
{ method: 'prompts/list', params: {} },
|
||||||
expect.anything(),
|
expect.anything(),
|
||||||
|
@ -129,37 +132,67 @@ describe('mcp-client', () => {
|
||||||
const mockRequest = vi.fn().mockResolvedValue({
|
const mockRequest = vi.fn().mockResolvedValue({
|
||||||
prompts: [],
|
prompts: [],
|
||||||
});
|
});
|
||||||
|
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
||||||
|
prompts: {},
|
||||||
|
});
|
||||||
|
|
||||||
const mockedClient = {
|
const mockedClient = {
|
||||||
|
getServerCapabilities: mockGetServerCapabilities,
|
||||||
request: mockRequest,
|
request: mockRequest,
|
||||||
} as unknown as ClientLib.Client;
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
const consoleLogSpy = vi
|
const consoleLogSpy = vi
|
||||||
.spyOn(console, 'debug')
|
.spyOn(console, 'debug')
|
||||||
.mockImplementation(() => {
|
.mockImplementation(() => {});
|
||||||
// no-op
|
|
||||||
});
|
|
||||||
|
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
||||||
expect(mockRequest).toHaveBeenCalledOnce();
|
expect(mockRequest).toHaveBeenCalledOnce();
|
||||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||||
|
|
||||||
consoleLogSpy.mockRestore();
|
consoleLogSpy.mockRestore();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should do nothing if the server has no prompt support', async () => {
|
||||||
|
const mockRequest = vi.fn().mockResolvedValue({
|
||||||
|
prompts: [],
|
||||||
|
});
|
||||||
|
const mockGetServerCapabilities = vi.fn().mockReturnValue({});
|
||||||
|
|
||||||
|
const mockedClient = {
|
||||||
|
getServerCapabilities: mockGetServerCapabilities,
|
||||||
|
request: mockRequest,
|
||||||
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
|
const consoleLogSpy = vi
|
||||||
|
.spyOn(console, 'debug')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
||||||
|
expect(mockRequest).not.toHaveBeenCalled();
|
||||||
|
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
consoleLogSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
it('should log an error if discovery fails', async () => {
|
it('should log an error if discovery fails', async () => {
|
||||||
const testError = new Error('test error');
|
const testError = new Error('test error');
|
||||||
testError.message = 'test error';
|
testError.message = 'test error';
|
||||||
const mockRequest = vi.fn().mockRejectedValue(testError);
|
const mockRequest = vi.fn().mockRejectedValue(testError);
|
||||||
|
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
||||||
|
prompts: {},
|
||||||
|
});
|
||||||
const mockedClient = {
|
const mockedClient = {
|
||||||
|
getServerCapabilities: mockGetServerCapabilities,
|
||||||
request: mockRequest,
|
request: mockRequest,
|
||||||
} as unknown as ClientLib.Client;
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
const consoleErrorSpy = vi
|
const consoleErrorSpy = vi
|
||||||
.spyOn(console, 'error')
|
.spyOn(console, 'error')
|
||||||
.mockImplementation(() => {
|
.mockImplementation(() => {});
|
||||||
// no-op
|
|
||||||
});
|
|
||||||
|
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
|
|
@ -496,6 +496,9 @@ export async function discoverPrompts(
|
||||||
promptRegistry: PromptRegistry,
|
promptRegistry: PromptRegistry,
|
||||||
): Promise<Prompt[]> {
|
): Promise<Prompt[]> {
|
||||||
try {
|
try {
|
||||||
|
// Only request prompts if the server supports them.
|
||||||
|
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
|
||||||
|
|
||||||
const response = await mcpClient.request(
|
const response = await mcpClient.request(
|
||||||
{ method: 'prompts/list', params: {} },
|
{ method: 'prompts/list', params: {} },
|
||||||
ListPromptsResultSchema,
|
ListPromptsResultSchema,
|
||||||
|
|
Loading…
Reference in New Issue