live output from shell tool (#573)
This commit is contained in:
parent
0d5f7686d7
commit
bfeaac8441
|
@ -70,21 +70,25 @@ export const useGeminiStream = (
|
||||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||||
const logger = useLogger();
|
const logger = useLogger();
|
||||||
const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
|
const [toolCalls, schedule, cancel] = useToolScheduler(
|
||||||
if (tools.length) {
|
(tools) => {
|
||||||
addItem(mapToDisplay(tools), Date.now());
|
if (tools.length) {
|
||||||
submitQuery(
|
addItem(mapToDisplay(tools), Date.now());
|
||||||
tools
|
submitQuery(
|
||||||
.filter(
|
tools
|
||||||
(t) =>
|
.filter(
|
||||||
t.status === 'error' ||
|
(t) =>
|
||||||
t.status === 'cancelled' ||
|
t.status === 'error' ||
|
||||||
t.status === 'success',
|
t.status === 'cancelled' ||
|
||||||
)
|
t.status === 'success',
|
||||||
.map((t) => t.response.responsePart),
|
)
|
||||||
);
|
.map((t) => t.response.responsePart),
|
||||||
}
|
);
|
||||||
}, config);
|
}
|
||||||
|
},
|
||||||
|
config,
|
||||||
|
setPendingHistoryItem,
|
||||||
|
);
|
||||||
const pendingToolCalls = useMemo(
|
const pendingToolCalls = useMemo(
|
||||||
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
|
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
|
||||||
[toolCalls],
|
[toolCalls],
|
||||||
|
|
|
@ -11,6 +11,7 @@ import {
|
||||||
ToolConfirmationOutcome,
|
ToolConfirmationOutcome,
|
||||||
Tool,
|
Tool,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
|
ToolResult,
|
||||||
} from '@gemini-code/server';
|
} from '@gemini-code/server';
|
||||||
import { Part } from '@google/genai';
|
import { Part } from '@google/genai';
|
||||||
import { useCallback, useEffect, useState } from 'react';
|
import { useCallback, useEffect, useState } from 'react';
|
||||||
|
@ -18,6 +19,7 @@ import {
|
||||||
HistoryItemToolGroup,
|
HistoryItemToolGroup,
|
||||||
IndividualToolCallDisplay,
|
IndividualToolCallDisplay,
|
||||||
ToolCallStatus,
|
ToolCallStatus,
|
||||||
|
HistoryItemWithoutId,
|
||||||
} from '../types.js';
|
} from '../types.js';
|
||||||
|
|
||||||
type ValidatingToolCall = {
|
type ValidatingToolCall = {
|
||||||
|
@ -45,10 +47,11 @@ type SuccessfulToolCall = {
|
||||||
response: ToolCallResponseInfo;
|
response: ToolCallResponseInfo;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ExecutingToolCall = {
|
export type ExecutingToolCall = {
|
||||||
status: 'executing';
|
status: 'executing';
|
||||||
request: ToolCallRequestInfo;
|
request: ToolCallRequestInfo;
|
||||||
tool: Tool;
|
tool: Tool;
|
||||||
|
liveOutput?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
type CancelledToolCall = {
|
type CancelledToolCall = {
|
||||||
|
@ -88,6 +91,9 @@ export type CompletedToolCall =
|
||||||
export function useToolScheduler(
|
export function useToolScheduler(
|
||||||
onComplete: (tools: CompletedToolCall[]) => void,
|
onComplete: (tools: CompletedToolCall[]) => void,
|
||||||
config: Config,
|
config: Config,
|
||||||
|
setPendingHistoryItem: React.Dispatch<
|
||||||
|
React.SetStateAction<HistoryItemWithoutId | null>
|
||||||
|
>,
|
||||||
): [ToolCall[], ScheduleFn, CancelFn] {
|
): [ToolCall[], ScheduleFn, CancelFn] {
|
||||||
const [toolRegistry] = useState(() => config.getToolRegistry());
|
const [toolRegistry] = useState(() => config.getToolRegistry());
|
||||||
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);
|
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);
|
||||||
|
@ -224,9 +230,48 @@ export function useToolScheduler(
|
||||||
.forEach((t) => {
|
.forEach((t) => {
|
||||||
const callId = t.request.callId;
|
const callId = t.request.callId;
|
||||||
setToolCalls(setStatus(t.request.callId, 'executing'));
|
setToolCalls(setStatus(t.request.callId, 'executing'));
|
||||||
|
|
||||||
|
let accumulatedOutput = '';
|
||||||
|
const onOutputChunk =
|
||||||
|
t.tool.name === 'execute_bash_command'
|
||||||
|
? (chunk: string) => {
|
||||||
|
accumulatedOutput += chunk;
|
||||||
|
setPendingHistoryItem(
|
||||||
|
(prevItem: HistoryItemWithoutId | null) => {
|
||||||
|
if (prevItem?.type === 'tool_group') {
|
||||||
|
return {
|
||||||
|
...prevItem,
|
||||||
|
tools: prevItem.tools.map(
|
||||||
|
(toolDisplay: IndividualToolCallDisplay) =>
|
||||||
|
toolDisplay.callId === callId &&
|
||||||
|
toolDisplay.status === ToolCallStatus.Executing
|
||||||
|
? {
|
||||||
|
...toolDisplay,
|
||||||
|
resultDisplay: accumulatedOutput,
|
||||||
|
}
|
||||||
|
: toolDisplay,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return prevItem;
|
||||||
|
},
|
||||||
|
);
|
||||||
|
// Also update the toolCall itself so that mapToDisplay
|
||||||
|
// can pick up the live output if the item is not pending
|
||||||
|
// (e.g. if it's being re-rendered from history)
|
||||||
|
setToolCalls((prevToolCalls) =>
|
||||||
|
prevToolCalls.map((tc) =>
|
||||||
|
tc.request.callId === callId && tc.status === 'executing'
|
||||||
|
? { ...tc, liveOutput: accumulatedOutput }
|
||||||
|
: tc,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
: undefined;
|
||||||
|
|
||||||
t.tool
|
t.tool
|
||||||
.execute(t.request.args, signal)
|
.execute(t.request.args, signal, onOutputChunk)
|
||||||
.then((result) => {
|
.then((result: ToolResult) => {
|
||||||
if (signal.aborted) {
|
if (signal.aborted) {
|
||||||
setToolCalls(
|
setToolCalls(
|
||||||
setStatus(callId, 'cancelled', String(result.llmContent)),
|
setStatus(callId, 'cancelled', String(result.llmContent)),
|
||||||
|
@ -248,7 +293,7 @@ export function useToolScheduler(
|
||||||
};
|
};
|
||||||
setToolCalls(setStatus(callId, 'success', response));
|
setToolCalls(setStatus(callId, 'success', response));
|
||||||
})
|
})
|
||||||
.catch((e) =>
|
.catch((e: Error) =>
|
||||||
setToolCalls(
|
setToolCalls(
|
||||||
setStatus(
|
setStatus(
|
||||||
callId,
|
callId,
|
||||||
|
@ -262,7 +307,7 @@ export function useToolScheduler(
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [toolCalls, toolRegistry, abortController.signal]);
|
}, [toolCalls, toolRegistry, abortController.signal, setPendingHistoryItem]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const allDone = toolCalls.every(
|
const allDone = toolCalls.every(
|
||||||
|
@ -480,7 +525,7 @@ export function mapToDisplay(
|
||||||
callId: t.request.callId,
|
callId: t.request.callId,
|
||||||
name: t.tool.displayName,
|
name: t.tool.displayName,
|
||||||
description: t.tool.getDescription(t.request.args),
|
description: t.tool.getDescription(t.request.args),
|
||||||
resultDisplay: undefined,
|
resultDisplay: t.liveOutput ?? undefined,
|
||||||
status: mapStatus(t.status),
|
status: mapStatus(t.status),
|
||||||
confirmationDetails: undefined,
|
confirmationDetails: undefined,
|
||||||
};
|
};
|
||||||
|
|
|
@ -123,6 +123,7 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||||
async execute(
|
async execute(
|
||||||
params: ShellToolParams,
|
params: ShellToolParams,
|
||||||
abortSignal: AbortSignal,
|
abortSignal: AbortSignal,
|
||||||
|
onOutputChunk?: (chunk: string) => void,
|
||||||
): Promise<ToolResult> {
|
): Promise<ToolResult> {
|
||||||
const validationError = this.validateToolParams(params);
|
const validationError = this.validateToolParams(params);
|
||||||
if (validationError) {
|
if (validationError) {
|
||||||
|
@ -157,6 +158,9 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||||
const str = data.toString();
|
const str = data.toString();
|
||||||
stdout += str;
|
stdout += str;
|
||||||
output += str;
|
output += str;
|
||||||
|
if (onOutputChunk) {
|
||||||
|
onOutputChunk(str);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let stderr = '';
|
let stderr = '';
|
||||||
|
@ -174,6 +178,9 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||||
}
|
}
|
||||||
stderr += str;
|
stderr += str;
|
||||||
output += str;
|
output += str;
|
||||||
|
if (onOutputChunk) {
|
||||||
|
onOutputChunk(str);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let error: Error | null = null;
|
let error: Error | null = null;
|
||||||
|
|
|
@ -64,7 +64,11 @@ export interface Tool<
|
||||||
* @param params Parameters for the tool execution
|
* @param params Parameters for the tool execution
|
||||||
* @returns Result of the tool execution
|
* @returns Result of the tool execution
|
||||||
*/
|
*/
|
||||||
execute(params: TParams, signal: AbortSignal): Promise<TResult>;
|
execute(
|
||||||
|
params: TParams,
|
||||||
|
signal: AbortSignal,
|
||||||
|
onOutputChunk?: (chunk: string) => void,
|
||||||
|
): Promise<TResult>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -144,7 +148,11 @@ export abstract class BaseTool<
|
||||||
* @param signal AbortSignal for tool cancellation
|
* @param signal AbortSignal for tool cancellation
|
||||||
* @returns Result of the tool execution
|
* @returns Result of the tool execution
|
||||||
*/
|
*/
|
||||||
abstract execute(params: TParams, signal: AbortSignal): Promise<TResult>;
|
abstract execute(
|
||||||
|
params: TParams,
|
||||||
|
signal: AbortSignal,
|
||||||
|
onOutputChunk?: (chunk: string) => void,
|
||||||
|
): Promise<TResult>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolResult {
|
export interface ToolResult {
|
||||||
|
|
Loading…
Reference in New Issue