From 8bc3b415c973794654d64d434949a93fb3239acb Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Wed, 18 Jun 2025 16:34:00 -0700 Subject: [PATCH] Refactor in preparation for Reauth (#1196) --- packages/cli/src/gemini.tsx | 8 +- packages/core/src/code_assist/oauth2.test.ts | 9 +- packages/core/src/code_assist/oauth2.ts | 104 +++++++++++-------- packages/core/src/core/client.test.ts | 19 ++-- packages/core/src/core/client.ts | 84 ++++++++------- packages/core/src/core/contentGenerator.ts | 2 +- 6 files changed, 128 insertions(+), 98 deletions(-) diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index d87a8a6a..5b5bfa67 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -47,6 +47,10 @@ export async function main() { const extensions = loadExtensions(workspaceRoot); const config = await loadCliConfig(settings.merged, extensions, sessionId); + // When using Code Assist this triggers the Oauth login. + // Do this now, before sandboxing, so web redirect works. + await config.getGeminiClient().initialize(); + // Initialize centralized FileDiscoveryService config.getFileService(); if (config.getCheckpointEnabled()) { @@ -65,10 +69,6 @@ export async function main() { } } - // When using Code Assist this triggers the Oauth login. - // Do this now, before sandboxing, so web redirect works. - await config.getGeminiClient().getChat(); - // hop into sandbox if we are outside and sandboxing is enabled if (!process.env.SANDBOX) { const sandboxConfig = config.getSandbox(); diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index 47bd45b3..0f5b791b 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -73,8 +73,10 @@ describe('oauth2', () => { (resolve) => (serverListeningCallback = resolve), ); + let capturedPort = 0; const mockHttpServer = { listen: vi.fn((port: number, callback?: () => void) => { + capturedPort = port; if (callback) { callback(); } @@ -86,7 +88,7 @@ describe('oauth2', () => { } }), on: vi.fn(), - address: () => ({ port: 1234 }), + address: () => ({ port: capturedPort }), }; vi.mocked(http.createServer).mockImplementation((cb) => { requestCallback = cb as http.RequestListener< @@ -115,7 +117,10 @@ describe('oauth2', () => { expect(client).toBe(mockOAuth2Client); expect(open).toHaveBeenCalledWith(mockAuthUrl); - expect(mockGetToken).toHaveBeenCalledWith(mockCode); + expect(mockGetToken).toHaveBeenCalledWith({ + code: mockCode, + redirect_uri: `http://localhost:${capturedPort}/oauth2callback`, + }); expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json'); diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 9e15f65b..6527f957 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL = const GEMINI_DIR = '.gemini'; const CREDENTIAL_FILENAME = 'oauth_creds.json'; -export async function getOauthClient(): Promise { - try { - return await getCachedCredentialClient(); - } catch (_) { - const loggedInClient = await webLoginClient(); - await setCachedCredentials(loggedInClient.credentials); - return loggedInClient; - } +/** + * An Authentication URL for updating the credentials of a Oauth2Client + * as well as a promise that will resolve when the credentials have + * been refreshed (or which throws error when refreshing credentials failed). + */ +export interface OauthWebLogin { + authUrl: string; + loginCompletePromise: Promise; } -async function webLoginClient(): Promise { - const port = await getAvailablePort(); - const oAuth2Client = new OAuth2Client({ +export async function getOauthClient(): Promise { + const client = new OAuth2Client({ clientId: OAUTH_CLIENT_ID, clientSecret: OAUTH_CLIENT_SECRET, - redirectUri: `http://localhost:${port}/oauth2callback`, }); - return new Promise((resolve, reject) => { - const state = crypto.randomBytes(32).toString('hex'); - const authURL: string = oAuth2Client.generateAuthUrl({ - access_type: 'offline', - scope: OAUTH_SCOPE, - state, - }); - console.log( - `\n\nCode Assist login required.\n` + - `Attempting to open authentication page in your browser.\n` + - `Otherwise navigate to:\n\n${authURL}\n\n`, - ); - open(authURL); - console.log('Waiting for authentication...'); + if (await loadCachedCredentials(client)) { + // Found valid cached credentials. + return client; + } + const webLogin = await authWithWeb(client); + + console.log( + `\n\nCode Assist login required.\n` + + `Attempting to open authentication page in your browser.\n` + + `Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`, + ); + await open(webLogin.authUrl); + console.log('Waiting for authentication...'); + + await webLogin.loginCompletePromise; + + return client; +} + +async function authWithWeb(client: OAuth2Client): Promise { + const port = await getAvailablePort(); + const redirectUri = `http://localhost:${port}/oauth2callback`; + const state = crypto.randomBytes(32).toString('hex'); + const authUrl: string = client.generateAuthUrl({ + redirect_uri: redirectUri, + access_type: 'offline', + scope: OAUTH_SCOPE, + state, + }); + + const loginCompletePromise = new Promise((resolve, reject) => { const server = http.createServer(async (req, res) => { try { if (req.url!.indexOf('/oauth2callback') === -1) { @@ -94,13 +109,16 @@ async function webLoginClient(): Promise { reject(new Error('State mismatch. Possible CSRF attack')); } else if (qs.get('code')) { - const code: string = qs.get('code')!; - const { tokens } = await oAuth2Client.getToken(code); - oAuth2Client.setCredentials(tokens); + const { tokens } = await client.getToken({ + code: qs.get('code')!, + redirect_uri: redirectUri, + }); + client.setCredentials(tokens); + await cacheCredentials(client.credentials); res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL }); res.end(); - resolve(oAuth2Client); + resolve(); } else { reject(new Error('No code found in request')); } @@ -112,9 +130,14 @@ async function webLoginClient(): Promise { }); server.listen(port); }); + + return { + authUrl, + loginCompletePromise, + }; } -function getAvailablePort(): Promise { +export function getAvailablePort(): Promise { return new Promise((resolve, reject) => { let port = 0; try { @@ -135,25 +158,20 @@ function getAvailablePort(): Promise { }); } -async function getCachedCredentialClient(): Promise { +async function loadCachedCredentials(client: OAuth2Client): Promise { try { const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8'); - const oAuth2Client = new OAuth2Client({ - clientId: OAUTH_CLIENT_ID, - clientSecret: OAUTH_CLIENT_SECRET, - }); - oAuth2Client.setCredentials(JSON.parse(creds)); + client.setCredentials(JSON.parse(creds)); // This will either return the existing token or refresh it. - await oAuth2Client.getAccessToken(); - // If we are here, the token is valid. - return oAuth2Client; + await client.getAccessToken(); + + return true; } catch (_) { - // Could not load credentials. - throw new Error('Could not load credentials'); + return false; } } -async function setCachedCredentials(credentials: Credentials) { +async function cacheCredentials(credentials: Credentials) { const filePath = getCachedCredentialPath(); await fs.mkdir(path.dirname(filePath), { recursive: true }); diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 675d1c26..96346e99 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -65,7 +65,7 @@ vi.mock('../telemetry/index.js', () => ({ describe('Gemini Client (client.ts)', () => { let client: GeminiClient; - beforeEach(() => { + beforeEach(async () => { vi.resetAllMocks(); // Set up the mock for GoogleGenAI constructor and its methods @@ -131,6 +131,7 @@ describe('Gemini Client (client.ts)', () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any const mockConfig = new Config({} as any); client = new GeminiClient(mockConfig); + await client.initialize(); }); afterEach(() => { @@ -262,9 +263,7 @@ describe('Gemini Client (client.ts)', () => { countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), generateContent: mockGenerateContentFn, }; - client['contentGenerator'] = Promise.resolve( - mockGenerator as ContentGenerator, - ); + client['contentGenerator'] = mockGenerator as ContentGenerator; await client.generateContent(contents, generationConfig, abortSignal); @@ -292,9 +291,7 @@ describe('Gemini Client (client.ts)', () => { countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), generateContent: mockGenerateContentFn, }; - client['contentGenerator'] = Promise.resolve( - mockGenerator as ContentGenerator, - ); + client['contentGenerator'] = mockGenerator as ContentGenerator; await client.generateJson(contents, schema, abortSignal); @@ -319,7 +316,7 @@ describe('Gemini Client (client.ts)', () => { addHistory: vi.fn(), }; // eslint-disable-next-line @typescript-eslint/no-explicit-any - client['chat'] = Promise.resolve(mockChat as any); + client['chat'] = mockChat as any; const newContent = { role: 'user', @@ -371,14 +368,12 @@ describe('Gemini Client (client.ts)', () => { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), }; - client['chat'] = Promise.resolve(mockChat as GeminiChat); + client['chat'] = mockChat as GeminiChat; const mockGenerator: Partial = { countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), }; - client['contentGenerator'] = Promise.resolve( - mockGenerator as ContentGenerator, - ); + client['contentGenerator'] = mockGenerator as ContentGenerator; // Act const stream = client.sendMessageStream( diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 9cc8f328..d9b30835 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -44,8 +44,8 @@ function isThinkingSupported(model: string) { } export class GeminiClient { - private chat: Promise; - private contentGenerator: Promise; + private chat?: GeminiChat; + private contentGenerator?: ContentGenerator; private model: string; private embeddingModel: string; private generateContentConfig: GenerateContentConfig = { @@ -59,35 +59,45 @@ export class GeminiClient { setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); } - this.contentGenerator = createContentGenerator( - this.config.getContentGeneratorConfig(), - ); this.model = config.getModel(); this.embeddingModel = config.getEmbeddingModel(); - this.chat = this.startChat(); + } + + async initialize() { + this.contentGenerator = await createContentGenerator( + this.config.getContentGeneratorConfig(), + ); + this.chat = await this.startChat(); } async addHistory(content: Content) { - const chat = await this.chat; - chat.addHistory(content); + this.getChat().addHistory(content); } - getChat(): Promise { + getChat(): GeminiChat { + if (!this.chat) { + throw new Error('Chat not initialized'); + } return this.chat; } + private getContentGenerator(): ContentGenerator { + if (!this.contentGenerator) { + throw new Error('Content generator not initialized'); + } + return this.contentGenerator; + } + async getHistory(): Promise { - const chat = await this.chat; - return chat.getHistory(); + return this.getChat().getHistory(); } async setHistory(history: Content[]): Promise { - const chat = await this.chat; - chat.setHistory(history); + this.getChat().setHistory(history); } async resetChat(): Promise { - this.chat = this.startChat(); + this.chat = await this.startChat(); await this.chat; } @@ -184,7 +194,7 @@ export class GeminiClient { : this.generateContentConfig; return new GeminiChat( this.config, - await this.contentGenerator, + this.getContentGenerator(), this.model, { systemInstruction, @@ -210,22 +220,24 @@ export class GeminiClient { turns: number = this.MAX_TURNS, ): AsyncGenerator { if (!turns) { - const chat = await this.chat; - return new Turn(chat); + return new Turn(this.getChat()); } const compressed = await this.tryCompressChat(); if (compressed) { yield { type: GeminiEventType.ChatCompressed, value: compressed }; } - const chat = await this.chat; - const turn = new Turn(chat); + const turn = new Turn(this.getChat()); const resultStream = turn.run(request, signal); for await (const event of resultStream) { yield event; } if (!turn.pendingToolCalls.length && signal && !signal.aborted) { - const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); + const nextSpeakerCheck = await checkNextSpeaker( + this.getChat(), + this, + signal, + ); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; // This recursive call's events will be yielded out, but the final @@ -243,7 +255,6 @@ export class GeminiClient { model: string = DEFAULT_GEMINI_FLASH_MODEL, config: GenerateContentConfig = {}, ): Promise> { - const cg = await this.contentGenerator; try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); @@ -254,7 +265,7 @@ export class GeminiClient { }; const apiCall = () => - cg.generateContent({ + this.getContentGenerator().generateContent({ model, config: { ...requestConfig, @@ -327,7 +338,6 @@ export class GeminiClient { generationConfig: GenerateContentConfig, abortSignal: AbortSignal, ): Promise { - const cg = await this.contentGenerator; const modelToUse = this.model; const configToUse: GenerateContentConfig = { ...this.generateContentConfig, @@ -345,7 +355,7 @@ export class GeminiClient { }; const apiCall = () => - cg.generateContent({ + this.getContentGenerator().generateContent({ model: modelToUse, config: requestConfig, contents, @@ -386,8 +396,8 @@ export class GeminiClient { contents: texts, }; - const cg = await this.contentGenerator; - const embedContentResponse = await cg.embedContent(embedModelParams); + const embedContentResponse = + await this.getContentGenerator().embedContent(embedModelParams); if ( !embedContentResponse.embeddings || embedContentResponse.embeddings.length === 0 @@ -415,19 +425,18 @@ export class GeminiClient { async tryCompressChat( force: boolean = false, ): Promise { - const chat = await this.chat; - const history = chat.getHistory(true); // Get curated history + const history = this.getChat().getHistory(true); // Get curated history // Regardless of `force`, don't do anything if the history is empty. if (history.length === 0) { return null; } - const cg = await this.contentGenerator; - const { totalTokens: originalTokenCount } = await cg.countTokens({ - model: this.model, - contents: history, - }); + const { totalTokens: originalTokenCount } = + await this.getContentGenerator().countTokens({ + model: this.model, + contents: history, + }); // If not forced, check if we should compress based on context size. if (!force) { @@ -457,7 +466,7 @@ export class GeminiClient { const summarizationRequestMessage = { text: 'Summarize our conversation up to this point. The summary should be a concise yet comprehensive overview of all key topics, questions, answers, and important details discussed. This summary will replace the current chat history to conserve tokens, so it must capture everything essential to understand the context and continue our conversation effectively as if no information was lost.', }; - const response = await chat.sendMessage({ + const response = await this.getChat().sendMessage({ message: summarizationRequestMessage, }); const newHistory = [ @@ -470,9 +479,12 @@ export class GeminiClient { parts: [{ text: response.text }], }, ]; - this.chat = this.startChat(newHistory); + this.chat = await this.startChat(newHistory); const newTokenCount = ( - await cg.countTokens({ model: this.model, contents: newHistory }) + await this.getContentGenerator().countTokens({ + model: this.model, + contents: newHistory, + }) ).totalTokens; return originalTokenCount && newTokenCount diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 3b276738..a0c8d56a 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -49,7 +49,7 @@ export async function createContentGenerator( }, }; if (config.codeAssist) { - return createCodeAssistContentGenerator(httpOptions); + return await createCodeAssistContentGenerator(httpOptions); } const googleGenAI = new GoogleGenAI({ apiKey: config.apiKey === '' ? undefined : config.apiKey,