Adds shell command allowlist (#68)

* Wire through passthrough commands

* Add default passthrough commands

* Clean up config passing to useGeminiStream
This commit is contained in:
Juliette Love 2025-04-20 21:06:22 +01:00 committed by GitHub
parent f480ef4bbc
commit a76d9b4dcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 12 deletions

View File

@ -72,6 +72,7 @@ export function loadCliConfig(): Config {
argv.model || DEFAULT_GEMINI_MODEL, argv.model || DEFAULT_GEMINI_MODEL,
argv.target_dir || process.cwd(), argv.target_dir || process.cwd(),
argv.debug_mode || false, argv.debug_mode || false,
// TODO: load passthroughCommands from .env file
); );
} }

View File

@ -32,7 +32,7 @@ export const App = ({ config }: AppProps) => {
const [history, setHistory] = useState<HistoryItem[]>([]); const [history, setHistory] = useState<HistoryItem[]>([]);
const [startupWarnings, setStartupWarnings] = useState<string[]>([]); const [startupWarnings, setStartupWarnings] = useState<string[]>([]);
const { streamingState, submitQuery, initError, debugMessage } = const { streamingState, submitQuery, initError, debugMessage } =
useGeminiStream(setHistory, config.getApiKey(), config.getModel()); useGeminiStream(setHistory, config);
const { elapsedTime, currentLoadingPhrase } = const { elapsedTime, currentLoadingPhrase } =
useLoadingIndicator(streamingState); useLoadingIndicator(streamingState);

View File

@ -14,6 +14,7 @@ import {
getErrorMessage, getErrorMessage,
isNodeError, isNodeError,
ToolResult, ToolResult,
Config,
} from '@gemini-code/server'; } from '@gemini-code/server';
import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai'; import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai';
// Import CLI types // Import CLI types
@ -27,8 +28,6 @@ import { StreamingState } from '../../core/gemini-stream.js';
// Import CLI tool registry // Import CLI tool registry
import { toolRegistry } from '../../tools/tool-registry.js'; import { toolRegistry } from '../../tools/tool-registry.js';
const _allowlistedCommands = ['ls']; // Prefix with underscore since it's unused
const addHistoryItem = ( const addHistoryItem = (
setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>, setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>,
itemData: Omit<HistoryItem, 'id'>, itemData: Omit<HistoryItem, 'id'>,
@ -43,8 +42,7 @@ const addHistoryItem = (
// Hook now accepts apiKey and model // Hook now accepts apiKey and model
export const useGeminiStream = ( export const useGeminiStream = (
setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>, setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>,
apiKey: string, config: Config,
model: string,
) => { ) => {
const [streamingState, setStreamingState] = useState<StreamingState>( const [streamingState, setStreamingState] = useState<StreamingState>(
StreamingState.Idle, StreamingState.Idle,
@ -62,15 +60,17 @@ export const useGeminiStream = (
setInitError(null); setInitError(null);
if (!geminiClientRef.current) { if (!geminiClientRef.current) {
try { try {
geminiClientRef.current = new GeminiClient(apiKey, model); geminiClientRef.current = new GeminiClient(
config.getApiKey(),
config.getModel(),
);
} catch (error: unknown) { } catch (error: unknown) {
setInitError( setInitError(
`Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`, `Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`,
); );
} }
} }
// Dependency array includes apiKey and model now }, [config.getApiKey(), config.getModel()]);
}, [apiKey, model]);
// Input Handling Effect (remains the same) // Input Handling Effect (remains the same)
useInput((input, key) => { useInput((input, key) => {
@ -107,6 +107,39 @@ export const useGeminiStream = (
if (typeof query === 'string') { if (typeof query === 'string') {
setDebugMessage(`User query: ${query}`); setDebugMessage(`User query: ${query}`);
const maybeCommand = query.split(/\s+/)[0];
if (config.getPassthroughCommands().includes(maybeCommand)) {
// Execute and capture output
setDebugMessage(`Executing shell command directly: ${query}`);
_exec(query, (error, stdout, stderr) => {
const timestamp = getNextMessageId(Date.now());
if (error) {
addHistoryItem(
setHistory,
{ type: 'error', text: error.message },
timestamp,
);
} else if (stderr) {
addHistoryItem(
setHistory,
{ type: 'error', text: stderr },
timestamp,
);
} else {
// Add stdout as an info message
addHistoryItem(
setHistory,
{ type: 'info', text: stdout || '' },
timestamp,
);
}
// Set state back to Idle *after* command finishes and output is added
setStreamingState(StreamingState.Idle);
});
// Set state to Responding while the command runs
setStreamingState(StreamingState.Responding);
return; // Prevent Gemini call
}
} }
const userMessageTimestamp = Date.now(); const userMessageTimestamp = Date.now();
@ -391,7 +424,8 @@ export const useGeminiStream = (
} }
} finally { } finally {
abortControllerRef.current = null; abortControllerRef.current = null;
// Only set to Idle if not waiting for confirmation // Only set to Idle if not waiting for confirmation.
// Passthrough commands handle their own Idle transition.
if (streamingState !== StreamingState.WaitingForConfirmation) { if (streamingState !== StreamingState.WaitingForConfirmation) {
setStreamingState(StreamingState.Idle); setStreamingState(StreamingState.Idle);
} }
@ -401,8 +435,8 @@ export const useGeminiStream = (
[ [
streamingState, streamingState,
setHistory, setHistory,
apiKey, config.getApiKey(),
model, config.getModel(),
getNextMessageId, getNextMessageId,
updateGeminiMessage, updateGeminiMessage,
], ],

View File

@ -9,22 +9,28 @@ import * as fs from 'node:fs';
import * as path from 'node:path'; import * as path from 'node:path';
import process from 'node:process'; import process from 'node:process';
const DEFAULT_PASSTHROUGH_COMMANDS = ['ls', 'git', 'npm'];
export class Config { export class Config {
private apiKey: string; private apiKey: string;
private model: string; private model: string;
private targetDir: string; private targetDir: string;
private debugMode: boolean; private debugMode: boolean;
private passthroughCommands: string[];
constructor( constructor(
apiKey: string, apiKey: string,
model: string, model: string,
targetDir: string, targetDir: string,
debugMode: boolean, debugMode: boolean,
passthroughCommands?: string[],
) { ) {
this.apiKey = apiKey; this.apiKey = apiKey;
this.model = model; this.model = model;
this.targetDir = targetDir; this.targetDir = targetDir;
this.debugMode = debugMode; this.debugMode = debugMode;
this.passthroughCommands =
passthroughCommands || DEFAULT_PASSTHROUGH_COMMANDS;
} }
getApiKey(): string { getApiKey(): string {
@ -42,6 +48,10 @@ export class Config {
getDebugMode(): boolean { getDebugMode(): boolean {
return this.debugMode; return this.debugMode;
} }
getPassthroughCommands(): string[] {
return this.passthroughCommands;
}
} }
function findEnvFile(startDir: string): string | null { function findEnvFile(startDir: string): string | null {
@ -72,6 +82,13 @@ export function createServerConfig(
model: string, model: string,
targetDir: string, targetDir: string,
debugMode: boolean, debugMode: boolean,
passthroughCommands?: string[],
): Config { ): Config {
return new Config(apiKey, model, path.resolve(targetDir), debugMode); return new Config(
apiKey,
model,
path.resolve(targetDir),
debugMode,
passthroughCommands,
);
} }