Refactor in preparation for Reauth (#1196)

This commit is contained in:
Tommaso Sciortino 2025-06-18 16:34:00 -07:00 committed by GitHub
parent b96fbd913e
commit 8bc3b415c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 128 additions and 98 deletions

View File

@ -47,6 +47,10 @@ export async function main() {
const extensions = loadExtensions(workspaceRoot); const extensions = loadExtensions(workspaceRoot);
const config = await loadCliConfig(settings.merged, extensions, sessionId); const config = await loadCliConfig(settings.merged, extensions, sessionId);
// When using Code Assist this triggers the Oauth login.
// Do this now, before sandboxing, so web redirect works.
await config.getGeminiClient().initialize();
// Initialize centralized FileDiscoveryService // Initialize centralized FileDiscoveryService
config.getFileService(); config.getFileService();
if (config.getCheckpointEnabled()) { if (config.getCheckpointEnabled()) {
@ -65,10 +69,6 @@ export async function main() {
} }
} }
// When using Code Assist this triggers the Oauth login.
// Do this now, before sandboxing, so web redirect works.
await config.getGeminiClient().getChat();
// hop into sandbox if we are outside and sandboxing is enabled // hop into sandbox if we are outside and sandboxing is enabled
if (!process.env.SANDBOX) { if (!process.env.SANDBOX) {
const sandboxConfig = config.getSandbox(); const sandboxConfig = config.getSandbox();

View File

@ -73,8 +73,10 @@ describe('oauth2', () => {
(resolve) => (serverListeningCallback = resolve), (resolve) => (serverListeningCallback = resolve),
); );
let capturedPort = 0;
const mockHttpServer = { const mockHttpServer = {
listen: vi.fn((port: number, callback?: () => void) => { listen: vi.fn((port: number, callback?: () => void) => {
capturedPort = port;
if (callback) { if (callback) {
callback(); callback();
} }
@ -86,7 +88,7 @@ describe('oauth2', () => {
} }
}), }),
on: vi.fn(), on: vi.fn(),
address: () => ({ port: 1234 }), address: () => ({ port: capturedPort }),
}; };
vi.mocked(http.createServer).mockImplementation((cb) => { vi.mocked(http.createServer).mockImplementation((cb) => {
requestCallback = cb as http.RequestListener< requestCallback = cb as http.RequestListener<
@ -115,7 +117,10 @@ describe('oauth2', () => {
expect(client).toBe(mockOAuth2Client); expect(client).toBe(mockOAuth2Client);
expect(open).toHaveBeenCalledWith(mockAuthUrl); expect(open).toHaveBeenCalledWith(mockAuthUrl);
expect(mockGetToken).toHaveBeenCalledWith(mockCode); expect(mockGetToken).toHaveBeenCalledWith({
code: mockCode,
redirect_uri: `http://localhost:${capturedPort}/oauth2callback`,
});
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json'); const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');

View File

@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL =
const GEMINI_DIR = '.gemini'; const GEMINI_DIR = '.gemini';
const CREDENTIAL_FILENAME = 'oauth_creds.json'; const CREDENTIAL_FILENAME = 'oauth_creds.json';
export async function getOauthClient(): Promise<OAuth2Client> { /**
try { * An Authentication URL for updating the credentials of a Oauth2Client
return await getCachedCredentialClient(); * as well as a promise that will resolve when the credentials have
} catch (_) { * been refreshed (or which throws error when refreshing credentials failed).
const loggedInClient = await webLoginClient(); */
await setCachedCredentials(loggedInClient.credentials); export interface OauthWebLogin {
return loggedInClient; authUrl: string;
} loginCompletePromise: Promise<void>;
} }
async function webLoginClient(): Promise<OAuth2Client> { export async function getOauthClient(): Promise<OAuth2Client> {
const port = await getAvailablePort(); const client = new OAuth2Client({
const oAuth2Client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID, clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET, clientSecret: OAUTH_CLIENT_SECRET,
redirectUri: `http://localhost:${port}/oauth2callback`,
}); });
return new Promise((resolve, reject) => { if (await loadCachedCredentials(client)) {
const state = crypto.randomBytes(32).toString('hex'); // Found valid cached credentials.
const authURL: string = oAuth2Client.generateAuthUrl({ return client;
access_type: 'offline', }
scope: OAUTH_SCOPE,
state,
});
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${authURL}\n\n`,
);
open(authURL);
console.log('Waiting for authentication...');
const webLogin = await authWithWeb(client);
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
);
await open(webLogin.authUrl);
console.log('Waiting for authentication...');
await webLogin.loginCompletePromise;
return client;
}
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
const port = await getAvailablePort();
const redirectUri = `http://localhost:${port}/oauth2callback`;
const state = crypto.randomBytes(32).toString('hex');
const authUrl: string = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => { const server = http.createServer(async (req, res) => {
try { try {
if (req.url!.indexOf('/oauth2callback') === -1) { if (req.url!.indexOf('/oauth2callback') === -1) {
@ -94,13 +109,16 @@ async function webLoginClient(): Promise<OAuth2Client> {
reject(new Error('State mismatch. Possible CSRF attack')); reject(new Error('State mismatch. Possible CSRF attack'));
} else if (qs.get('code')) { } else if (qs.get('code')) {
const code: string = qs.get('code')!; const { tokens } = await client.getToken({
const { tokens } = await oAuth2Client.getToken(code); code: qs.get('code')!,
oAuth2Client.setCredentials(tokens); redirect_uri: redirectUri,
});
client.setCredentials(tokens);
await cacheCredentials(client.credentials);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL }); res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end(); res.end();
resolve(oAuth2Client); resolve();
} else { } else {
reject(new Error('No code found in request')); reject(new Error('No code found in request'));
} }
@ -112,9 +130,14 @@ async function webLoginClient(): Promise<OAuth2Client> {
}); });
server.listen(port); server.listen(port);
}); });
return {
authUrl,
loginCompletePromise,
};
} }
function getAvailablePort(): Promise<number> { export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
let port = 0; let port = 0;
try { try {
@ -135,25 +158,20 @@ function getAvailablePort(): Promise<number> {
}); });
} }
async function getCachedCredentialClient(): Promise<OAuth2Client> { async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
try { try {
const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8'); const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8');
const oAuth2Client = new OAuth2Client({ client.setCredentials(JSON.parse(creds));
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
});
oAuth2Client.setCredentials(JSON.parse(creds));
// This will either return the existing token or refresh it. // This will either return the existing token or refresh it.
await oAuth2Client.getAccessToken(); await client.getAccessToken();
// If we are here, the token is valid.
return oAuth2Client; return true;
} catch (_) { } catch (_) {
// Could not load credentials. return false;
throw new Error('Could not load credentials');
} }
} }
async function setCachedCredentials(credentials: Credentials) { async function cacheCredentials(credentials: Credentials) {
const filePath = getCachedCredentialPath(); const filePath = getCachedCredentialPath();
await fs.mkdir(path.dirname(filePath), { recursive: true }); await fs.mkdir(path.dirname(filePath), { recursive: true });

View File

@ -65,7 +65,7 @@ vi.mock('../telemetry/index.js', () => ({
describe('Gemini Client (client.ts)', () => { describe('Gemini Client (client.ts)', () => {
let client: GeminiClient; let client: GeminiClient;
beforeEach(() => { beforeEach(async () => {
vi.resetAllMocks(); vi.resetAllMocks();
// Set up the mock for GoogleGenAI constructor and its methods // Set up the mock for GoogleGenAI constructor and its methods
@ -131,6 +131,7 @@ describe('Gemini Client (client.ts)', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
const mockConfig = new Config({} as any); const mockConfig = new Config({} as any);
client = new GeminiClient(mockConfig); client = new GeminiClient(mockConfig);
await client.initialize();
}); });
afterEach(() => { afterEach(() => {
@ -262,9 +263,7 @@ describe('Gemini Client (client.ts)', () => {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn, generateContent: mockGenerateContentFn,
}; };
client['contentGenerator'] = Promise.resolve( client['contentGenerator'] = mockGenerator as ContentGenerator;
mockGenerator as ContentGenerator,
);
await client.generateContent(contents, generationConfig, abortSignal); await client.generateContent(contents, generationConfig, abortSignal);
@ -292,9 +291,7 @@ describe('Gemini Client (client.ts)', () => {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn, generateContent: mockGenerateContentFn,
}; };
client['contentGenerator'] = Promise.resolve( client['contentGenerator'] = mockGenerator as ContentGenerator;
mockGenerator as ContentGenerator,
);
await client.generateJson(contents, schema, abortSignal); await client.generateJson(contents, schema, abortSignal);
@ -319,7 +316,7 @@ describe('Gemini Client (client.ts)', () => {
addHistory: vi.fn(), addHistory: vi.fn(),
}; };
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
client['chat'] = Promise.resolve(mockChat as any); client['chat'] = mockChat as any;
const newContent = { const newContent = {
role: 'user', role: 'user',
@ -371,14 +368,12 @@ describe('Gemini Client (client.ts)', () => {
addHistory: vi.fn(), addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]), getHistory: vi.fn().mockReturnValue([]),
}; };
client['chat'] = Promise.resolve(mockChat as GeminiChat); client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = { const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
}; };
client['contentGenerator'] = Promise.resolve( client['contentGenerator'] = mockGenerator as ContentGenerator;
mockGenerator as ContentGenerator,
);
// Act // Act
const stream = client.sendMessageStream( const stream = client.sendMessageStream(

View File

@ -44,8 +44,8 @@ function isThinkingSupported(model: string) {
} }
export class GeminiClient { export class GeminiClient {
private chat: Promise<GeminiChat>; private chat?: GeminiChat;
private contentGenerator: Promise<ContentGenerator>; private contentGenerator?: ContentGenerator;
private model: string; private model: string;
private embeddingModel: string; private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = { private generateContentConfig: GenerateContentConfig = {
@ -59,35 +59,45 @@ export class GeminiClient {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
} }
this.contentGenerator = createContentGenerator(
this.config.getContentGeneratorConfig(),
);
this.model = config.getModel(); this.model = config.getModel();
this.embeddingModel = config.getEmbeddingModel(); this.embeddingModel = config.getEmbeddingModel();
this.chat = this.startChat(); }
async initialize() {
this.contentGenerator = await createContentGenerator(
this.config.getContentGeneratorConfig(),
);
this.chat = await this.startChat();
} }
async addHistory(content: Content) { async addHistory(content: Content) {
const chat = await this.chat; this.getChat().addHistory(content);
chat.addHistory(content);
} }
getChat(): Promise<GeminiChat> { getChat(): GeminiChat {
if (!this.chat) {
throw new Error('Chat not initialized');
}
return this.chat; return this.chat;
} }
private getContentGenerator(): ContentGenerator {
if (!this.contentGenerator) {
throw new Error('Content generator not initialized');
}
return this.contentGenerator;
}
async getHistory(): Promise<Content[]> { async getHistory(): Promise<Content[]> {
const chat = await this.chat; return this.getChat().getHistory();
return chat.getHistory();
} }
async setHistory(history: Content[]): Promise<void> { async setHistory(history: Content[]): Promise<void> {
const chat = await this.chat; this.getChat().setHistory(history);
chat.setHistory(history);
} }
async resetChat(): Promise<void> { async resetChat(): Promise<void> {
this.chat = this.startChat(); this.chat = await this.startChat();
await this.chat; await this.chat;
} }
@ -184,7 +194,7 @@ export class GeminiClient {
: this.generateContentConfig; : this.generateContentConfig;
return new GeminiChat( return new GeminiChat(
this.config, this.config,
await this.contentGenerator, this.getContentGenerator(),
this.model, this.model,
{ {
systemInstruction, systemInstruction,
@ -210,22 +220,24 @@ export class GeminiClient {
turns: number = this.MAX_TURNS, turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> { ): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!turns) { if (!turns) {
const chat = await this.chat; return new Turn(this.getChat());
return new Turn(chat);
} }
const compressed = await this.tryCompressChat(); const compressed = await this.tryCompressChat();
if (compressed) { if (compressed) {
yield { type: GeminiEventType.ChatCompressed, value: compressed }; yield { type: GeminiEventType.ChatCompressed, value: compressed };
} }
const chat = await this.chat; const turn = new Turn(this.getChat());
const turn = new Turn(chat);
const resultStream = turn.run(request, signal); const resultStream = turn.run(request, signal);
for await (const event of resultStream) { for await (const event of resultStream) {
yield event; yield event;
} }
if (!turn.pendingToolCalls.length && signal && !signal.aborted) { if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this,
signal,
);
if (nextSpeakerCheck?.next_speaker === 'model') { if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }]; const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final // This recursive call's events will be yielded out, but the final
@ -243,7 +255,6 @@ export class GeminiClient {
model: string = DEFAULT_GEMINI_FLASH_MODEL, model: string = DEFAULT_GEMINI_FLASH_MODEL,
config: GenerateContentConfig = {}, config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
const cg = await this.contentGenerator;
try { try {
const userMemory = this.config.getUserMemory(); const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory); const systemInstruction = getCoreSystemPrompt(userMemory);
@ -254,7 +265,7 @@ export class GeminiClient {
}; };
const apiCall = () => const apiCall = () =>
cg.generateContent({ this.getContentGenerator().generateContent({
model, model,
config: { config: {
...requestConfig, ...requestConfig,
@ -327,7 +338,6 @@ export class GeminiClient {
generationConfig: GenerateContentConfig, generationConfig: GenerateContentConfig,
abortSignal: AbortSignal, abortSignal: AbortSignal,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const cg = await this.contentGenerator;
const modelToUse = this.model; const modelToUse = this.model;
const configToUse: GenerateContentConfig = { const configToUse: GenerateContentConfig = {
...this.generateContentConfig, ...this.generateContentConfig,
@ -345,7 +355,7 @@ export class GeminiClient {
}; };
const apiCall = () => const apiCall = () =>
cg.generateContent({ this.getContentGenerator().generateContent({
model: modelToUse, model: modelToUse,
config: requestConfig, config: requestConfig,
contents, contents,
@ -386,8 +396,8 @@ export class GeminiClient {
contents: texts, contents: texts,
}; };
const cg = await this.contentGenerator; const embedContentResponse =
const embedContentResponse = await cg.embedContent(embedModelParams); await this.getContentGenerator().embedContent(embedModelParams);
if ( if (
!embedContentResponse.embeddings || !embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0 embedContentResponse.embeddings.length === 0
@ -415,19 +425,18 @@ export class GeminiClient {
async tryCompressChat( async tryCompressChat(
force: boolean = false, force: boolean = false,
): Promise<ChatCompressionInfo | null> { ): Promise<ChatCompressionInfo | null> {
const chat = await this.chat; const history = this.getChat().getHistory(true); // Get curated history
const history = chat.getHistory(true); // Get curated history
// Regardless of `force`, don't do anything if the history is empty. // Regardless of `force`, don't do anything if the history is empty.
if (history.length === 0) { if (history.length === 0) {
return null; return null;
} }
const cg = await this.contentGenerator; const { totalTokens: originalTokenCount } =
const { totalTokens: originalTokenCount } = await cg.countTokens({ await this.getContentGenerator().countTokens({
model: this.model, model: this.model,
contents: history, contents: history,
}); });
// If not forced, check if we should compress based on context size. // If not forced, check if we should compress based on context size.
if (!force) { if (!force) {
@ -457,7 +466,7 @@ export class GeminiClient {
const summarizationRequestMessage = { const summarizationRequestMessage = {
text: 'Summarize our conversation up to this point. The summary should be a concise yet comprehensive overview of all key topics, questions, answers, and important details discussed. This summary will replace the current chat history to conserve tokens, so it must capture everything essential to understand the context and continue our conversation effectively as if no information was lost.', text: 'Summarize our conversation up to this point. The summary should be a concise yet comprehensive overview of all key topics, questions, answers, and important details discussed. This summary will replace the current chat history to conserve tokens, so it must capture everything essential to understand the context and continue our conversation effectively as if no information was lost.',
}; };
const response = await chat.sendMessage({ const response = await this.getChat().sendMessage({
message: summarizationRequestMessage, message: summarizationRequestMessage,
}); });
const newHistory = [ const newHistory = [
@ -470,9 +479,12 @@ export class GeminiClient {
parts: [{ text: response.text }], parts: [{ text: response.text }],
}, },
]; ];
this.chat = this.startChat(newHistory); this.chat = await this.startChat(newHistory);
const newTokenCount = ( const newTokenCount = (
await cg.countTokens({ model: this.model, contents: newHistory }) await this.getContentGenerator().countTokens({
model: this.model,
contents: newHistory,
})
).totalTokens; ).totalTokens;
return originalTokenCount && newTokenCount return originalTokenCount && newTokenCount

View File

@ -49,7 +49,7 @@ export async function createContentGenerator(
}, },
}; };
if (config.codeAssist) { if (config.codeAssist) {
return createCodeAssistContentGenerator(httpOptions); return await createCodeAssistContentGenerator(httpOptions);
} }
const googleGenAI = new GoogleGenAI({ const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey, apiKey: config.apiKey === '' ? undefined : config.apiKey,