refactor: derive streaming state from tool calls and isresponding state (#376)

This commit is contained in:
Brandon Keiji 2025-05-16 16:45:58 +00:00 committed by GitHub
parent 609757f911
commit 458fd86429
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 58 deletions

View File

@ -8,7 +8,6 @@ import { exec as _exec } from 'child_process';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { Config } from '@gemini-code/server'; import { Config } from '@gemini-code/server';
import { type PartListUnion } from '@google/genai'; import { type PartListUnion } from '@google/genai';
import { StreamingState } from '../types.js';
import { getCommandFromQuery } from '../utils/commandUtils.js'; import { getCommandFromQuery } from '../utils/commandUtils.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js';
@ -18,7 +17,7 @@ import { UseHistoryManagerReturn } from './useHistoryManager.js';
*/ */
export const useShellCommandProcessor = ( export const useShellCommandProcessor = (
addItemToHistory: UseHistoryManagerReturn['addItem'], addItemToHistory: UseHistoryManagerReturn['addItem'],
setStreamingState: React.Dispatch<React.SetStateAction<StreamingState>>, onExec: (command: Promise<void>) => void,
onDebugMessage: (message: string) => void, onDebugMessage: (message: string) => void,
config: Config, config: Config,
) => { ) => {
@ -57,8 +56,7 @@ export const useShellCommandProcessor = (
cwd: targetDir, cwd: targetDir,
}; };
setStreamingState(StreamingState.Responding); const execPromise = new Promise<void>((resolve) => {
_exec(commandToExecute, execOptions, (error, stdout, stderr) => { _exec(commandToExecute, execOptions, (error, stdout, stderr) => {
if (error) { if (error) {
addItemToHistory( addItemToHistory(
@ -75,12 +73,19 @@ export const useShellCommandProcessor = (
userMessageTimestamp, userMessageTimestamp,
); );
} }
setStreamingState(StreamingState.Idle); resolve();
}); });
});
try {
onExec(execPromise);
} catch (_e) {
// silently ignore errors from this since it's from the caller
}
return true; // Command was initiated return true; // Command was initiated
}, },
[config, onDebugMessage, addItemToHistory, setStreamingState], [config, onDebugMessage, addItemToHistory, onExec],
); );
return { handleShellCommand }; return { handleShellCommand };

View File

@ -62,19 +62,22 @@ export const useGeminiStream = (
handleSlashCommand: (cmd: PartListUnion) => boolean, handleSlashCommand: (cmd: PartListUnion) => boolean,
) => { ) => {
const toolRegistry = config.getToolRegistry(); const toolRegistry = config.getToolRegistry();
const [streamingState, setStreamingState] = useState<StreamingState>(
StreamingState.Idle,
);
const [initError, setInitError] = useState<string | null>(null); const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null); const abortControllerRef = useRef<AbortController | null>(null);
const chatSessionRef = useRef<Chat | null>(null); const chatSessionRef = useRef<Chat | null>(null);
const geminiClientRef = useRef<GeminiClient | null>(null); const geminiClientRef = useRef<GeminiClient | null>(null);
const [isResponding, setIsResponding] = useState<boolean>(false);
const [pendingHistoryItemRef, setPendingHistoryItem] = const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null); useStateAndRef<HistoryItemWithoutId | null>(null);
const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true);
await done;
setIsResponding(false);
}, []);
const { handleShellCommand } = useShellCommandProcessor( const { handleShellCommand } = useShellCommandProcessor(
addItem, addItem,
setStreamingState, onExec,
onDebugMessage, onDebugMessage,
config, config,
); );
@ -176,7 +179,6 @@ export const useGeminiStream = (
const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`;
setInitError(errorMsg); setInitError(errorMsg);
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
setStreamingState(StreamingState.Idle);
return { client: currentClient, chat: null }; return { client: currentClient, chat: null };
} }
} }
@ -192,19 +194,17 @@ export const useGeminiStream = (
item?.type === 'tool_group' item?.type === 'tool_group'
? { ? {
...item, ...item,
tools: item.tools.map((tool) => { tools: item.tools.map((tool) =>
if (tool.callId === toolResponse.callId) { tool.callId === toolResponse.callId
return { ? {
...tool, ...tool,
status, status,
resultDisplay: toolResponse.resultDisplay, resultDisplay: toolResponse.resultDisplay,
};
} else {
return tool;
} }
}), : tool,
),
} }
: null, : item,
); );
}; };
@ -212,7 +212,6 @@ export const useGeminiStream = (
callId: string, callId: string,
confirmationDetails: ToolCallConfirmationDetails | undefined, confirmationDetails: ToolCallConfirmationDetails | undefined,
) => { ) => {
if (pendingHistoryItemRef.current?.type !== 'tool_group') return;
setPendingHistoryItem((item) => setPendingHistoryItem((item) =>
item?.type === 'tool_group' item?.type === 'tool_group'
? { ? {
@ -227,11 +226,10 @@ export const useGeminiStream = (
: tool, : tool,
), ),
} }
: null, : item,
); );
}; };
// This function will be fully refactored in a later step
const wireConfirmationSubmission = ( const wireConfirmationSubmission = (
confirmationDetails: ServerToolCallConfirmationDetails, confirmationDetails: ServerToolCallConfirmationDetails,
): ToolCallConfirmationDetails => { ): ToolCallConfirmationDetails => {
@ -306,7 +304,7 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, Date.now()); addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null); setPendingHistoryItem(null);
} }
setStreamingState(StreamingState.Idle); setIsResponding(false);
await submitQuery(functionResponse); // Recursive call await submitQuery(functionResponse); // Recursive call
} finally { } finally {
if (streamingState !== StreamingState.WaitingForConfirmation) { if (streamingState !== StreamingState.WaitingForConfirmation) {
@ -354,7 +352,7 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, Date.now()); addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null); setPendingHistoryItem(null);
} }
setStreamingState(StreamingState.Idle); setIsResponding(false);
} }
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm }; return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
@ -466,7 +464,6 @@ export const useGeminiStream = (
eventValue.request.callId, eventValue.request.callId,
confirmationDetails, confirmationDetails,
); );
setStreamingState(StreamingState.WaitingForConfirmation);
}; };
const handleUserCancelledEvent = (userMessageTimestamp: number) => { const handleUserCancelledEvent = (userMessageTimestamp: number) => {
@ -493,7 +490,7 @@ export const useGeminiStream = (
{ type: MessageType.INFO, text: 'User cancelled the request.' }, { type: MessageType.INFO, text: 'User cancelled the request.' },
userMessageTimestamp, userMessageTimestamp,
); );
setStreamingState(StreamingState.Idle); setIsResponding(false);
}; };
const handleErrorEvent = ( const handleErrorEvent = (
@ -529,7 +526,7 @@ export const useGeminiStream = (
handleToolCallResponseEvent(event.value); handleToolCallResponseEvent(event.value);
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) { } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
handleToolCallConfirmationEvent(event.value); handleToolCallConfirmationEvent(event.value);
return StreamProcessingStatus.PausedForConfirmation; // Explicit return as this pauses the stream return StreamProcessingStatus.PausedForConfirmation;
} else if (event.type === ServerGeminiEventType.UserCancelled) { } else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp); handleUserCancelledEvent(userMessageTimestamp);
return StreamProcessingStatus.UserCancelled; return StreamProcessingStatus.UserCancelled;
@ -543,7 +540,7 @@ export const useGeminiStream = (
const submitQuery = useCallback( const submitQuery = useCallback(
async (query: PartListUnion) => { async (query: PartListUnion) => {
if (streamingState === StreamingState.Responding) return; if (isResponding) return;
const userMessageTimestamp = Date.now(); const userMessageTimestamp = Date.now();
setShowHelp(false); setShowHelp(false);
@ -567,7 +564,7 @@ export const useGeminiStream = (
return; return;
} }
setStreamingState(StreamingState.Responding); setIsResponding(true);
setInitError(null); setInitError(null);
try { try {
@ -588,13 +585,6 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, userMessageTimestamp); addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null); setPendingHistoryItem(null);
} }
if (
processingStatus === StreamProcessingStatus.Completed ||
processingStatus === StreamProcessingStatus.Error
) {
setStreamingState(StreamingState.Idle);
}
} catch (error: unknown) { } catch (error: unknown) {
if (!isNodeError(error) || error.name !== 'AbortError') { if (!isNodeError(error) || error.name !== 'AbortError') {
addItem( addItem(
@ -605,16 +595,16 @@ export const useGeminiStream = (
userMessageTimestamp, userMessageTimestamp,
); );
} }
setStreamingState(StreamingState.Idle);
} finally { } finally {
if (streamingState !== StreamingState.WaitingForConfirmation) { if (streamingState !== StreamingState.WaitingForConfirmation) {
abortControllerRef.current = null; abortControllerRef.current = null;
} }
setIsResponding(false);
} }
}, },
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
[ [
streamingState, isResponding,
setShowHelp, setShowHelp,
handleSlashCommand, handleSlashCommand,
handleShellCommand, handleShellCommand,
@ -623,10 +613,15 @@ export const useGeminiStream = (
onDebugMessage, onDebugMessage,
refreshStatic, refreshStatic,
setInitError, setInitError,
setStreamingState,
], ],
); );
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingConfirmations(pendingHistoryItemRef.current)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle;
return { return {
streamingState, streamingState,
submitQuery, submitQuery,
@ -634,3 +629,7 @@ export const useGeminiStream = (
pendingHistoryItem: pendingHistoryItemRef.current, pendingHistoryItem: pendingHistoryItemRef.current,
}; };
}; };
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
item?.type === 'tool_group' &&
item.tools.some((t) => t.status === ToolCallStatus.Confirming);