feat(core): Continue declarative tool migration. (#6114)

This commit is contained in:
joshualitt 2025-08-13 11:57:37 -07:00 committed by GitHub
parent 22109db320
commit 904f4623b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 623 additions and 697 deletions

View File

@ -74,9 +74,11 @@ describe('LSTool', () => {
const params = { const params = {
path: '/home/user/project/src', path: '/home/user/project/src',
}; };
vi.mocked(fs.statSync).mockReturnValue({
const error = lsTool.validateToolParams(params); isDirectory: () => true,
expect(error).toBeNull(); } as fs.Stats);
const invocation = lsTool.build(params);
expect(invocation).toBeDefined();
}); });
it('should reject relative paths', () => { it('should reject relative paths', () => {
@ -84,8 +86,9 @@ describe('LSTool', () => {
path: './src', path: './src',
}; };
const error = lsTool.validateToolParams(params); expect(() => lsTool.build(params)).toThrow(
expect(error).toBe('Path must be absolute: ./src'); 'Path must be absolute: ./src',
);
}); });
it('should reject paths outside workspace with clear error message', () => { it('should reject paths outside workspace with clear error message', () => {
@ -93,8 +96,7 @@ describe('LSTool', () => {
path: '/etc/passwd', path: '/etc/passwd',
}; };
const error = lsTool.validateToolParams(params); expect(() => lsTool.build(params)).toThrow(
expect(error).toBe(
'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project', 'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project',
); );
}); });
@ -103,9 +105,11 @@ describe('LSTool', () => {
const params = { const params = {
path: '/home/user/other-project/lib', path: '/home/user/other-project/lib',
}; };
vi.mocked(fs.statSync).mockReturnValue({
const error = lsTool.validateToolParams(params); isDirectory: () => true,
expect(error).toBeNull(); } as fs.Stats);
const invocation = lsTool.build(params);
expect(invocation).toBeDefined();
}); });
}); });
@ -133,10 +137,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('[DIR] subdir'); expect(result.llmContent).toContain('[DIR] subdir');
expect(result.llmContent).toContain('file1.ts'); expect(result.llmContent).toContain('file1.ts');
@ -161,10 +163,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('module1.js'); expect(result.llmContent).toContain('module1.js');
expect(result.llmContent).toContain('module2.js'); expect(result.llmContent).toContain('module2.js');
@ -179,10 +179,8 @@ describe('LSTool', () => {
} as fs.Stats); } as fs.Stats);
vi.mocked(fs.readdirSync).mockReturnValue([]); vi.mocked(fs.readdirSync).mockReturnValue([]);
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toBe( expect(result.llmContent).toBe(
'Directory /home/user/project/empty is empty.', 'Directory /home/user/project/empty is empty.',
@ -207,10 +205,11 @@ describe('LSTool', () => {
}); });
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute( const invocation = lsTool.build({
{ path: testPath, ignore: ['*.spec.js'] }, path: testPath,
new AbortController().signal, ignore: ['*.spec.js'],
); });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test.js'); expect(result.llmContent).toContain('test.js');
expect(result.llmContent).toContain('index.js'); expect(result.llmContent).toContain('index.js');
@ -238,10 +237,8 @@ describe('LSTool', () => {
(path: string) => path.includes('ignored.js'), (path: string) => path.includes('ignored.js'),
); );
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('file1.js'); expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js'); expect(result.llmContent).toContain('file2.js');
@ -269,10 +266,8 @@ describe('LSTool', () => {
(path: string) => path.includes('private.js'), (path: string) => path.includes('private.js'),
); );
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('file1.js'); expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js'); expect(result.llmContent).toContain('file2.js');
@ -287,10 +282,8 @@ describe('LSTool', () => {
isDirectory: () => false, isDirectory: () => false,
} as fs.Stats); } as fs.Stats);
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('Path is not a directory'); expect(result.llmContent).toContain('Path is not a directory');
expect(result.returnDisplay).toBe('Error: Path is not a directory.'); expect(result.returnDisplay).toBe('Error: Path is not a directory.');
@ -303,10 +296,8 @@ describe('LSTool', () => {
throw new Error('ENOENT: no such file or directory'); throw new Error('ENOENT: no such file or directory');
}); });
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('Error listing directory'); expect(result.llmContent).toContain('Error listing directory');
expect(result.returnDisplay).toBe('Error: Failed to list directory.'); expect(result.returnDisplay).toBe('Error: Failed to list directory.');
@ -336,10 +327,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
const lines = ( const lines = (
typeof result.llmContent === 'string' ? result.llmContent : '' typeof result.llmContent === 'string' ? result.llmContent : ''
@ -361,24 +350,18 @@ describe('LSTool', () => {
throw new Error('EACCES: permission denied'); throw new Error('EACCES: permission denied');
}); });
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('Error listing directory'); expect(result.llmContent).toContain('Error listing directory');
expect(result.llmContent).toContain('permission denied'); expect(result.llmContent).toContain('permission denied');
expect(result.returnDisplay).toBe('Error: Failed to list directory.'); expect(result.returnDisplay).toBe('Error: Failed to list directory.');
}); });
it('should validate parameters and return error for invalid params', async () => { it('should throw for invalid params at build time', async () => {
const result = await lsTool.execute( expect(() => lsTool.build({ path: '../outside' })).toThrow(
{ path: '../outside' }, 'Path must be absolute: ../outside',
new AbortController().signal,
); );
expect(result.llmContent).toContain('Invalid parameters provided');
expect(result.returnDisplay).toBe('Error: Failed to execute tool.');
}); });
it('should handle errors accessing individual files during listing', async () => { it('should handle errors accessing individual files during listing', async () => {
@ -406,10 +389,8 @@ describe('LSTool', () => {
.spyOn(console, 'error') .spyOn(console, 'error')
.mockImplementation(() => {}); .mockImplementation(() => {});
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
// Should still list the accessible file // Should still list the accessible file
expect(result.llmContent).toContain('accessible.ts'); expect(result.llmContent).toContain('accessible.ts');
@ -428,19 +409,25 @@ describe('LSTool', () => {
describe('getDescription', () => { describe('getDescription', () => {
it('should return shortened relative path', () => { it('should return shortened relative path', () => {
const params = { const params = {
path: path.join(mockPrimaryDir, 'deeply', 'nested', 'directory'), path: `${mockPrimaryDir}/deeply/nested/directory`,
}; };
vi.mocked(fs.statSync).mockReturnValue({
const description = lsTool.getDescription(params); isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
const description = invocation.getDescription();
expect(description).toBe(path.join('deeply', 'nested', 'directory')); expect(description).toBe(path.join('deeply', 'nested', 'directory'));
}); });
it('should handle paths in secondary workspace', () => { it('should handle paths in secondary workspace', () => {
const params = { const params = {
path: path.join(mockSecondaryDir, 'lib'), path: `${mockSecondaryDir}/lib`,
}; };
vi.mocked(fs.statSync).mockReturnValue({
const description = lsTool.getDescription(params); isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
const description = invocation.getDescription();
expect(description).toBe(path.join('..', 'other-project', 'lib')); expect(description).toBe(path.join('..', 'other-project', 'lib'));
}); });
}); });
@ -448,22 +435,25 @@ describe('LSTool', () => {
describe('workspace boundary validation', () => { describe('workspace boundary validation', () => {
it('should accept paths in primary workspace directory', () => { it('should accept paths in primary workspace directory', () => {
const params = { path: `${mockPrimaryDir}/src` }; const params = { path: `${mockPrimaryDir}/src` };
expect(lsTool.validateToolParams(params)).toBeNull(); vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
expect(lsTool.build(params)).toBeDefined();
}); });
it('should accept paths in secondary workspace directory', () => { it('should accept paths in secondary workspace directory', () => {
const params = { path: `${mockSecondaryDir}/lib` }; const params = { path: `${mockSecondaryDir}/lib` };
expect(lsTool.validateToolParams(params)).toBeNull(); vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
expect(lsTool.build(params)).toBeDefined();
}); });
it('should reject paths outside all workspace directories', () => { it('should reject paths outside all workspace directories', () => {
const params = { path: '/etc/passwd' }; const params = { path: '/etc/passwd' };
const error = lsTool.validateToolParams(params); expect(() => lsTool.build(params)).toThrow(
expect(error).toContain(
'Path must be within one of the workspace directories', 'Path must be within one of the workspace directories',
); );
expect(error).toContain(mockPrimaryDir);
expect(error).toContain(mockSecondaryDir);
}); });
it('should list files from secondary workspace directory', async () => { it('should list files from secondary workspace directory', async () => {
@ -483,10 +473,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute( const invocation = lsTool.build({ path: testPath });
{ path: testPath }, const result = await invocation.execute(new AbortController().signal);
new AbortController().signal,
);
expect(result.llmContent).toContain('test1.spec.ts'); expect(result.llmContent).toContain('test1.spec.ts');
expect(result.llmContent).toContain('test2.spec.ts'); expect(result.llmContent).toContain('test2.spec.ts');

View File

@ -6,7 +6,13 @@
import fs from 'fs'; import fs from 'fs';
import path from 'path'; import path from 'path';
import { BaseTool, Kind, ToolResult } from './tools.js'; import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js'; import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js'; import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
@ -64,10 +70,199 @@ export interface FileEntry {
modifiedTime: Date; modifiedTime: Date;
} }
class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
constructor(
private readonly config: Config,
params: LSToolParams,
) {
super(params);
}
/**
* Checks if a filename matches any of the ignore patterns
* @param filename Filename to check
* @param patterns Array of glob patterns to check against
* @returns True if the filename should be ignored
*/
private shouldIgnore(filename: string, patterns?: string[]): boolean {
if (!patterns || patterns.length === 0) {
return false;
}
for (const pattern of patterns) {
// Convert glob pattern to RegExp
const regexPattern = pattern
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
.replace(/\*/g, '.*')
.replace(/\?/g, '.');
const regex = new RegExp(`^${regexPattern}$`);
if (regex.test(filename)) {
return true;
}
}
return false;
}
/**
* Gets a description of the file reading operation
* @returns A string describing the file being read
*/
getDescription(): string {
const relativePath = makeRelative(
this.params.path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
// Helper for consistent error formatting
private errorResult(llmContent: string, returnDisplay: string): ToolResult {
return {
llmContent,
// Keep returnDisplay simpler in core logic
returnDisplay: `Error: ${returnDisplay}`,
};
}
/**
* Executes the LS operation with the given parameters
* @returns Result of the LS operation
*/
async execute(_signal: AbortSignal): Promise<ToolResult> {
try {
const stats = fs.statSync(this.params.path);
if (!stats) {
// fs.statSync throws on non-existence, so this check might be redundant
// but keeping for clarity. Error message adjusted.
return this.errorResult(
`Error: Directory not found or inaccessible: ${this.params.path}`,
`Directory not found or inaccessible.`,
);
}
if (!stats.isDirectory()) {
return this.errorResult(
`Error: Path is not a directory: ${this.params.path}`,
`Path is not a directory.`,
);
}
const files = fs.readdirSync(this.params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
this.params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
this.params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
// Get centralized file discovery service
const fileDiscovery = this.config.getFileService();
const entries: FileEntry[] = [];
let gitIgnoredCount = 0;
let geminiIgnoredCount = 0;
if (files.length === 0) {
// Changed error message to be more neutral for LLM
return {
llmContent: `Directory ${this.params.path} is empty.`,
returnDisplay: `Directory is empty.`,
};
}
for (const file of files) {
if (this.shouldIgnore(file, this.params.ignore)) {
continue;
}
const fullPath = path.join(this.params.path, file);
const relativePath = path.relative(
this.config.getTargetDir(),
fullPath,
);
// Check if this file should be ignored based on git or gemini ignore rules
if (
fileFilteringOptions.respectGitIgnore &&
fileDiscovery.shouldGitIgnoreFile(relativePath)
) {
gitIgnoredCount++;
continue;
}
if (
fileFilteringOptions.respectGeminiIgnore &&
fileDiscovery.shouldGeminiIgnoreFile(relativePath)
) {
geminiIgnoredCount++;
continue;
}
try {
const stats = fs.statSync(fullPath);
const isDir = stats.isDirectory();
entries.push({
name: file,
path: fullPath,
isDirectory: isDir,
size: isDir ? 0 : stats.size,
modifiedTime: stats.mtime,
});
} catch (error) {
// Log error internally but don't fail the whole listing
console.error(`Error accessing ${fullPath}: ${error}`);
}
}
// Sort entries (directories first, then alphabetically)
entries.sort((a, b) => {
if (a.isDirectory && !b.isDirectory) return -1;
if (!a.isDirectory && b.isDirectory) return 1;
return a.name.localeCompare(b.name);
});
// Create formatted content for LLM
const directoryContent = entries
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
.join('\n');
let resultMessage = `Directory listing for ${this.params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
}
if (geminiIgnoredCount > 0) {
ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`);
}
if (ignoredMessages.length > 0) {
resultMessage += `\n\n(${ignoredMessages.join(', ')})`;
}
let displayMessage = `Listed ${entries.length} item(s).`;
if (ignoredMessages.length > 0) {
displayMessage += ` (${ignoredMessages.join(', ')})`;
}
return {
llmContent: resultMessage,
returnDisplay: displayMessage,
};
} catch (error) {
const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`;
return this.errorResult(errorMsg, 'Failed to list directory.');
}
}
}
/** /**
* Implementation of the LS tool logic * Implementation of the LS tool logic
*/ */
export class LSTool extends BaseTool<LSToolParams, ToolResult> { export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
static readonly Name = 'list_directory'; static readonly Name = 'list_directory';
constructor(private config: Config) { constructor(private config: Config) {
@ -134,198 +329,16 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
const workspaceContext = this.config.getWorkspaceContext(); const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.path)) { if (!workspaceContext.isPathWithinWorkspace(params.path)) {
const directories = workspaceContext.getDirectories(); const directories = workspaceContext.getDirectories();
return `Path must be within one of the workspace directories: ${directories.join(', ')}`; return `Path must be within one of the workspace directories: ${directories.join(
', ',
)}`;
} }
return null; return null;
} }
/** protected createInvocation(
* Checks if a filename matches any of the ignore patterns
* @param filename Filename to check
* @param patterns Array of glob patterns to check against
* @returns True if the filename should be ignored
*/
private shouldIgnore(filename: string, patterns?: string[]): boolean {
if (!patterns || patterns.length === 0) {
return false;
}
for (const pattern of patterns) {
// Convert glob pattern to RegExp
const regexPattern = pattern
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
.replace(/\*/g, '.*')
.replace(/\?/g, '.');
const regex = new RegExp(`^${regexPattern}$`);
if (regex.test(filename)) {
return true;
}
}
return false;
}
/**
* Gets a description of the file reading operation
* @param params Parameters for the file reading
* @returns A string describing the file being read
*/
getDescription(params: LSToolParams): string {
const relativePath = makeRelative(params.path, this.config.getTargetDir());
return shortenPath(relativePath);
}
// Helper for consistent error formatting
private errorResult(llmContent: string, returnDisplay: string): ToolResult {
return {
llmContent,
// Keep returnDisplay simpler in core logic
returnDisplay: `Error: ${returnDisplay}`,
};
}
/**
* Executes the LS operation with the given parameters
* @param params Parameters for the LS operation
* @returns Result of the LS operation
*/
async execute(
params: LSToolParams, params: LSToolParams,
_signal: AbortSignal, ): ToolInvocation<LSToolParams, ToolResult> {
): Promise<ToolResult> { return new LSToolInvocation(this.config, params);
const validationError = this.validateToolParams(params);
if (validationError) {
return this.errorResult(
`Error: Invalid parameters provided. Reason: ${validationError}`,
`Failed to execute tool.`,
);
}
try {
const stats = fs.statSync(params.path);
if (!stats) {
// fs.statSync throws on non-existence, so this check might be redundant
// but keeping for clarity. Error message adjusted.
return this.errorResult(
`Error: Directory not found or inaccessible: ${params.path}`,
`Directory not found or inaccessible.`,
);
}
if (!stats.isDirectory()) {
return this.errorResult(
`Error: Path is not a directory: ${params.path}`,
`Path is not a directory.`,
);
}
const files = fs.readdirSync(params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
// Get centralized file discovery service
const fileDiscovery = this.config.getFileService();
const entries: FileEntry[] = [];
let gitIgnoredCount = 0;
let geminiIgnoredCount = 0;
if (files.length === 0) {
// Changed error message to be more neutral for LLM
return {
llmContent: `Directory ${params.path} is empty.`,
returnDisplay: `Directory is empty.`,
};
}
for (const file of files) {
if (this.shouldIgnore(file, params.ignore)) {
continue;
}
const fullPath = path.join(params.path, file);
const relativePath = path.relative(
this.config.getTargetDir(),
fullPath,
);
// Check if this file should be ignored based on git or gemini ignore rules
if (
fileFilteringOptions.respectGitIgnore &&
fileDiscovery.shouldGitIgnoreFile(relativePath)
) {
gitIgnoredCount++;
continue;
}
if (
fileFilteringOptions.respectGeminiIgnore &&
fileDiscovery.shouldGeminiIgnoreFile(relativePath)
) {
geminiIgnoredCount++;
continue;
}
try {
const stats = fs.statSync(fullPath);
const isDir = stats.isDirectory();
entries.push({
name: file,
path: fullPath,
isDirectory: isDir,
size: isDir ? 0 : stats.size,
modifiedTime: stats.mtime,
});
} catch (error) {
// Log error internally but don't fail the whole listing
console.error(`Error accessing ${fullPath}: ${error}`);
}
}
// Sort entries (directories first, then alphabetically)
entries.sort((a, b) => {
if (a.isDirectory && !b.isDirectory) return -1;
if (!a.isDirectory && b.isDirectory) return 1;
return a.name.localeCompare(b.name);
});
// Create formatted content for LLM
const directoryContent = entries
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
.join('\n');
let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
}
if (geminiIgnoredCount > 0) {
ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`);
}
if (ignoredMessages.length > 0) {
resultMessage += `\n\n(${ignoredMessages.join(', ')})`;
}
let displayMessage = `Listed ${entries.length} item(s).`;
if (ignoredMessages.length > 0) {
displayMessage += ` (${ignoredMessages.join(', ')})`;
}
return {
llmContent: resultMessage,
returnDisplay: displayMessage,
};
} catch (error) {
const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`;
return this.errorResult(errorMsg, 'Failed to list directory.');
}
} }
} }

View File

@ -73,11 +73,21 @@ describe('DiscoveredMCPTool', () => {
required: ['param'], required: ['param'],
}; };
let tool: DiscoveredMCPTool;
beforeEach(() => { beforeEach(() => {
mockCallTool.mockClear(); mockCallTool.mockClear();
mockToolMethod.mockClear(); mockToolMethod.mockClear();
tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
// Clear allowlist before each relevant test, especially for shouldConfirmExecute // Clear allowlist before each relevant test, especially for shouldConfirmExecute
(DiscoveredMCPTool as any).allowlist.clear(); const invocation = tool.build({}) as any;
invocation.constructor.allowlist.clear();
}); });
afterEach(() => { afterEach(() => {
@ -86,14 +96,6 @@ describe('DiscoveredMCPTool', () => {
describe('constructor', () => { describe('constructor', () => {
it('should set properties correctly', () => { it('should set properties correctly', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
expect(tool.name).toBe(serverToolName); expect(tool.name).toBe(serverToolName);
expect(tool.schema.name).toBe(serverToolName); expect(tool.schema.name).toBe(serverToolName);
expect(tool.schema.description).toBe(baseDescription); expect(tool.schema.description).toBe(baseDescription);
@ -105,7 +107,7 @@ describe('DiscoveredMCPTool', () => {
it('should accept and store a custom timeout', () => { it('should accept and store a custom timeout', () => {
const customTimeout = 5000; const customTimeout = 5000;
const tool = new DiscoveredMCPTool( const toolWithTimeout = new DiscoveredMCPTool(
mockCallableToolInstance, mockCallableToolInstance,
serverName, serverName,
serverToolName, serverToolName,
@ -113,19 +115,12 @@ describe('DiscoveredMCPTool', () => {
inputSchema, inputSchema,
customTimeout, customTimeout,
); );
expect(tool.timeout).toBe(customTimeout); expect(toolWithTimeout.timeout).toBe(customTimeout);
}); });
}); });
describe('execute', () => { describe('execute', () => {
it('should call mcpTool.callTool with correct parameters and format display output', async () => { it('should call mcpTool.callTool with correct parameters and format display output', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'testValue' }; const params = { param: 'testValue' };
const mockToolSuccessResultObject = { const mockToolSuccessResultObject = {
success: true, success: true,
@ -147,7 +142,10 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(mockMcpToolResponseParts); mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
const toolResult: ToolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult: ToolResult = await invocation.execute(
new AbortController().signal,
);
expect(mockCallTool).toHaveBeenCalledWith([ expect(mockCallTool).toHaveBeenCalledWith([
{ name: serverToolName, args: params }, { name: serverToolName, args: params },
@ -163,17 +161,13 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle empty result from getStringifiedResultForDisplay', async () => { it('should handle empty result from getStringifiedResultForDisplay', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'testValue' }; const params = { param: 'testValue' };
const mockMcpToolResponsePartsEmpty: Part[] = []; const mockMcpToolResponsePartsEmpty: Part[] = [];
mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty); mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty);
const toolResult: ToolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult: ToolResult = await invocation.execute(
new AbortController().signal,
);
expect(toolResult.returnDisplay).toBe('```json\n[]\n```'); expect(toolResult.returnDisplay).toBe('```json\n[]\n```');
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ text: '[Error: Could not parse tool response]' }, { text: '[Error: Could not parse tool response]' },
@ -181,28 +175,17 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should propagate rejection if mcpTool.callTool rejects', async () => { it('should propagate rejection if mcpTool.callTool rejects', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'failCase' }; const params = { param: 'failCase' };
const expectedError = new Error('MCP call failed'); const expectedError = new Error('MCP call failed');
mockCallTool.mockRejectedValue(expectedError); mockCallTool.mockRejectedValue(expectedError);
await expect(tool.execute(params)).rejects.toThrow(expectedError); const invocation = tool.build(params);
await expect(
invocation.execute(new AbortController().signal),
).rejects.toThrow(expectedError);
}); });
it('should handle a simple text response correctly', async () => { it('should handle a simple text response correctly', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { query: 'test' }; const params = { query: 'test' };
const successMessage = 'This is a success message.'; const successMessage = 'This is a success message.';
@ -221,7 +204,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
// 1. Assert that the llmContent sent to the scheduler is a clean Part array. // 1. Assert that the llmContent sent to the scheduler is a clean Part array.
expect(toolResult.llmContent).toEqual([{ text: successMessage }]); expect(toolResult.llmContent).toEqual([{ text: successMessage }]);
@ -236,13 +220,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle an AudioBlock response', async () => { it('should handle an AudioBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'play' }; const params = { action: 'play' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -262,7 +239,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ {
@ -279,13 +257,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle a ResourceLinkBlock response', async () => { it('should handle a ResourceLinkBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' }; const params = { resource: 'get' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -306,7 +277,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ {
@ -319,13 +291,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle an embedded text ResourceBlock response', async () => { it('should handle an embedded text ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' }; const params = { resource: 'get' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -348,7 +313,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ text: 'This is the text content.' }, { text: 'This is the text content.' },
@ -357,13 +323,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle an embedded binary ResourceBlock response', async () => { it('should handle an embedded binary ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' }; const params = { resource: 'get' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -386,7 +345,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ {
@ -405,13 +365,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle a mix of content block types', async () => { it('should handle a mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'complex' }; const params = { action: 'complex' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -433,7 +386,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ text: 'First part.' }, { text: 'First part.' },
@ -454,13 +408,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should ignore unknown content block types', async () => { it('should ignore unknown content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'test' }; const params = { action: 'test' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -477,7 +424,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]); expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]);
expect(toolResult.returnDisplay).toBe( expect(toolResult.returnDisplay).toBe(
@ -486,13 +434,6 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle a complex mix of content block types', async () => { it('should handle a complex mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'super-complex' }; const params = { action: 'super-complex' };
const sdkResponse: Part[] = [ const sdkResponse: Part[] = [
{ {
@ -527,7 +468,8 @@ describe('DiscoveredMCPTool', () => {
]; ];
mockCallTool.mockResolvedValue(sdkResponse); mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params); const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([ expect(toolResult.llmContent).toEqual([
{ text: 'Here is a resource.' }, { text: 'Here is a resource.' },
@ -552,10 +494,8 @@ describe('DiscoveredMCPTool', () => {
}); });
describe('shouldConfirmExecute', () => { describe('shouldConfirmExecute', () => {
// beforeEach is already clearing allowlist
it('should return false if trust is true', async () => { it('should return false if trust is true', async () => {
const tool = new DiscoveredMCPTool( const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance, mockCallableToolInstance,
serverName, serverName,
serverToolName, serverToolName,
@ -564,50 +504,32 @@ describe('DiscoveredMCPTool', () => {
undefined, undefined,
true, true,
); );
const invocation = trustedTool.build({});
expect( expect(
await tool.shouldConfirmExecute({}, new AbortController().signal), await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false); ).toBe(false);
}); });
it('should return false if server is allowlisted', async () => { it('should return false if server is allowlisted', async () => {
(DiscoveredMCPTool as any).allowlist.add(serverName); const invocation = tool.build({}) as any;
const tool = new DiscoveredMCPTool( invocation.constructor.allowlist.add(serverName);
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
expect( expect(
await tool.shouldConfirmExecute({}, new AbortController().signal), await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false); ).toBe(false);
}); });
it('should return false if tool is allowlisted', async () => { it('should return false if tool is allowlisted', async () => {
const toolAllowlistKey = `${serverName}.${serverToolName}`; const toolAllowlistKey = `${serverName}.${serverToolName}`;
(DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey); const invocation = tool.build({}) as any;
const tool = new DiscoveredMCPTool( invocation.constructor.allowlist.add(toolAllowlistKey);
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
expect( expect(
await tool.shouldConfirmExecute({}, new AbortController().signal), await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false); ).toBe(false);
}); });
it('should return confirmation details if not trusted and not allowlisted', async () => { it('should return confirmation details if not trusted and not allowlisted', async () => {
const tool = new DiscoveredMCPTool( const invocation = tool.build({});
mockCallableToolInstance, const confirmation = await invocation.shouldConfirmExecute(
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
new AbortController().signal, new AbortController().signal,
); );
expect(confirmation).not.toBe(false); expect(confirmation).not.toBe(false);
@ -629,15 +551,8 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should add server to allowlist on ProceedAlwaysServer', async () => { it('should add server to allowlist on ProceedAlwaysServer', async () => {
const tool = new DiscoveredMCPTool( const invocation = tool.build({}) as any;
mockCallableToolInstance, const confirmation = await invocation.shouldConfirmExecute(
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
new AbortController().signal, new AbortController().signal,
); );
expect(confirmation).not.toBe(false); expect(confirmation).not.toBe(false);
@ -650,7 +565,7 @@ describe('DiscoveredMCPTool', () => {
await confirmation.onConfirm( await confirmation.onConfirm(
ToolConfirmationOutcome.ProceedAlwaysServer, ToolConfirmationOutcome.ProceedAlwaysServer,
); );
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true); expect(invocation.constructor.allowlist.has(serverName)).toBe(true);
} else { } else {
throw new Error( throw new Error(
'Confirmation details or onConfirm not in expected format', 'Confirmation details or onConfirm not in expected format',
@ -659,16 +574,9 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should add tool to allowlist on ProceedAlwaysTool', async () => { it('should add tool to allowlist on ProceedAlwaysTool', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const toolAllowlistKey = `${serverName}.${serverToolName}`; const toolAllowlistKey = `${serverName}.${serverToolName}`;
const confirmation = await tool.shouldConfirmExecute( const invocation = tool.build({}) as any;
{}, const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal, new AbortController().signal,
); );
expect(confirmation).not.toBe(false); expect(confirmation).not.toBe(false);
@ -679,7 +587,7 @@ describe('DiscoveredMCPTool', () => {
typeof confirmation.onConfirm === 'function' typeof confirmation.onConfirm === 'function'
) { ) {
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool); await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool);
expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe( expect(invocation.constructor.allowlist.has(toolAllowlistKey)).toBe(
true, true,
); );
} else { } else {
@ -690,15 +598,8 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle Cancel confirmation outcome', async () => { it('should handle Cancel confirmation outcome', async () => {
const tool = new DiscoveredMCPTool( const invocation = tool.build({}) as any;
mockCallableToolInstance, const confirmation = await invocation.shouldConfirmExecute(
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
new AbortController().signal, new AbortController().signal,
); );
expect(confirmation).not.toBe(false); expect(confirmation).not.toBe(false);
@ -710,11 +611,9 @@ describe('DiscoveredMCPTool', () => {
) { ) {
// Cancel should not add anything to allowlist // Cancel should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel); await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe( expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
false,
);
expect( expect(
(DiscoveredMCPTool as any).allowlist.has( invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`, `${serverName}.${serverToolName}`,
), ),
).toBe(false); ).toBe(false);
@ -726,15 +625,8 @@ describe('DiscoveredMCPTool', () => {
}); });
it('should handle ProceedOnce confirmation outcome', async () => { it('should handle ProceedOnce confirmation outcome', async () => {
const tool = new DiscoveredMCPTool( const invocation = tool.build({}) as any;
mockCallableToolInstance, const confirmation = await invocation.shouldConfirmExecute(
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
new AbortController().signal, new AbortController().signal,
); );
expect(confirmation).not.toBe(false); expect(confirmation).not.toBe(false);
@ -746,11 +638,9 @@ describe('DiscoveredMCPTool', () => {
) { ) {
// ProceedOnce should not add anything to allowlist // ProceedOnce should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce); await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe( expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
false,
);
expect( expect(
(DiscoveredMCPTool as any).allowlist.has( invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`, `${serverName}.${serverToolName}`,
), ),
).toBe(false); ).toBe(false);

View File

@ -5,14 +5,16 @@
*/ */
import { import {
BaseTool, BaseDeclarativeTool,
ToolResult, BaseToolInvocation,
Kind,
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
ToolConfirmationOutcome, ToolConfirmationOutcome,
ToolInvocation,
ToolMcpConfirmationDetails, ToolMcpConfirmationDetails,
Kind, ToolResult,
} from './tools.js'; } from './tools.js';
import { CallableTool, Part, FunctionCall } from '@google/genai'; import { CallableTool, FunctionCall, Part } from '@google/genai';
type ToolParams = Record<string, unknown>; type ToolParams = Record<string, unknown>;
@ -50,9 +52,84 @@ type McpContentBlock =
| McpResourceBlock | McpResourceBlock
| McpResourceLinkBlock; | McpResourceLinkBlock;
export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> { class DiscoveredMCPToolInvocation extends BaseToolInvocation<
ToolParams,
ToolResult
> {
private static readonly allowlist: Set<string> = new Set(); private static readonly allowlist: Set<string> = new Set();
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
readonly serverToolName: string,
readonly displayName: string,
readonly timeout?: number,
readonly trust?: boolean,
params: ToolParams = {},
) {
super(params);
}
async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const serverAllowListKey = this.serverName;
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
if (this.trust) {
return false; // server is trusted, no confirmation needed
}
if (
DiscoveredMCPToolInvocation.allowlist.has(serverAllowListKey) ||
DiscoveredMCPToolInvocation.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
type: 'mcp',
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.displayName, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
async execute(): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
args: this.params,
},
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
}
getDescription(): string {
return this.displayName;
}
}
export class DiscoveredMCPTool extends BaseDeclarativeTool<
ToolParams,
ToolResult
> {
constructor( constructor(
private readonly mcpTool: CallableTool, private readonly mcpTool: CallableTool,
readonly serverName: string, readonly serverName: string,
@ -87,56 +164,18 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
); );
} }
async shouldConfirmExecute( protected createInvocation(
_params: ToolParams, params: ToolParams,
_abortSignal: AbortSignal, ): ToolInvocation<ToolParams, ToolResult> {
): Promise<ToolCallConfirmationDetails | false> { return new DiscoveredMCPToolInvocation(
const serverAllowListKey = this.serverName; this.mcpTool,
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`; this.serverName,
this.serverToolName,
if (this.trust) { this.displayName,
return false; // server is trusted, no confirmation needed this.timeout,
} this.trust,
params,
if ( );
DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
DiscoveredMCPTool.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
type: 'mcp',
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.name, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPTool.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPTool.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
async execute(params: ToolParams): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
args: params,
},
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
} }
} }

View File

@ -218,7 +218,8 @@ describe('MemoryTool', () => {
it('should call performAddMemoryEntry with correct parameters and return success', async () => { it('should call performAddMemoryEntry with correct parameters and return success', async () => {
const params = { fact: 'The sky is blue' }; const params = { fact: 'The sky is blue' };
const result = await memoryTool.execute(params, mockAbortSignal); const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test // Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
const expectedFilePath = path.join( const expectedFilePath = path.join(
os.homedir(), os.homedir(),
@ -247,14 +248,12 @@ describe('MemoryTool', () => {
it('should return an error if fact is empty', async () => { it('should return an error if fact is empty', async () => {
const params = { fact: ' ' }; // Empty fact const params = { fact: ' ' }; // Empty fact
const result = await memoryTool.execute(params, mockAbortSignal); expect(memoryTool.validateToolParams(params)).toBe(
const errorMessage = 'Parameter "fact" must be a non-empty string.'; 'Parameter "fact" must be a non-empty string.',
);
expect(performAddMemoryEntrySpy).not.toHaveBeenCalled(); expect(() => memoryTool.build(params)).toThrow(
expect(result.llmContent).toBe( 'Parameter "fact" must be a non-empty string.',
JSON.stringify({ success: false, error: errorMessage }),
); );
expect(result.returnDisplay).toBe(`Error: ${errorMessage}`);
}); });
it('should handle errors from performAddMemoryEntry', async () => { it('should handle errors from performAddMemoryEntry', async () => {
@ -264,7 +263,8 @@ describe('MemoryTool', () => {
); );
performAddMemoryEntrySpy.mockRejectedValue(underlyingError); performAddMemoryEntrySpy.mockRejectedValue(underlyingError);
const result = await memoryTool.execute(params, mockAbortSignal); const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
expect(result.llmContent).toBe( expect(result.llmContent).toBe(
JSON.stringify({ JSON.stringify({
@ -284,17 +284,17 @@ describe('MemoryTool', () => {
beforeEach(() => { beforeEach(() => {
memoryTool = new MemoryTool(); memoryTool = new MemoryTool();
// Clear the allowlist before each test // Clear the allowlist before each test
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.clear(); const invocation = memoryTool.build({ fact: 'mock-fact' });
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.clear();
// Mock fs.readFile to return empty string (file doesn't exist) // Mock fs.readFile to return empty string (file doesn't exist)
vi.mocked(fs.readFile).mockResolvedValue(''); vi.mocked(fs.readFile).mockResolvedValue('');
}); });
it('should return confirmation details when memory file is not allowlisted', async () => { it('should return confirmation details when memory file is not allowlisted', async () => {
const params = { fact: 'Test fact' }; const params = { fact: 'Test fact' };
const result = await memoryTool.shouldConfirmExecute( const invocation = memoryTool.build(params);
params, const result = await invocation.shouldConfirmExecute(mockAbortSignal);
mockAbortSignal,
);
expect(result).toBeDefined(); expect(result).toBeDefined();
expect(result).not.toBe(false); expect(result).not.toBe(false);
@ -321,15 +321,12 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(), getCurrentGeminiMdFilename(),
); );
const invocation = memoryTool.build(params);
// Add the memory file to the allowlist // Add the memory file to the allowlist
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.add( // eslint-disable-next-line @typescript-eslint/no-explicit-any
memoryFilePath, (invocation.constructor as any).allowlist.add(memoryFilePath);
);
const result = await memoryTool.shouldConfirmExecute( const result = await invocation.shouldConfirmExecute(mockAbortSignal);
params,
mockAbortSignal,
);
expect(result).toBe(false); expect(result).toBe(false);
}); });
@ -342,10 +339,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(), getCurrentGeminiMdFilename(),
); );
const result = await memoryTool.shouldConfirmExecute( const invocation = memoryTool.build(params);
params, const result = await invocation.shouldConfirmExecute(mockAbortSignal);
mockAbortSignal,
);
expect(result).toBeDefined(); expect(result).toBeDefined();
expect(result).not.toBe(false); expect(result).not.toBe(false);
@ -356,9 +351,8 @@ describe('MemoryTool', () => {
// Check that the memory file was added to the allowlist // Check that the memory file was added to the allowlist
expect( expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has( // eslint-disable-next-line @typescript-eslint/no-explicit-any
memoryFilePath, (invocation.constructor as any).allowlist.has(memoryFilePath),
),
).toBe(true); ).toBe(true);
} }
}); });
@ -371,10 +365,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(), getCurrentGeminiMdFilename(),
); );
const result = await memoryTool.shouldConfirmExecute( const invocation = memoryTool.build(params);
params, const result = await invocation.shouldConfirmExecute(mockAbortSignal);
mockAbortSignal,
);
expect(result).toBeDefined(); expect(result).toBeDefined();
expect(result).not.toBe(false); expect(result).not.toBe(false);
@ -382,18 +374,12 @@ describe('MemoryTool', () => {
if (result && result.type === 'edit') { if (result && result.type === 'edit') {
// Simulate the onConfirm callback with different outcomes // Simulate the onConfirm callback with different outcomes
await result.onConfirm(ToolConfirmationOutcome.ProceedOnce); await result.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect( // eslint-disable-next-line @typescript-eslint/no-explicit-any
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has( const allowlist = (invocation.constructor as any).allowlist;
memoryFilePath, expect(allowlist.has(memoryFilePath)).toBe(false);
),
).toBe(false);
await result.onConfirm(ToolConfirmationOutcome.Cancel); await result.onConfirm(ToolConfirmationOutcome.Cancel);
expect( expect(allowlist.has(memoryFilePath)).toBe(false);
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
memoryFilePath,
),
).toBe(false);
} }
}); });
@ -405,10 +391,8 @@ describe('MemoryTool', () => {
// Mock fs.readFile to return existing content // Mock fs.readFile to return existing content
vi.mocked(fs.readFile).mockResolvedValue(existingContent); vi.mocked(fs.readFile).mockResolvedValue(existingContent);
const result = await memoryTool.shouldConfirmExecute( const invocation = memoryTool.build(params);
params, const result = await invocation.shouldConfirmExecute(mockAbortSignal);
mockAbortSignal,
);
expect(result).toBeDefined(); expect(result).toBeDefined();
expect(result).not.toBe(false); expect(result).not.toBe(false);

View File

@ -5,11 +5,12 @@
*/ */
import { import {
BaseTool, BaseDeclarativeTool,
BaseToolInvocation,
Kind, Kind,
ToolResult,
ToolEditConfirmationDetails, ToolEditConfirmationDetails,
ToolConfirmationOutcome, ToolConfirmationOutcome,
ToolResult,
} from './tools.js'; } from './tools.js';
import { FunctionDeclaration } from '@google/genai'; import { FunctionDeclaration } from '@google/genai';
import * as fs from 'fs/promises'; import * as fs from 'fs/promises';
@ -19,6 +20,7 @@ import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js'; import { tildeifyPath } from '../utils/paths.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
const memoryToolSchemaData: FunctionDeclaration = { const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory', name: 'save_memory',
@ -110,32 +112,10 @@ function ensureNewlineSeparation(currentContent: string): string {
return '\n\n'; return '\n\n';
} }
export class MemoryTool /**
extends BaseTool<SaveMemoryParams, ToolResult>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
private static readonly allowlist: Set<string> = new Set();
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Kind.Think,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
);
}
getDescription(_params: SaveMemoryParams): string {
const memoryFilePath = getGlobalMemoryFilePath();
return `in ${tildeifyPath(memoryFilePath)}`;
}
/**
* Reads the current content of the memory file * Reads the current content of the memory file
*/ */
private async readMemoryFileContent(): Promise<string> { async function readMemoryFileContent(): Promise<string> {
try { try {
return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8');
} catch (err) { } catch (err) {
@ -143,12 +123,12 @@ export class MemoryTool
if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
return ''; return '';
} }
} }
/** /**
* Computes the new content that would result from adding a memory entry * Computes the new content that would result from adding a memory entry
*/ */
private computeNewContent(currentContent: string, fact: string): string { function computeNewContent(currentContent: string, fact: string): string {
let processedText = fact.trim(); let processedText = fact.trim();
processedText = processedText.replace(/^(-+\s*)+/, '').trim(); processedText = processedText.replace(/^(-+\s*)+/, '').trim();
const newMemoryItem = `- ${processedText}`; const newMemoryItem = `- ${processedText}`;
@ -187,24 +167,31 @@ export class MemoryTool
'\n' '\n'
); );
} }
}
class MemoryToolInvocation extends BaseToolInvocation<
SaveMemoryParams,
ToolResult
> {
private static readonly allowlist: Set<string> = new Set();
getDescription(): string {
const memoryFilePath = getGlobalMemoryFilePath();
return `in ${tildeifyPath(memoryFilePath)}`;
} }
async shouldConfirmExecute( async shouldConfirmExecute(
params: SaveMemoryParams,
_abortSignal: AbortSignal, _abortSignal: AbortSignal,
): Promise<ToolEditConfirmationDetails | false> { ): Promise<ToolEditConfirmationDetails | false> {
const memoryFilePath = getGlobalMemoryFilePath(); const memoryFilePath = getGlobalMemoryFilePath();
const allowlistKey = memoryFilePath; const allowlistKey = memoryFilePath;
if (MemoryTool.allowlist.has(allowlistKey)) { if (MemoryToolInvocation.allowlist.has(allowlistKey)) {
return false; return false;
} }
// Read current content of the memory file const currentContent = await readMemoryFileContent();
const currentContent = await this.readMemoryFileContent(); const newContent = computeNewContent(currentContent, this.params.fact);
// Calculate the new content that will be written to the memory file
const newContent = this.computeNewContent(currentContent, params.fact);
const fileName = path.basename(memoryFilePath); const fileName = path.basename(memoryFilePath);
const fileDiff = Diff.createPatch( const fileDiff = Diff.createPatch(
@ -226,13 +213,107 @@ export class MemoryTool
newContent, newContent,
onConfirm: async (outcome: ToolConfirmationOutcome) => { onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) { if (outcome === ToolConfirmationOutcome.ProceedAlways) {
MemoryTool.allowlist.add(allowlistKey); MemoryToolInvocation.allowlist.add(allowlistKey);
} }
}, },
}; };
return confirmationDetails; return confirmationDetails;
} }
async execute(_signal: AbortSignal): Promise<ToolResult> {
const { fact, modified_by_user, modified_content } = this.params;
try {
if (modified_by_user && modified_content !== undefined) {
// User modified the content in external editor, write it directly
await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
recursive: true,
});
await fs.writeFile(
getGlobalMemoryFilePath(),
modified_content,
'utf-8',
);
const successMessage = `Okay, I've updated the memory file with your modifications.`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
} else {
// Use the normal memory entry logic
await MemoryTool.performAddMemoryEntry(
fact,
getGlobalMemoryFilePath(),
{
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
},
);
const successMessage = `Okay, I've remembered that: "${fact}"`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
}
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(
`[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
);
return {
llmContent: JSON.stringify({
success: false,
error: `Failed to save memory. Detail: ${errorMessage}`,
}),
returnDisplay: `Error saving memory: ${errorMessage}`,
};
}
}
}
export class MemoryTool
extends BaseDeclarativeTool<SaveMemoryParams, ToolResult>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Kind.Think,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
);
}
validateToolParams(params: SaveMemoryParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (params.fact.trim() === '') {
return 'Parameter "fact" must be a non-empty string.';
}
return null;
}
protected createInvocation(params: SaveMemoryParams) {
return new MemoryToolInvocation(params);
}
static async performAddMemoryEntry( static async performAddMemoryEntry(
text: string, text: string,
memoryFilePath: string, memoryFilePath: string,
@ -303,83 +384,14 @@ export class MemoryTool
} }
} }
async execute(
params: SaveMemoryParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const { fact, modified_by_user, modified_content } = params;
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
const errorMessage = 'Parameter "fact" must be a non-empty string.';
return {
llmContent: JSON.stringify({ success: false, error: errorMessage }),
returnDisplay: `Error: ${errorMessage}`,
};
}
try {
if (modified_by_user && modified_content !== undefined) {
// User modified the content in external editor, write it directly
await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
recursive: true,
});
await fs.writeFile(
getGlobalMemoryFilePath(),
modified_content,
'utf-8',
);
const successMessage = `Okay, I've updated the memory file with your modifications.`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
} else {
// Use the normal memory entry logic
await MemoryTool.performAddMemoryEntry(
fact,
getGlobalMemoryFilePath(),
{
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
},
);
const successMessage = `Okay, I've remembered that: "${fact}"`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
}
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(
`[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
);
return {
llmContent: JSON.stringify({
success: false,
error: `Failed to save memory. Detail: ${errorMessage}`,
}),
returnDisplay: `Error saving memory: ${errorMessage}`,
};
}
}
getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> { getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> {
return { return {
getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(),
getCurrentContent: async (_params: SaveMemoryParams): Promise<string> => getCurrentContent: async (_params: SaveMemoryParams): Promise<string> =>
this.readMemoryFileContent(), readMemoryFileContent(),
getProposedContent: async (params: SaveMemoryParams): Promise<string> => { getProposedContent: async (params: SaveMemoryParams): Promise<string> => {
const currentContent = await this.readMemoryFileContent(); const currentContent = await readMemoryFileContent();
return this.computeNewContent(currentContent, params.fact); return computeNewContent(currentContent, params.fact);
}, },
createUpdatedParams: ( createUpdatedParams: (
_oldContent: string, _oldContent: string,