485 lines
16 KiB
TypeScript
485 lines
16 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
import { exec as _exec } from 'child_process';
|
|
import { useState, useRef, useCallback, useEffect } from 'react';
|
|
import { useInput } from 'ink';
|
|
// Import server-side client and types
|
|
import {
|
|
GeminiClient,
|
|
GeminiEventType as ServerGeminiEventType, // Rename to avoid conflict
|
|
getErrorMessage,
|
|
isNodeError,
|
|
ToolResult,
|
|
Config,
|
|
ToolCallConfirmationDetails,
|
|
ToolCallResponseInfo,
|
|
ServerToolCallConfirmationDetails,
|
|
ToolConfirmationOutcome,
|
|
ToolResultDisplay,
|
|
ToolEditConfirmationDetails,
|
|
ToolExecuteConfirmationDetails,
|
|
} from '@gemini-code/server';
|
|
import {
|
|
type Chat,
|
|
type PartListUnion,
|
|
type FunctionDeclaration,
|
|
type Part,
|
|
} from '@google/genai';
|
|
import {
|
|
StreamingState,
|
|
HistoryItem,
|
|
IndividualToolCallDisplay,
|
|
ToolCallStatus,
|
|
} from '../types.js';
|
|
|
|
const addHistoryItem = (
|
|
setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>,
|
|
itemData: Omit<HistoryItem, 'id'>,
|
|
id: number,
|
|
) => {
|
|
setHistory((prevHistory) => [
|
|
...prevHistory,
|
|
{ ...itemData, id } as HistoryItem,
|
|
]);
|
|
};
|
|
|
|
// Hook now accepts apiKey and model
|
|
export const useGeminiStream = (
|
|
setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>,
|
|
config: Config,
|
|
) => {
|
|
const toolRegistry = config.getToolRegistry();
|
|
const [streamingState, setStreamingState] = useState<StreamingState>(
|
|
StreamingState.Idle,
|
|
);
|
|
const [debugMessage, setDebugMessage] = useState<string>('');
|
|
const [initError, setInitError] = useState<string | null>(null);
|
|
const abortControllerRef = useRef<AbortController | null>(null);
|
|
const chatSessionRef = useRef<Chat | null>(null);
|
|
const geminiClientRef = useRef<GeminiClient | null>(null);
|
|
const messageIdCounterRef = useRef(0);
|
|
const currentGeminiMessageIdRef = useRef<number | null>(null);
|
|
|
|
// Initialize Client Effect - uses props now
|
|
useEffect(() => {
|
|
setInitError(null);
|
|
if (!geminiClientRef.current) {
|
|
try {
|
|
geminiClientRef.current = new GeminiClient(
|
|
config.getApiKey(),
|
|
config.getModel(),
|
|
);
|
|
} catch (error: unknown) {
|
|
setInitError(
|
|
`Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`,
|
|
);
|
|
}
|
|
}
|
|
}, [config.getApiKey(), config.getModel()]);
|
|
|
|
// Input Handling Effect (remains the same)
|
|
useInput((input, key) => {
|
|
if (streamingState === StreamingState.Responding && key.escape) {
|
|
abortControllerRef.current?.abort();
|
|
}
|
|
});
|
|
|
|
// ID Generation Callback (remains the same)
|
|
const getNextMessageId = useCallback((baseTimestamp: number): number => {
|
|
messageIdCounterRef.current += 1;
|
|
return baseTimestamp + messageIdCounterRef.current;
|
|
}, []);
|
|
|
|
// Helper function to update Gemini message content
|
|
const updateGeminiMessage = useCallback(
|
|
(messageId: number, newContent: string) => {
|
|
setHistory((prevHistory) =>
|
|
prevHistory.map((item) =>
|
|
item.id === messageId && item.type === 'gemini'
|
|
? { ...item, text: newContent }
|
|
: item,
|
|
),
|
|
);
|
|
},
|
|
[setHistory],
|
|
);
|
|
|
|
// Improved submit query function
|
|
const submitQuery = useCallback(
|
|
async (query: PartListUnion) => {
|
|
if (streamingState === StreamingState.Responding) return;
|
|
if (typeof query === 'string' && query.trim().length === 0) return;
|
|
|
|
if (typeof query === 'string') {
|
|
setDebugMessage(`User query: ${query}`);
|
|
const maybeCommand = query.split(/\s+/)[0];
|
|
if (query.trim() === 'clear') {
|
|
// This just clears the *UI* history, not the model history.
|
|
// TODO: add a slash command for that.
|
|
setDebugMessage('Clearing terminal.');
|
|
setHistory((_) => []);
|
|
return;
|
|
} else if (config.getPassthroughCommands().includes(maybeCommand)) {
|
|
// Execute and capture output
|
|
setDebugMessage(`Executing shell command directly: ${query}`);
|
|
_exec(query, (error, stdout, stderr) => {
|
|
const timestamp = getNextMessageId(Date.now());
|
|
if (error) {
|
|
addHistoryItem(
|
|
setHistory,
|
|
{ type: 'error', text: error.message },
|
|
timestamp,
|
|
);
|
|
} else if (stderr) {
|
|
addHistoryItem(
|
|
setHistory,
|
|
{ type: 'error', text: stderr },
|
|
timestamp,
|
|
);
|
|
} else {
|
|
// Add stdout as an info message
|
|
addHistoryItem(
|
|
setHistory,
|
|
{ type: 'info', text: stdout || '' },
|
|
timestamp,
|
|
);
|
|
}
|
|
// Set state back to Idle *after* command finishes and output is added
|
|
setStreamingState(StreamingState.Idle);
|
|
});
|
|
// Set state to Responding while the command runs
|
|
setStreamingState(StreamingState.Responding);
|
|
return; // Prevent Gemini call
|
|
}
|
|
}
|
|
|
|
const userMessageTimestamp = Date.now();
|
|
const client = geminiClientRef.current;
|
|
if (!client) {
|
|
setInitError('Gemini client is not available.');
|
|
return;
|
|
}
|
|
|
|
if (!chatSessionRef.current) {
|
|
try {
|
|
// Use getFunctionDeclarations for startChat
|
|
const toolSchemas = toolRegistry.getFunctionDeclarations();
|
|
chatSessionRef.current = await client.startChat(toolSchemas);
|
|
} catch (err: unknown) {
|
|
setInitError(`Failed to start chat: ${getErrorMessage(err)}`);
|
|
setStreamingState(StreamingState.Idle);
|
|
return;
|
|
}
|
|
}
|
|
|
|
setStreamingState(StreamingState.Responding);
|
|
setInitError(null);
|
|
messageIdCounterRef.current = 0; // Reset counter for new submission
|
|
const chat = chatSessionRef.current;
|
|
let currentToolGroupId: number | null = null;
|
|
|
|
// For function responses, we don't need to add a user message
|
|
if (typeof query === 'string') {
|
|
// Only add user message for string queries, not for function responses
|
|
addHistoryItem(
|
|
setHistory,
|
|
{ type: 'user', text: query },
|
|
userMessageTimestamp,
|
|
);
|
|
}
|
|
|
|
try {
|
|
abortControllerRef.current = new AbortController();
|
|
const signal = abortControllerRef.current.signal;
|
|
|
|
// Get ServerTool descriptions for the server call
|
|
const serverTools: ServerTool[] = toolRegistry.getAllTools();
|
|
|
|
const stream = client.sendMessageStream(
|
|
chat,
|
|
query,
|
|
serverTools,
|
|
signal,
|
|
);
|
|
|
|
// Process the stream events from the server logic
|
|
let currentGeminiText = ''; // To accumulate message content
|
|
let hasInitialGeminiResponse = false;
|
|
|
|
for await (const event of stream) {
|
|
if (signal.aborted) break;
|
|
|
|
if (event.type === ServerGeminiEventType.Content) {
|
|
// For content events, accumulate the text and update an existing message or create a new one
|
|
currentGeminiText += event.value;
|
|
|
|
if (!hasInitialGeminiResponse) {
|
|
// Create a new Gemini message if this is the first content event
|
|
hasInitialGeminiResponse = true;
|
|
const eventTimestamp = getNextMessageId(userMessageTimestamp);
|
|
currentGeminiMessageIdRef.current = eventTimestamp;
|
|
|
|
addHistoryItem(
|
|
setHistory,
|
|
{ type: 'gemini', text: currentGeminiText },
|
|
eventTimestamp,
|
|
);
|
|
} else if (currentGeminiMessageIdRef.current !== null) {
|
|
// Update the existing message with accumulated content
|
|
updateGeminiMessage(
|
|
currentGeminiMessageIdRef.current,
|
|
currentGeminiText,
|
|
);
|
|
}
|
|
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
|
|
// Reset the Gemini message tracking for the next response
|
|
currentGeminiText = '';
|
|
hasInitialGeminiResponse = false;
|
|
currentGeminiMessageIdRef.current = null;
|
|
|
|
const { callId, name, args } = event.value;
|
|
|
|
const cliTool = toolRegistry.getTool(name); // Get the full CLI tool
|
|
if (!cliTool) {
|
|
console.error(`CLI Tool "${name}" not found!`);
|
|
continue;
|
|
}
|
|
|
|
if (currentToolGroupId === null) {
|
|
currentToolGroupId = getNextMessageId(userMessageTimestamp);
|
|
// Add explicit cast to Omit<HistoryItem, 'id'>
|
|
addHistoryItem(
|
|
setHistory,
|
|
{ type: 'tool_group', tools: [] } as Omit<HistoryItem, 'id'>,
|
|
currentToolGroupId,
|
|
);
|
|
}
|
|
|
|
let description: string;
|
|
try {
|
|
description = cliTool.getDescription(args);
|
|
} catch (e) {
|
|
description = `Error: Unable to get description: ${getErrorMessage(e)}`;
|
|
}
|
|
|
|
// Create the UI display object matching IndividualToolCallDisplay
|
|
const toolCallDisplay: IndividualToolCallDisplay = {
|
|
callId,
|
|
name: cliTool.displayName,
|
|
description,
|
|
status: ToolCallStatus.Pending,
|
|
resultDisplay: undefined,
|
|
confirmationDetails: undefined,
|
|
};
|
|
|
|
// Add pending tool call to the UI history group
|
|
setHistory((prevHistory) =>
|
|
prevHistory.map((item) => {
|
|
if (
|
|
item.id === currentToolGroupId &&
|
|
item.type === 'tool_group'
|
|
) {
|
|
// Ensure item.tools exists and is an array before spreading
|
|
const currentTools = Array.isArray(item.tools)
|
|
? item.tools
|
|
: [];
|
|
return {
|
|
...item,
|
|
tools: [...currentTools, toolCallDisplay], // Add the complete display object
|
|
};
|
|
}
|
|
return item;
|
|
}),
|
|
);
|
|
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
|
|
const status = event.value.error
|
|
? ToolCallStatus.Error
|
|
: ToolCallStatus.Success;
|
|
updateFunctionResponseUI(event.value, status);
|
|
} else if (
|
|
event.type === ServerGeminiEventType.ToolCallConfirmation
|
|
) {
|
|
const confirmationDetails = wireConfirmationSubmission(event.value);
|
|
updateConfirmingFunctionStatusUI(
|
|
event.value.request.callId,
|
|
confirmationDetails,
|
|
);
|
|
setStreamingState(StreamingState.WaitingForConfirmation);
|
|
return;
|
|
}
|
|
}
|
|
|
|
setStreamingState(StreamingState.Idle);
|
|
} catch (error: unknown) {
|
|
if (!isNodeError(error) || error.name !== 'AbortError') {
|
|
console.error('Error processing stream or executing tool:', error);
|
|
addHistoryItem(
|
|
setHistory,
|
|
{
|
|
type: 'error',
|
|
text: `[Error: ${getErrorMessage(error)}]`,
|
|
},
|
|
getNextMessageId(userMessageTimestamp),
|
|
);
|
|
}
|
|
setStreamingState(StreamingState.Idle);
|
|
} finally {
|
|
abortControllerRef.current = null;
|
|
}
|
|
|
|
function updateConfirmingFunctionStatusUI(
|
|
callId: string,
|
|
confirmationDetails: ToolCallConfirmationDetails | undefined,
|
|
) {
|
|
setHistory((prevHistory) =>
|
|
prevHistory.map((item) => {
|
|
if (item.id === currentToolGroupId && item.type === 'tool_group') {
|
|
return {
|
|
...item,
|
|
tools: item.tools.map((tool) =>
|
|
tool.callId === callId
|
|
? {
|
|
...tool,
|
|
status: ToolCallStatus.Confirming,
|
|
confirmationDetails,
|
|
}
|
|
: tool,
|
|
),
|
|
};
|
|
}
|
|
return item;
|
|
}),
|
|
);
|
|
}
|
|
|
|
function updateFunctionResponseUI(
|
|
toolResponse: ToolCallResponseInfo,
|
|
status: ToolCallStatus,
|
|
) {
|
|
setHistory((prevHistory) =>
|
|
prevHistory.map((item) => {
|
|
if (item.id === currentToolGroupId && item.type === 'tool_group') {
|
|
return {
|
|
...item,
|
|
tools: item.tools.map((tool) => {
|
|
if (tool.callId === toolResponse.callId) {
|
|
return {
|
|
...tool,
|
|
status,
|
|
resultDisplay: toolResponse.resultDisplay,
|
|
};
|
|
} else {
|
|
return tool;
|
|
}
|
|
}),
|
|
};
|
|
}
|
|
return item;
|
|
}),
|
|
);
|
|
}
|
|
|
|
function wireConfirmationSubmission(
|
|
confirmationDetails: ServerToolCallConfirmationDetails,
|
|
): ToolCallConfirmationDetails {
|
|
const originalConfirmationDetails = confirmationDetails.details;
|
|
const request = confirmationDetails.request;
|
|
const resubmittingConfirm = async (
|
|
outcome: ToolConfirmationOutcome,
|
|
) => {
|
|
originalConfirmationDetails.onConfirm(outcome);
|
|
|
|
// Reset streaming state since confirmation has been chosen.
|
|
setStreamingState(StreamingState.Idle);
|
|
|
|
if (outcome === ToolConfirmationOutcome.Cancel) {
|
|
let resultDisplay: ToolResultDisplay | undefined;
|
|
if ('fileDiff' in originalConfirmationDetails) {
|
|
resultDisplay = {
|
|
fileDiff: (
|
|
originalConfirmationDetails as ToolEditConfirmationDetails
|
|
).fileDiff,
|
|
};
|
|
} else {
|
|
resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`;
|
|
}
|
|
const functionResponse: Part = {
|
|
functionResponse: {
|
|
id: request.callId,
|
|
name: request.name,
|
|
response: { error: 'User rejected function call.' },
|
|
},
|
|
};
|
|
|
|
const responseInfo: ToolCallResponseInfo = {
|
|
callId: request.callId,
|
|
responsePart: functionResponse,
|
|
resultDisplay,
|
|
error: undefined,
|
|
};
|
|
|
|
updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
|
|
|
|
await submitQuery(functionResponse);
|
|
} else {
|
|
const tool = toolRegistry.getTool(request.name);
|
|
if (!tool) {
|
|
throw new Error(
|
|
`Tool "${request.name}" not found or is not registered.`,
|
|
);
|
|
}
|
|
const result = await tool.execute(request.args);
|
|
const functionResponse: Part = {
|
|
functionResponse: {
|
|
name: request.name,
|
|
id: request.callId,
|
|
response: { output: result.llmContent },
|
|
},
|
|
};
|
|
|
|
const responseInfo: ToolCallResponseInfo = {
|
|
callId: request.callId,
|
|
responsePart: functionResponse,
|
|
resultDisplay: result.returnDisplay,
|
|
error: undefined,
|
|
};
|
|
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
|
|
|
|
await submitQuery(functionResponse);
|
|
}
|
|
};
|
|
|
|
return {
|
|
...originalConfirmationDetails,
|
|
onConfirm: resubmittingConfirm,
|
|
};
|
|
}
|
|
},
|
|
// Dependencies need careful review - including updateGeminiMessage
|
|
[
|
|
streamingState,
|
|
setHistory,
|
|
config.getApiKey(),
|
|
config.getModel(),
|
|
getNextMessageId,
|
|
updateGeminiMessage,
|
|
],
|
|
);
|
|
|
|
return { streamingState, submitQuery, initError, debugMessage };
|
|
};
|
|
|
|
// Define ServerTool interface here if not importing from server (circular dep issue?)
|
|
interface ServerTool {
|
|
name: string;
|
|
schema: FunctionDeclaration;
|
|
shouldConfirmExecute(
|
|
params: Record<string, unknown>,
|
|
): Promise<ToolCallConfirmationDetails | false>;
|
|
execute(params: Record<string, unknown>): Promise<ToolResult>;
|
|
}
|