Implementation of web search as a tool (#307)

This commit is contained in:
Allen Hutchison 2025-05-20 11:36:21 -07:00 committed by GitHub
parent d1210f2e0a
commit 93e89215e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 212 additions and 3 deletions

View File

@ -20,6 +20,7 @@ import { WriteFileTool } from '../tools/write-file.js';
import { WebFetchTool } from '../tools/web-fetch.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { MemoryTool } from '../tools/memoryTool.js';
import { WebSearchTool } from '../tools/web-search.js';
export class MCPServerConfig {
constructor(
@ -202,7 +203,7 @@ export function createServerConfig(
);
}
function createToolRegistry(config: Config): ToolRegistry {
export function createToolRegistry(config: Config): ToolRegistry {
const registry = new ToolRegistry(config);
const targetDir = config.getTargetDir();
const tools = config.getCoreTools()
@ -228,7 +229,7 @@ function createToolRegistry(config: Config): ToolRegistry {
registerCoreTool(ReadManyFilesTool, targetDir);
registerCoreTool(ShellTool, config);
registerCoreTool(MemoryTool);
registerCoreTool(WebSearchTool, config);
registry.discoverTools();
return registry;
}

View File

@ -133,8 +133,11 @@ Returns the MCP server response as a json string.
export class ToolRegistry {
private tools: Map<string, Tool> = new Map();
private config: Config;
constructor(private readonly config: Config) {}
constructor(config: Config) {
this.config = config;
}
/**
* Registers a tool definition.

View File

@ -0,0 +1,205 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { GoogleGenAI, GroundingMetadata } from '@google/genai';
import { BaseTool, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
interface GroundingChunkWeb {
uri?: string;
title?: string;
}
interface GroundingChunkItem {
web?: GroundingChunkWeb;
// Other properties might exist if needed in the future
}
interface GroundingSupportSegment {
startIndex: number;
endIndex: number;
text?: string; // text is optional as per the example
}
interface GroundingSupportItem {
segment?: GroundingSupportSegment;
groundingChunkIndices?: number[];
confidenceScores?: number[]; // Optional as per example
}
/**
* Parameters for the WebSearchTool.
*/
export interface WebSearchToolParams {
/**
* The search query.
*/
query: string;
}
/**
* Extends ToolResult to include sources for web search.
*/
export interface WebSearchToolResult extends ToolResult {
sources?: GroundingMetadata extends { groundingChunks: GroundingChunkItem[] }
? GroundingMetadata['groundingChunks']
: GroundingChunkItem[];
}
/**
* A tool to perform web searches using Google Search via the Gemini API.
*/
export class WebSearchTool extends BaseTool<
WebSearchToolParams,
WebSearchToolResult
> {
static readonly Name: string = 'google_web_search';
private ai: GoogleGenAI;
private modelName: string;
constructor(private readonly config: Config) {
super(
WebSearchTool.Name,
'GoogleSearch',
'Performs a web search using Google Search (via the Gemini API) and returns the results. This tool is useful for finding information on the internet based on a query.',
{
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query to find information on the web.',
},
},
required: ['query'],
},
);
const apiKey = this.config.getApiKey();
if (!apiKey) {
throw new Error(
'Google AI API key is not configured. WebSearchTool cannot be initialized.',
);
}
this.ai = new GoogleGenAI({ apiKey });
this.modelName = this.config.getModel();
}
validateParams(params: WebSearchToolParams): string | null {
if (
this.schema.parameters &&
!SchemaValidator.validate(
this.schema.parameters as Record<string, unknown>,
params,
)
) {
return "Parameters failed schema validation. Ensure 'query' is a string.";
}
if (!params.query || params.query.trim() === '') {
return "The 'query' parameter cannot be empty.";
}
return null;
}
getDescription(params: WebSearchToolParams): string {
return `Searching the web for: "${params.query}"`;
}
async execute(params: WebSearchToolParams): Promise<WebSearchToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {
llmContent: `Error: Invalid parameters provided. Reason: ${validationError}`,
returnDisplay: validationError,
};
}
try {
const response = await this.ai.models.generateContent({
model: this.modelName,
contents: [{ role: 'user', parts: [{ text: params.query }] }],
config: {
tools: [{ googleSearch: {} }],
},
});
const responseText = getResponseText(response);
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;
const sources = groundingMetadata?.groundingChunks as
| GroundingChunkItem[]
| undefined;
const groundingSupports = groundingMetadata?.groundingSupports as
| GroundingSupportItem[]
| undefined;
if (!responseText || !responseText.trim()) {
return {
llmContent: `No search results or information found for query: "${params.query}"`,
returnDisplay: 'No information found.',
};
}
let modifiedResponseText = responseText;
const sourceListFormatted: string[] = [];
if (sources && sources.length > 0) {
sources.forEach((source: GroundingChunkItem, index: number) => {
const title = source.web?.title || 'Untitled';
const uri = source.web?.uri || 'No URI';
sourceListFormatted.push(`[${index + 1}] ${title} (${uri})`);
});
if (groundingSupports && groundingSupports.length > 0) {
const insertions: Array<{ index: number; marker: string }> = [];
groundingSupports.forEach((support: GroundingSupportItem) => {
if (support.segment && support.groundingChunkIndices) {
const citationMarker = support.groundingChunkIndices
.map((chunkIndex: number) => `[${chunkIndex + 1}]`)
.join('');
insertions.push({
index: support.segment.endIndex,
marker: citationMarker,
});
}
});
// Sort insertions by index in descending order to avoid shifting subsequent indices
insertions.sort((a, b) => b.index - a.index);
const responseChars = modifiedResponseText.split(''); // Use new variable
insertions.forEach((insertion) => {
// Fixed arrow function syntax
responseChars.splice(insertion.index, 0, insertion.marker);
});
modifiedResponseText = responseChars.join(''); // Assign back to modifiedResponseText
}
if (sourceListFormatted.length > 0) {
modifiedResponseText +=
'\n\nSources:\n' + sourceListFormatted.join('\n'); // Fixed string concatenation
}
}
return {
llmContent: `Web search results for "${params.query}":\n\n${modifiedResponseText}`,
returnDisplay: `Search results for "${params.query}" returned.`,
sources,
};
} catch (error: unknown) {
const errorMessage = `Error during web search for query "${params.query}": ${getErrorMessage(error)}`;
console.error(errorMessage, error);
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error performing web search.`,
};
}
}
}