Fix oauth credential caching. (#2709)

This commit is contained in:
Tommaso Sciortino 2025-06-30 08:47:01 -07:00 committed by GitHub
parent f3849627fc
commit 5c4c833ddd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 25 deletions

View File

@ -64,6 +64,7 @@ describe('oauth2', () => {
setCredentials: mockSetCredentials, setCredentials: mockSetCredentials,
getAccessToken: mockGetAccessToken, getAccessToken: mockGetAccessToken,
credentials: mockTokens, credentials: mockTokens,
on: vi.fn(),
} as unknown as OAuth2Client; } as unknown as OAuth2Client;
vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client); vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client);
@ -136,10 +137,6 @@ describe('oauth2', () => {
}); });
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');
const tokenData = JSON.parse(fs.readFileSync(tokenPath, 'utf-8'));
expect(tokenData).toEqual(mockTokens);
// Verify Google Account ID was cached // Verify Google Account ID was cached
const googleAccountIdPath = path.join( const googleAccountIdPath = path.join(
tempHomeDir, tempHomeDir,

View File

@ -58,6 +58,9 @@ export async function getOauthClient(): Promise<OAuth2Client> {
clientId: OAUTH_CLIENT_ID, clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET, clientSecret: OAUTH_CLIENT_SECRET,
}); });
client.on('tokens', async (tokens: Credentials) => {
await cacheCredentials(tokens);
});
if (await loadCachedCredentials(client)) { if (await loadCachedCredentials(client)) {
// Found valid cached credentials. // Found valid cached credentials.
@ -130,8 +133,6 @@ async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
redirect_uri: redirectUri, redirect_uri: redirectUri,
}); });
client.setCredentials(tokens); client.setCredentials(tokens);
await cacheCredentials(client.credentials);
// Retrieve and cache Google Account ID during authentication // Retrieve and cache Google Account ID during authentication
try { try {
const googleAccountId = await getGoogleAccountId(client); const googleAccountId = await getGoogleAccountId(client);

View File

@ -18,8 +18,8 @@ describe('CodeAssistServer', () => {
}); });
it('should call the generateContent endpoint', async () => { it('should call the generateContent endpoint', async () => {
const auth = new OAuth2Client(); const client = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(client, 'test-project');
const mockResponse = { const mockResponse = {
response: { response: {
candidates: [ candidates: [
@ -53,8 +53,8 @@ describe('CodeAssistServer', () => {
}); });
it('should call the generateContentStream endpoint', async () => { it('should call the generateContentStream endpoint', async () => {
const auth = new OAuth2Client(); const client = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(client, 'test-project');
const mockResponse = (async function* () { const mockResponse = (async function* () {
yield { yield {
response: { response: {
@ -90,8 +90,8 @@ describe('CodeAssistServer', () => {
}); });
it('should call the onboardUser endpoint', async () => { it('should call the onboardUser endpoint', async () => {
const auth = new OAuth2Client(); const client = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(client, 'test-project');
const mockResponse = { const mockResponse = {
name: 'operations/123', name: 'operations/123',
done: true, done: true,
@ -112,8 +112,8 @@ describe('CodeAssistServer', () => {
}); });
it('should call the loadCodeAssist endpoint', async () => { it('should call the loadCodeAssist endpoint', async () => {
const auth = new OAuth2Client(); const client = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(client, 'test-project');
const mockResponse = { const mockResponse = {
// TODO: Add mock response // TODO: Add mock response
}; };
@ -131,8 +131,8 @@ describe('CodeAssistServer', () => {
}); });
it('should return 0 for countTokens', async () => { it('should return 0 for countTokens', async () => {
const auth = new OAuth2Client(); const client = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(client, 'test-project');
const mockResponse = { const mockResponse = {
totalTokens: 100, totalTokens: 100,
}; };
@ -146,8 +146,8 @@ describe('CodeAssistServer', () => {
}); });
it('should throw an error for embedContent', async () => { it('should throw an error for embedContent', async () => {
const auth = new OAuth2Client(); const client = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(client, 'test-project');
await expect( await expect(
server.embedContent({ server.embedContent({
model: 'test-model', model: 'test-model',

View File

@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { AuthClient } from 'google-auth-library'; import { OAuth2Client } from 'google-auth-library';
import { import {
CodeAssistGlobalUserSettingResponse, CodeAssistGlobalUserSettingResponse,
LoadCodeAssistRequest, LoadCodeAssistRequest,
@ -46,7 +46,7 @@ export const CODE_ASSIST_API_VERSION = 'v1internal';
export class CodeAssistServer implements ContentGenerator { export class CodeAssistServer implements ContentGenerator {
constructor( constructor(
readonly auth: AuthClient, readonly client: OAuth2Client,
readonly projectId?: string, readonly projectId?: string,
readonly httpOptions: HttpOptions = {}, readonly httpOptions: HttpOptions = {},
) {} ) {}
@ -129,7 +129,7 @@ export class CodeAssistServer implements ContentGenerator {
req: object, req: object,
signal?: AbortSignal, signal?: AbortSignal,
): Promise<T> { ): Promise<T> {
const res = await this.auth.request({ const res = await this.client.request({
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
method: 'POST', method: 'POST',
headers: { headers: {
@ -144,7 +144,7 @@ export class CodeAssistServer implements ContentGenerator {
} }
async getEndpoint<T>(method: string, signal?: AbortSignal): Promise<T> { async getEndpoint<T>(method: string, signal?: AbortSignal): Promise<T> {
const res = await this.auth.request({ const res = await this.client.request({
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
method: 'GET', method: 'GET',
headers: { headers: {
@ -162,7 +162,7 @@ export class CodeAssistServer implements ContentGenerator {
req: object, req: object,
signal?: AbortSignal, signal?: AbortSignal,
): Promise<AsyncGenerator<T>> { ): Promise<AsyncGenerator<T>> {
const res = await this.auth.request({ const res = await this.client.request({
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
method: 'POST', method: 'POST',
params: { params: {

View File

@ -27,9 +27,9 @@ export class ProjectIdRequiredError extends Error {
* @param projectId the user's project id, if any * @param projectId the user's project id, if any
* @returns the user's actual project id * @returns the user's actual project id
*/ */
export async function setupUser(authClient: OAuth2Client): Promise<string> { export async function setupUser(client: OAuth2Client): Promise<string> {
let projectId = process.env.GOOGLE_CLOUD_PROJECT; let projectId = process.env.GOOGLE_CLOUD_PROJECT;
const caServer = new CodeAssistServer(authClient, projectId); const caServer = new CodeAssistServer(client, projectId);
const clientMetadata: ClientMetadata = { const clientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED', ideType: 'IDE_UNSPECIFIED',