diff --git a/packages/core/src/code_assist/oauth2.test.ts b/packages/core/src/code_assist/oauth2.test.ts index 7e3c38f0..4223bb75 100644 --- a/packages/core/src/code_assist/oauth2.test.ts +++ b/packages/core/src/code_assist/oauth2.test.ts @@ -64,6 +64,7 @@ describe('oauth2', () => { setCredentials: mockSetCredentials, getAccessToken: mockGetAccessToken, credentials: mockTokens, + on: vi.fn(), } as unknown as OAuth2Client; vi.mocked(OAuth2Client).mockImplementation(() => mockOAuth2Client); @@ -136,10 +137,6 @@ describe('oauth2', () => { }); expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens); - const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json'); - const tokenData = JSON.parse(fs.readFileSync(tokenPath, 'utf-8')); - expect(tokenData).toEqual(mockTokens); - // Verify Google Account ID was cached const googleAccountIdPath = path.join( tempHomeDir, diff --git a/packages/core/src/code_assist/oauth2.ts b/packages/core/src/code_assist/oauth2.ts index d07c8560..a55f3804 100644 --- a/packages/core/src/code_assist/oauth2.ts +++ b/packages/core/src/code_assist/oauth2.ts @@ -58,6 +58,9 @@ export async function getOauthClient(): Promise { clientId: OAUTH_CLIENT_ID, clientSecret: OAUTH_CLIENT_SECRET, }); + client.on('tokens', async (tokens: Credentials) => { + await cacheCredentials(tokens); + }); if (await loadCachedCredentials(client)) { // Found valid cached credentials. @@ -130,8 +133,6 @@ async function authWithWeb(client: OAuth2Client): Promise { redirect_uri: redirectUri, }); client.setCredentials(tokens); - await cacheCredentials(client.credentials); - // Retrieve and cache Google Account ID during authentication try { const googleAccountId = await getGoogleAccountId(client); diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index d8d9c10a..9bcfa304 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -18,8 +18,8 @@ describe('CodeAssistServer', () => { }); it('should call the generateContent endpoint', async () => { - const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const client = new OAuth2Client(); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { response: { candidates: [ @@ -53,8 +53,8 @@ describe('CodeAssistServer', () => { }); it('should call the generateContentStream endpoint', async () => { - const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const client = new OAuth2Client(); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = (async function* () { yield { response: { @@ -90,8 +90,8 @@ describe('CodeAssistServer', () => { }); it('should call the onboardUser endpoint', async () => { - const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const client = new OAuth2Client(); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { name: 'operations/123', done: true, @@ -112,8 +112,8 @@ describe('CodeAssistServer', () => { }); it('should call the loadCodeAssist endpoint', async () => { - const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const client = new OAuth2Client(); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { // TODO: Add mock response }; @@ -131,8 +131,8 @@ describe('CodeAssistServer', () => { }); it('should return 0 for countTokens', async () => { - const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const client = new OAuth2Client(); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { totalTokens: 100, }; @@ -146,8 +146,8 @@ describe('CodeAssistServer', () => { }); it('should throw an error for embedContent', async () => { - const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const client = new OAuth2Client(); + const server = new CodeAssistServer(client, 'test-project'); await expect( server.embedContent({ model: 'test-model', diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 1eaf9217..8e74c8b2 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { AuthClient } from 'google-auth-library'; +import { OAuth2Client } from 'google-auth-library'; import { CodeAssistGlobalUserSettingResponse, LoadCodeAssistRequest, @@ -46,7 +46,7 @@ export const CODE_ASSIST_API_VERSION = 'v1internal'; export class CodeAssistServer implements ContentGenerator { constructor( - readonly auth: AuthClient, + readonly client: OAuth2Client, readonly projectId?: string, readonly httpOptions: HttpOptions = {}, ) {} @@ -129,7 +129,7 @@ export class CodeAssistServer implements ContentGenerator { req: object, signal?: AbortSignal, ): Promise { - const res = await this.auth.request({ + const res = await this.client.request({ url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, method: 'POST', headers: { @@ -144,7 +144,7 @@ export class CodeAssistServer implements ContentGenerator { } async getEndpoint(method: string, signal?: AbortSignal): Promise { - const res = await this.auth.request({ + const res = await this.client.request({ url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, method: 'GET', headers: { @@ -162,7 +162,7 @@ export class CodeAssistServer implements ContentGenerator { req: object, signal?: AbortSignal, ): Promise> { - const res = await this.auth.request({ + const res = await this.client.request({ url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, method: 'POST', params: { diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index f0ea60b3..7db6bdcd 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -27,9 +27,9 @@ export class ProjectIdRequiredError extends Error { * @param projectId the user's project id, if any * @returns the user's actual project id */ -export async function setupUser(authClient: OAuth2Client): Promise { +export async function setupUser(client: OAuth2Client): Promise { let projectId = process.env.GOOGLE_CLOUD_PROJECT; - const caServer = new CodeAssistServer(authClient, projectId); + const caServer = new CodeAssistServer(client, projectId); const clientMetadata: ClientMetadata = { ideType: 'IDE_UNSPECIFIED',