From b3e26de8624daae8662fced1eadf805a6135b089 Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Mon, 16 Jun 2025 19:31:32 -0700 Subject: [PATCH] Cache credentials in home dir, not working dir (#1122) --- packages/core/src/code_assist/oauth2.test.ts | 56 +++++++++++++------ packages/core/src/code_assist/oauth2.ts | 57 +++++++++++--------- 2 files changed, 72 insertions(+), 41 deletions(-) diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index 80949203..cd06a11a 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -4,12 +4,23 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi } from 'vitest'; -import { webLoginClient } from './oauth2.js'; +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { getOauthClient } from './oauth2.js'; import { OAuth2Client } from 'google-auth-library'; +import * as fs from 'fs'; +import * as path from 'path'; import http from 'http'; import open from 'open'; import crypto from 'crypto'; +import * as os from 'os'; + +vi.mock('os', async (importOriginal) => { + const os = await importOriginal(); + return { + ...os, + homedir: vi.fn(), + }; +}); vi.mock('google-auth-library'); vi.mock('http'); @@ -17,6 +28,18 @@ vi.mock('open'); vi.mock('crypto'); describe('oauth2', () => { + let tempHomeDir: string; + + beforeEach(() => { + tempHomeDir = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-cli-test-home-'), + ); + vi.mocked(os.homedir).mockReturnValue(tempHomeDir); + }); + afterEach(() => { + fs.rmSync(tempHomeDir, { recursive: true, force: true }); + }); + it('should perform a web login', async () => { const mockAuthUrl = 'https://example.com/auth'; const mockCode = 'test-code'; @@ -33,16 +56,17 @@ describe('oauth2', () => { generateAuthUrl: mockGenerateAuthUrl, getToken: mockGetToken, setCredentials: mockSetCredentials, + credentials: mockTokens, } as unknown as OAuth2Client; vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client); vi.spyOn(crypto, 'randomBytes').mockReturnValue(mockState as never); vi.mocked(open).mockImplementation(async () => ({}) as never); - let requestCallback!: ( - req: http.IncomingMessage, - res: http.ServerResponse, - ) => void; + let requestCallback!: http.RequestListener< + typeof http.IncomingMessage, + typeof http.ServerResponse + >; const mockHttpServer = { listen: vi.fn((port: number, callback?: () => void) => { if (callback) { @@ -58,14 +82,14 @@ describe('oauth2', () => { address: () => ({ port: 1234 }), }; vi.mocked(http.createServer).mockImplementation((cb) => { - requestCallback = cb as ( - req: http.IncomingMessage, - res: http.ServerResponse, - ) => void; + requestCallback = cb as http.RequestListener< + typeof http.IncomingMessage, + typeof http.ServerResponse + >; return mockHttpServer as unknown as http.Server; }); - const clientPromise = webLoginClient(); + const clientPromise = getOauthClient(); // Wait for the server to be created await new Promise((resolve) => setTimeout(resolve, 0)); @@ -78,15 +102,17 @@ describe('oauth2', () => { end: vi.fn(), } as unknown as http.ServerResponse; - if (requestCallback) { - await requestCallback(mockReq, mockRes); - } + await requestCallback(mockReq, mockRes); const client = await clientPromise; + expect(client).toBe(mockOAuth2Client); expect(open).toHaveBeenCalledWith(mockAuthUrl); expect(mockGetToken).toHaveBeenCalledWith(mockCode); expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); - expect(client).toBe(mockOAuth2Client); + + const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json'); + const tokenData = JSON.parse(fs.readFileSync(tokenPath, 'utf-8')); + expect(tokenData).toEqual(mockTokens); }); }); diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index 7d65d260..84c72fca 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -12,6 +12,7 @@ import * as net from 'net'; import open from 'open'; import path from 'node:path'; import { promises as fs } from 'node:fs'; +import * as os from 'os'; // OAuth Client ID used to initiate OAuth2Client class. const OAUTH_CLIENT_ID = @@ -41,30 +42,8 @@ const SIGN_IN_FAILURE_URL = const GEMINI_DIR = '.gemini'; const CREDENTIAL_FILENAME = 'oauth_creds.json'; -export async function getCachedCredentialClient(): Promise { - try { - const creds = await fs.readFile( - path.join(process.cwd(), GEMINI_DIR, CREDENTIAL_FILENAME), - 'utf-8', - ); - - const oAuth2Client = new OAuth2Client({ - clientId: OAUTH_CLIENT_ID, - clientSecret: OAUTH_CLIENT_SECRET, - }); - oAuth2Client.setCredentials(JSON.parse(creds)); - // This will either return the existing token or refresh it. - await oAuth2Client.getAccessToken(); - // If we are here, the token is valid. - return oAuth2Client; - } catch (_) { - // Could not load credentials. - throw new Error('Could not load credentials'); - } -} - export async function clearCachedCredentials(): Promise { - await fs.rm(path.join(process.cwd(), GEMINI_DIR, CREDENTIAL_FILENAME)); + await fs.rm(getCachedCredentialPath()); } export async function getOauthClient(): Promise { @@ -72,16 +51,19 @@ export async function getOauthClient(): Promise { return await getCachedCredentialClient(); } catch (_) { const loggedInClient = await webLoginClient(); - await fs.mkdir(path.join(process.cwd(), GEMINI_DIR), { recursive: true }); + + await fs.mkdir(path.dirname(getCachedCredentialPath()), { + recursive: true, + }); await fs.writeFile( - path.join(process.cwd(), GEMINI_DIR, CREDENTIAL_FILENAME), + getCachedCredentialPath(), JSON.stringify(loggedInClient.credentials, null, 2), ); return loggedInClient; } } -export async function webLoginClient(): Promise { +async function webLoginClient(): Promise { const port = await getAvailablePort(); const oAuth2Client = new OAuth2Client({ clientId: OAUTH_CLIENT_ID, @@ -163,3 +145,26 @@ function getAvailablePort(): Promise { } }); } + +async function getCachedCredentialClient(): Promise { + try { + const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8'); + + const oAuth2Client = new OAuth2Client({ + clientId: OAUTH_CLIENT_ID, + clientSecret: OAUTH_CLIENT_SECRET, + }); + oAuth2Client.setCredentials(JSON.parse(creds)); + // This will either return the existing token or refresh it. + await oAuth2Client.getAccessToken(); + // If we are here, the token is valid. + return oAuth2Client; + } catch (_) { + // Could not load credentials. + throw new Error('Could not load credentials'); + } +} + +function getCachedCredentialPath(): string { + return path.join(os.homedir(), GEMINI_DIR, CREDENTIAL_FILENAME); +}