refactor: derive streaming state from tool calls and isresponding state (#376)

This commit is contained in:
Brandon Keiji 2025-05-16 16:45:58 +00:00 committed by GitHub
parent 609757f911
commit 458fd86429
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 58 deletions

View File

@ -8,7 +8,6 @@ import { exec as _exec } from 'child_process';
import { useCallback } from 'react';
import { Config } from '@gemini-code/server';
import { type PartListUnion } from '@google/genai';
import { StreamingState } from '../types.js';
import { getCommandFromQuery } from '../utils/commandUtils.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
@ -18,7 +17,7 @@ import { UseHistoryManagerReturn } from './useHistoryManager.js';
*/
export const useShellCommandProcessor = (
addItemToHistory: UseHistoryManagerReturn['addItem'],
setStreamingState: React.Dispatch<React.SetStateAction<StreamingState>>,
onExec: (command: Promise<void>) => void,
onDebugMessage: (message: string) => void,
config: Config,
) => {
@ -57,30 +56,36 @@ export const useShellCommandProcessor = (
cwd: targetDir,
};
setStreamingState(StreamingState.Responding);
const execPromise = new Promise<void>((resolve) => {
_exec(commandToExecute, execOptions, (error, stdout, stderr) => {
if (error) {
addItemToHistory(
{ type: 'error', text: error.message },
userMessageTimestamp,
);
} else {
let output = '';
if (stdout) output += stdout;
if (stderr) output += (output ? '\n' : '') + stderr; // Include stderr as info
_exec(commandToExecute, execOptions, (error, stdout, stderr) => {
if (error) {
addItemToHistory(
{ type: 'error', text: error.message },
userMessageTimestamp,
);
} else {
let output = '';
if (stdout) output += stdout;
if (stderr) output += (output ? '\n' : '') + stderr; // Include stderr as info
addItemToHistory(
{ type: 'info', text: output || '(Command produced no output)' },
userMessageTimestamp,
);
}
setStreamingState(StreamingState.Idle);
addItemToHistory(
{ type: 'info', text: output || '(Command produced no output)' },
userMessageTimestamp,
);
}
resolve();
});
});
try {
onExec(execPromise);
} catch (_e) {
// silently ignore errors from this since it's from the caller
}
return true; // Command was initiated
},
[config, onDebugMessage, addItemToHistory, setStreamingState],
[config, onDebugMessage, addItemToHistory, onExec],
);
return { handleShellCommand };

View File

@ -62,19 +62,22 @@ export const useGeminiStream = (
handleSlashCommand: (cmd: PartListUnion) => boolean,
) => {
const toolRegistry = config.getToolRegistry();
const [streamingState, setStreamingState] = useState<StreamingState>(
StreamingState.Idle,
);
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const chatSessionRef = useRef<Chat | null>(null);
const geminiClientRef = useRef<GeminiClient | null>(null);
const [isResponding, setIsResponding] = useState<boolean>(false);
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true);
await done;
setIsResponding(false);
}, []);
const { handleShellCommand } = useShellCommandProcessor(
addItem,
setStreamingState,
onExec,
onDebugMessage,
config,
);
@ -176,7 +179,6 @@ export const useGeminiStream = (
const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`;
setInitError(errorMsg);
addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now());
setStreamingState(StreamingState.Idle);
return { client: currentClient, chat: null };
}
}
@ -192,19 +194,17 @@ export const useGeminiStream = (
item?.type === 'tool_group'
? {
...item,
tools: item.tools.map((tool) => {
if (tool.callId === toolResponse.callId) {
return {
...tool,
status,
resultDisplay: toolResponse.resultDisplay,
};
} else {
return tool;
}
}),
tools: item.tools.map((tool) =>
tool.callId === toolResponse.callId
? {
...tool,
status,
resultDisplay: toolResponse.resultDisplay,
}
: tool,
),
}
: null,
: item,
);
};
@ -212,7 +212,6 @@ export const useGeminiStream = (
callId: string,
confirmationDetails: ToolCallConfirmationDetails | undefined,
) => {
if (pendingHistoryItemRef.current?.type !== 'tool_group') return;
setPendingHistoryItem((item) =>
item?.type === 'tool_group'
? {
@ -227,11 +226,10 @@ export const useGeminiStream = (
: tool,
),
}
: null,
: item,
);
};
// This function will be fully refactored in a later step
const wireConfirmationSubmission = (
confirmationDetails: ServerToolCallConfirmationDetails,
): ToolCallConfirmationDetails => {
@ -306,7 +304,7 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setStreamingState(StreamingState.Idle);
setIsResponding(false);
await submitQuery(functionResponse); // Recursive call
} finally {
if (streamingState !== StreamingState.WaitingForConfirmation) {
@ -354,7 +352,7 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setStreamingState(StreamingState.Idle);
setIsResponding(false);
}
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
@ -466,7 +464,6 @@ export const useGeminiStream = (
eventValue.request.callId,
confirmationDetails,
);
setStreamingState(StreamingState.WaitingForConfirmation);
};
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
@ -493,7 +490,7 @@ export const useGeminiStream = (
{ type: MessageType.INFO, text: 'User cancelled the request.' },
userMessageTimestamp,
);
setStreamingState(StreamingState.Idle);
setIsResponding(false);
};
const handleErrorEvent = (
@ -529,7 +526,7 @@ export const useGeminiStream = (
handleToolCallResponseEvent(event.value);
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
handleToolCallConfirmationEvent(event.value);
return StreamProcessingStatus.PausedForConfirmation; // Explicit return as this pauses the stream
return StreamProcessingStatus.PausedForConfirmation;
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
return StreamProcessingStatus.UserCancelled;
@ -543,7 +540,7 @@ export const useGeminiStream = (
const submitQuery = useCallback(
async (query: PartListUnion) => {
if (streamingState === StreamingState.Responding) return;
if (isResponding) return;
const userMessageTimestamp = Date.now();
setShowHelp(false);
@ -567,7 +564,7 @@ export const useGeminiStream = (
return;
}
setStreamingState(StreamingState.Responding);
setIsResponding(true);
setInitError(null);
try {
@ -588,13 +585,6 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null);
}
if (
processingStatus === StreamProcessingStatus.Completed ||
processingStatus === StreamProcessingStatus.Error
) {
setStreamingState(StreamingState.Idle);
}
} catch (error: unknown) {
if (!isNodeError(error) || error.name !== 'AbortError') {
addItem(
@ -605,16 +595,16 @@ export const useGeminiStream = (
userMessageTimestamp,
);
}
setStreamingState(StreamingState.Idle);
} finally {
if (streamingState !== StreamingState.WaitingForConfirmation) {
abortControllerRef.current = null;
}
setIsResponding(false);
}
},
// eslint-disable-next-line react-hooks/exhaustive-deps
[
streamingState,
isResponding,
setShowHelp,
handleSlashCommand,
handleShellCommand,
@ -623,10 +613,15 @@ export const useGeminiStream = (
onDebugMessage,
refreshStatic,
setInitError,
setStreamingState,
],
);
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingConfirmations(pendingHistoryItemRef.current)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle;
return {
streamingState,
submitQuery,
@ -634,3 +629,7 @@ export const useGeminiStream = (
pendingHistoryItem: pendingHistoryItemRef.current,
};
};
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
item?.type === 'tool_group' &&
item.tools.some((t) => t.status === ToolCallStatus.Confirming);