Basic code assist support (#910)

This commit is contained in:
Tommaso Sciortino 2025-06-10 16:00:13 -07:00 committed by GitHub
parent 4e84431df3
commit d79dafc577
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 434 additions and 340 deletions

1
package-lock.json generated
View File

@ -10761,6 +10761,7 @@
"diff": "^7.0.0",
"dotenv": "^16.4.7",
"fast-glob": "^3.3.3",
"google-auth-library": "^9.11.0",
"ignore": "^7.0.0",
"shell-quote": "^1.8.2",
"strip-ansi": "^7.1.0",

View File

@ -235,6 +235,7 @@ async function createContentGeneratorConfig(
model: argv.model || DEFAULT_GEMINI_MODEL,
apiKey: googleApiKey || geminiApiKey || '',
vertexai: hasGeminiApiKey ? false : undefined,
codeAssist: !!process.env.GEMINI_CODE_ASSIST,
};
if (config.apiKey) {

View File

@ -35,6 +35,7 @@
"ignore": "^7.0.0",
"shell-quote": "^1.8.2",
"strip-ansi": "^7.1.0",
"google-auth-library": "^9.11.0",
"undici": "^7.10.0"
},
"devDependencies": {

View File

@ -0,0 +1,146 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { OAuth2Client } from 'google-auth-library';
import {
LoadCodeAssistResponse,
LoadCodeAssistRequest,
OnboardUserRequest,
LongrunningOperationResponse,
} from './types.js';
import {
GenerateContentResponse,
GenerateContentParameters,
CountTokensParameters,
EmbedContentResponse,
CountTokensResponse,
EmbedContentParameters,
} from '@google/genai';
import { Readable } from 'stream';
import * as readline from 'readline';
import type { ReadableStream } from 'node:stream/web';
import { ContentGenerator } from '../core/contentGenerator.js';
// TODO: Use production endpoint once it supports our methods.
export const CCPA_ENDPOINT =
'https://staging-cloudcode-pa.sandbox.googleapis.com';
export const CCPA_API_VERSION = '/v1internal';
export class CcpaServer implements ContentGenerator {
constructor(
readonly auth: OAuth2Client,
readonly projectId?: string,
) {}
async generateContentStream(
req: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> {
return await this.streamEndpoint<GenerateContentResponse>(
'streamGenerateContent',
req,
);
}
async generateContent(
req: GenerateContentParameters,
): Promise<GenerateContentResponse> {
return await this.callEndpoint<GenerateContentResponse>(
'generateContent',
req,
);
}
async onboardUser(
req: OnboardUserRequest,
): Promise<LongrunningOperationResponse> {
return await this.callEndpoint<LongrunningOperationResponse>(
'onboardUser',
req,
);
}
async loadCodeAssist(
req: LoadCodeAssistRequest,
): Promise<LoadCodeAssistResponse> {
return await this.callEndpoint<LoadCodeAssistResponse>(
'loadCodeAssist',
req,
);
}
async countTokens(_req: CountTokensParameters): Promise<CountTokensResponse> {
return { totalTokens: 0 };
}
async embedContent(
_req: EmbedContentParameters,
): Promise<EmbedContentResponse> {
throw Error();
}
async callEndpoint<T>(method: string, req: object): Promise<T> {
const res = await this.auth.request({
url: `${CCPA_ENDPOINT}/${CCPA_API_VERSION}:${method}`,
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Goog-User-Project': this.projectId || '',
},
responseType: 'json',
body: JSON.stringify(req),
});
if (res.status !== 200) {
throw new Error(
`Failed to fetch from ${method}: ${res.status} ${res.data}`,
);
}
return res.data as T;
}
async streamEndpoint<T>(
method: string,
req: object,
): Promise<AsyncGenerator<T>> {
const res = await this.auth.request({
url: `${CCPA_ENDPOINT}/${CCPA_API_VERSION}:${method}`,
method: 'POST',
params: {
alt: 'sse',
},
headers: { 'Content-Type': 'application/json' },
responseType: 'stream',
body: JSON.stringify(req),
});
if (res.status !== 200) {
throw new Error(
`Failed to fetch from ${method}: ${res.status} ${res.data}`,
);
}
return (async function* (): AsyncGenerator<T> {
const rl = readline.createInterface({
input: Readable.fromWeb(res.data as ReadableStream<Uint8Array>),
crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks
});
let bufferedLines: string[] = [];
for await (const line of rl) {
// blank lines are used to separate JSON objects in the stream
if (line === '') {
if (bufferedLines.length === 0) {
continue; // no data to yield
}
yield JSON.parse(bufferedLines.join('\n')) as T;
bufferedLines = []; // Reset the buffer after yielding
} else if (line.startsWith('data: ')) {
bufferedLines.push(line.slice(6).trim());
} else {
throw new Error(`Unexpected line format in response: ${line}`);
}
}
})();
}
}

View File

@ -0,0 +1,19 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { ContentGenerator } from '../core/contentGenerator.js';
import { loginWithOauth } from './oauth2.js';
import { setupUser } from './setup.js';
import { CcpaServer } from './ccpaServer.js';
export async function createCodeAssistContentGenerator(): Promise<ContentGenerator> {
const oauth2Client = await loginWithOauth();
const projectId = await setupUser(
oauth2Client,
process.env.GOOGLE_CLOUD_PROJECT,
);
return new CcpaServer(oauth2Client, projectId);
}

View File

@ -1,7 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export const DEFAULT_ENDPOINT = 'https://cloudcode-pa.googleapis.com';

View File

@ -1,119 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { OAuth2Client } from 'google-auth-library';
import * as http from 'http';
import url from 'url';
import crypto from 'crypto';
import * as net from 'net';
// OAuth Client ID used to initiate OAuth2Client class.
const OAUTH_CLIENT_ID =
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
// OAuth Secret value used to initiate OAuth2Client class.
const OAUTH_CLIENT_NOT_SO_SECRET = process.env.GCA_OAUTH_SECRET;
// OAuth Scopes for Cloud Code authorization.
const OAUTH_SCOPE = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
];
const HTTP_REDIRECT = 301;
const SIGN_IN_SUCCESS_URL =
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
const SIGN_IN_FAILURE_URL =
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
export async function doGCALogin(): Promise<OAuth2Client> {
const redirectPort: number = await getAvailablePort();
const client: OAuth2Client = await createOAuth2Client(redirectPort);
await login(client, redirectPort);
return client;
}
function createOAuth2Client(redirectPort: number): OAuth2Client {
return new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_NOT_SO_SECRET,
redirectUri: `http://localhost:${redirectPort}/oauth2redirect`,
});
}
/**
* Returns first available port in user's machine
* @returns port number
*/
function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
const server = net.createServer();
server.listen(0, () => {
const address = server.address()! as net.AddressInfo;
port = address.port;
});
server.on('listening', () => {
server.close();
server.unref();
});
server.on('error', (e) => reject(e));
server.on('close', () => resolve(port));
} catch (e) {
reject(e);
}
});
}
function login(oAuth2Client: OAuth2Client, port: number): Promise<boolean> {
return new Promise((resolve, reject) => {
const state = crypto.randomBytes(32).toString('hex');
const authURL: string = oAuth2Client.generateAuthUrl({
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
console.log('Login:\n\n', authURL);
const server = http
.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') > -1) {
// acquire the code from the querystring, and close the web server.
const qs = new url.URL(req.url!).searchParams;
if (qs.get('error')) {
console.error(`Error during authentication: ${qs.get('error')}`);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
resolve(false);
} else if (qs.get('state') !== state) {
//check state value
console.log('State mismatch. Possible CSRF attack');
res.end('State mismatch. Possible CSRF attack');
resolve(false);
} else if (!qs.get('code')) {
const { tokens } = await oAuth2Client.getToken(qs.get('code')!);
console.log('Logged in! Tokens:\n\n', tokens);
oAuth2Client.setCredentials(tokens);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
resolve(true);
}
}
} catch (e) {
reject(e);
}
server.close();
})
.listen(port);
});
}

View File

@ -1,37 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
export interface ClientMetadata {
ideType?: ClientMetadataIdeType | null;
ideVersion?: string | null;
pluginVersion?: string | null;
platform?: ClientMetadataPlatform | null;
updateChannel?: string | null;
duetProject?: string | null;
pluginType?: ClientMetadataPluginType | null;
ideName?: string | null;
}
export type ClientMetadataIdeType =
| 'IDE_UNSPECIFIED'
| 'VSCODE'
| 'INTELLIJ'
| 'VSCODE_CLOUD_WORKSTATION'
| 'INTELLIJ_CLOUD_WORKSTATION'
| 'CLOUD_SHELL';
export type ClientMetadataPlatform =
| 'PLATFORM_UNSPECIFIED'
| 'DARWIN_AMD64'
| 'DARWIN_ARM64'
| 'LINUX_AMD64'
| 'LINUX_ARM64'
| 'WINDOWS_AMD64';
export type ClientMetadataPluginType =
| 'PLUGIN_UNSPECIFIED'
| 'CLOUD_CODE'
| 'GEMINI'
| 'AIPLUGIN_INTELLIJ'
| 'AIPLUGIN_STUDIO';

View File

@ -0,0 +1,116 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { OAuth2Client } from 'google-auth-library';
import * as http from 'http';
import url from 'url';
import crypto from 'crypto';
import * as net from 'net';
import open from 'open';
// OAuth Client ID used to initiate OAuth2Client class.
const OAUTH_CLIENT_ID =
'681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com';
// OAuth Secret value used to initiate OAuth2Client class.
// Note: It's ok to save this in git because this is an installed application
// as described here: https://developers.google.com/identity/protocols/oauth2#installed
// "The process results in a client ID and, in some cases, a client secret,
// which you embed in the source code of your application. (In this context,
// the client secret is obviously not treated as a secret.)"
const OAUTH_CLIENT_SECRET = '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';
// OAuth Scopes for Cloud Code authorization.
const OAUTH_SCOPE = [
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/userinfo.profile',
];
const HTTP_REDIRECT = 301;
const SIGN_IN_SUCCESS_URL =
'https://developers.google.com/gemini-code-assist/auth_success_gemini';
const SIGN_IN_FAILURE_URL =
'https://developers.google.com/gemini-code-assist/auth_failure_gemini';
export async function loginWithOauth(): Promise<OAuth2Client> {
const port = await getAvailablePort();
const oAuth2Client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
redirectUri: `http://localhost:${port}/oauth2callback`,
});
return new Promise((resolve, reject) => {
const state = crypto.randomBytes(32).toString('hex');
const authURL: string = oAuth2Client.generateAuthUrl({
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
open(authURL);
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
console.log('Unexpected request:', req.url);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(new Error('Unexpected request: ' + req.url));
}
// acquire the code from the querystring, and close the web server.
const qs = new url.URL(req.url!, 'http://localhost:3000').searchParams;
console.log('Processing request:', qs);
if (qs.get('error')) {
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_FAILURE_URL });
res.end();
reject(new Error(`Error during authentication: ${qs.get('error')}`));
} else if (qs.get('state') !== state) {
res.end('State mismatch. Possible CSRF attack');
reject(new Error('State mismatch. Possible CSRF attack'));
} else if (qs.get('code')) {
const code: string = qs.get('code')!;
console.log();
const { tokens } = await oAuth2Client.getToken(code);
console.log('Logged in! Tokens:\n\n', tokens);
oAuth2Client.setCredentials(tokens);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
resolve(oAuth2Client);
} else {
reject(new Error('No code found in request'));
}
} catch (e) {
reject(e);
} finally {
server.close();
}
});
server.listen(port);
});
}
function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
const server = net.createServer();
server.listen(0, () => {
const address = server.address()! as net.AddressInfo;
port = address.port;
});
server.on('listening', () => {
server.close();
server.unref();
});
server.on('error', (e) => reject(e));
server.on('close', () => resolve(port));
} catch (e) {
reject(e);
}
});
}

View File

@ -1,90 +0,0 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { OAuth2Client } from 'google-auth-library';
import { ClientMetadata } from './metadata.js';
import { DEFAULT_ENDPOINT } from './constants.js';
const ONBOARD_USER_ENDPOINT = '/v1internal:onboardUser';
export async function doOnboardUser(
req: OnboardUserRequest,
oauth2Client: OAuth2Client,
): Promise<LongrunningOperationResponse> {
console.log('OnboardUser req: ', JSON.stringify(req));
const authHeaders = await oauth2Client.getRequestHeaders();
const headers = { 'Content-Type': 'application/json', ...authHeaders };
const res: Response = await fetch(
new URL(ONBOARD_USER_ENDPOINT, DEFAULT_ENDPOINT),
{
method: 'POST',
headers,
body: JSON.stringify(req),
},
);
const data: LongrunningOperationResponse =
(await res.json()) as LongrunningOperationResponse;
console.log('OnboardUser res: ', JSON.stringify(data));
return data;
}
/**
* Proto signature of OnboardUserRequest as payload to OnboardUser call
*/
export interface OnboardUserRequest {
tierId: string | undefined;
cloudaicompanionProject: string | undefined;
metadata: ClientMetadata | undefined;
}
/**
* Represents LongrunningOperation proto
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
*/
export interface LongrunningOperationResponse {
name: string;
done?: boolean;
response?: OnboardUserResponse;
}
/**
* Represents OnboardUserResponse proto
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
*/
export interface OnboardUserResponse {
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
cloudaicompanionProject?: {
id: string;
name: string;
};
}
/**
* Status code of user license status
* it does not stricly correspond to the proto
* Error value is an additional value assigned to error responses from OnboardUser
*/
export enum OnboardUserStatusCode {
Default = 'DEFAULT',
Notice = 'NOTICE',
Warning = 'WARNING',
Error = 'ERROR',
}
/**
* Status of user onboarded to gemini
*/
export interface OnboardUserStatus {
statusCode: OnboardUserStatusCode;
displayMessage: string;
helpLink: HelpLinkUrl | undefined;
}
export interface HelpLinkUrl {
description: string;
url: string;
}

View File

@ -4,52 +4,46 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { ClientMetadata, OnboardUserRequest } from './types.js';
import { CcpaServer } from './ccpaServer.js';
import { OAuth2Client } from 'google-auth-library';
import { ClientMetadata } from './metadata.js';
import { doLoadCodeAssist, LoadCodeAssistResponse } from './load.js';
import { doGCALogin } from './login.js';
import {
doOnboardUser,
LongrunningOperationResponse,
OnboardUserRequest,
} from './onboard.js';
export async function doSetup(): Promise<string> {
const oauth2Client: OAuth2Client = await doGCALogin();
/**
*
* @param projectId the user's project id, if any
* @returns the user's actual project id
*/
export async function setupUser(
oAuth2Client: OAuth2Client,
projectId?: string,
): Promise<string> {
const ccpaServer: CcpaServer = new CcpaServer(oAuth2Client, projectId);
const clientMetadata: ClientMetadata = {
ideType: 'IDE_UNSPECIFIED',
ideVersion: null,
pluginVersion: null,
platform: 'PLATFORM_UNSPECIFIED',
updateChannel: null,
duetProject: 'aipp-internal-testing',
pluginType: 'GEMINI',
ideName: null,
};
if (process.env.GOOGLE_CLOUD_PROJECT) {
clientMetadata.duetProject = process.env.GOOGLE_CLOUD_PROJECT;
}
// Call LoadCodeAssist.
const loadCodeAssistRes: LoadCodeAssistResponse = await doLoadCodeAssist(
{
cloudaicompanionProject: 'aipp-internal-testing',
metadata: clientMetadata,
},
oauth2Client,
);
// TODO: Support Free Tier user without projectId.
const loadRes = await ccpaServer.loadCodeAssist({
cloudaicompanionProject: process.env.GOOGLE_CLOUD_PROJECT,
metadata: clientMetadata,
});
// Call OnboardUser until long running operation is complete.
const onboardUserReq: OnboardUserRequest = {
const onboardReq: OnboardUserRequest = {
tierId: 'legacy-tier',
cloudaicompanionProject: loadCodeAssistRes.cloudaicompanionProject || '',
cloudaicompanionProject: loadRes.cloudaicompanionProject || '',
metadata: clientMetadata,
};
let lroRes: LongrunningOperationResponse = await doOnboardUser(
onboardUserReq,
oauth2Client,
);
// Poll onboardUser until long running operation is complete.
let lroRes = await ccpaServer.onboardUser(onboardReq);
while (!lroRes.done) {
await new Promise((f) => setTimeout(f, 5000));
lroRes = await doOnboardUser(onboardUserReq, oauth2Client);
lroRes = await ccpaServer.onboardUser(onboardReq);
}
return lroRes.response?.cloudaicompanionProject?.id || '';

View File

@ -4,34 +4,38 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { OAuth2Client } from 'google-auth-library';
import { ClientMetadata } from './metadata.js';
import { DEFAULT_ENDPOINT } from './constants.js';
const LOAD_CODE_ASSIST_ENDPOINT = '/v1internal:loadCodeAssist';
export async function doLoadCodeAssist(
req: LoadCodeAssistRequest,
oauth2Client: OAuth2Client,
): Promise<LoadCodeAssistResponse> {
console.log('LoadCodeAssist req: ', JSON.stringify(req));
const authHeaders = await oauth2Client.getRequestHeaders();
const headers = { 'Content-Type': 'application/json', ...authHeaders };
const res: Response = await fetch(
new URL(LOAD_CODE_ASSIST_ENDPOINT, DEFAULT_ENDPOINT),
{
method: 'POST',
headers,
body: JSON.stringify(req),
},
);
const data: LoadCodeAssistResponse =
(await res.json()) as LoadCodeAssistResponse;
console.log('LoadCodeAssist res: ', JSON.stringify(data));
return data;
export interface ClientMetadata {
ideType?: ClientMetadataIdeType;
ideVersion?: string;
pluginVersion?: string;
platform?: ClientMetadataPlatform;
updateChannel?: string;
duetProject?: string;
pluginType?: ClientMetadataPluginType;
ideName?: string;
}
export type ClientMetadataIdeType =
| 'IDE_UNSPECIFIED'
| 'VSCODE'
| 'INTELLIJ'
| 'VSCODE_CLOUD_WORKSTATION'
| 'INTELLIJ_CLOUD_WORKSTATION'
| 'CLOUD_SHELL';
export type ClientMetadataPlatform =
| 'PLATFORM_UNSPECIFIED'
| 'DARWIN_AMD64'
| 'DARWIN_ARM64'
| 'LINUX_AMD64'
| 'LINUX_ARM64'
| 'WINDOWS_AMD64';
export type ClientMetadataPluginType =
| 'PLUGIN_UNSPECIFIED'
| 'CLOUD_CODE'
| 'GEMINI'
| 'AIPLUGIN_INTELLIJ'
| 'AIPLUGIN_STUDIO';
export interface LoadCodeAssistRequest {
cloudaicompanionProject?: string;
metadata: ClientMetadata;
@ -63,6 +67,20 @@ export interface GeminiUserTier {
hasOnboardedPreviously?: boolean;
}
/**
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface IneligibleTier {
reasonCode: IneligibleTierReasonCode;
reasonMessage: string;
tierId: UserTierId;
tierName: string;
}
/**
* List of predefined reason codes when a tier is blocked from a specific tier.
* https://source.corp.google.com/piper///depot/google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=378
@ -79,21 +97,6 @@ export enum IneligibleTierReasonCode {
UNSUPPORTED_LOCATION = 'UNSUPPORTED_LOCATION',
// go/keep-sorted end
}
/**
* Includes information specifying the reasons for a user's ineligibility for a specific tier.
* @param reasonCode mnemonic code representing the reason for in-eligibility.
* @param reasonMessage message to display to the user.
* @param tierId id of the tier.
* @param tierName name of the tier.
*/
export interface IneligibleTier {
reasonCode: IneligibleTierReasonCode;
reasonMessage: string;
tierId: UserTierId;
tierName: string;
}
/**
* UserTierId represents IDs returned from the Cloud Code Private API representing a user's tier
*
@ -113,3 +116,60 @@ export interface PrivacyNotice {
showNotice: boolean;
noticeText?: string;
}
/**
* Proto signature of OnboardUserRequest as payload to OnboardUser call
*/
export interface OnboardUserRequest {
tierId: string | undefined;
cloudaicompanionProject: string | undefined;
metadata: ClientMetadata | undefined;
}
/**
* Represents LongrunningOperation proto
* http://google3/google/longrunning/operations.proto;rcl=698857719;l=107
*/
export interface LongrunningOperationResponse {
name: string;
done?: boolean;
response?: OnboardUserResponse;
}
/**
* Represents OnboardUserResponse proto
* http://google3/google/internal/cloud/code/v1internal/cloudcode.proto;l=215
*/
export interface OnboardUserResponse {
// tslint:disable-next-line:enforce-name-casing This is the name of the field in the proto.
cloudaicompanionProject?: {
id: string;
name: string;
};
}
/**
* Status code of user license status
* it does not stricly correspond to the proto
* Error value is an additional value assigned to error responses from OnboardUser
*/
export enum OnboardUserStatusCode {
Default = 'DEFAULT',
Notice = 'NOTICE',
Warning = 'WARNING',
Error = 'ERROR',
}
/**
* Status of user onboarded to gemini
*/
export interface OnboardUserStatus {
statusCode: OnboardUserStatusCode;
displayMessage: string;
helpLink: HelpLinkUrl | undefined;
}
export interface HelpLinkUrl {
description: string;
url: string;
}

View File

@ -273,7 +273,9 @@ describe('Gemini Client (client.ts)', () => {
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
client['contentGenerator'] = Promise.resolve(
mockGenerator as ContentGenerator,
);
// Act
const stream = client.sendMessageStream(

View File

@ -5,7 +5,6 @@
*/
import {
EmbedContentResponse,
EmbedContentParameters,
GenerateContentConfig,
Part,
@ -37,7 +36,6 @@ import {
ContentGenerator,
createContentGenerator,
} from './contentGenerator.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
const proxy =
@ -52,7 +50,7 @@ if (proxy) {
export class GeminiClient {
private chat: Promise<GeminiChat>;
private contentGenerator: ContentGenerator;
private contentGenerator: Promise<ContentGenerator>;
private model: string;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
@ -162,7 +160,7 @@ export class GeminiClient {
const systemInstruction = getCoreSystemPrompt(userMemory);
return new GeminiChat(
this.contentGenerator,
await this.contentGenerator,
this.model,
{
systemInstruction,
@ -289,6 +287,7 @@ export class GeminiClient {
model: string = 'gemini-2.0-flash',
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
const cg = await this.contentGenerator;
const attempt = 1;
const startTime = Date.now();
try {
@ -302,7 +301,7 @@ export class GeminiClient {
let inputTokenCount = 0;
try {
const { totalTokens } = await this.contentGenerator.countTokens({
const { totalTokens } = await cg.countTokens({
model,
contents,
});
@ -317,7 +316,7 @@ export class GeminiClient {
this._logApiRequest(model, inputTokenCount);
const apiCall = () =>
this.contentGenerator.generateContent({
cg.generateContent({
model,
config: {
...requestConfig,
@ -397,6 +396,7 @@ export class GeminiClient {
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
): Promise<GenerateContentResponse> {
const cg = await this.contentGenerator;
const modelToUse = this.model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
@ -417,7 +417,7 @@ export class GeminiClient {
let inputTokenCount = 0;
try {
const { totalTokens } = await this.contentGenerator.countTokens({
const { totalTokens } = await cg.countTokens({
model: modelToUse,
contents,
});
@ -432,7 +432,7 @@ export class GeminiClient {
this._logApiRequest(modelToUse, inputTokenCount);
const apiCall = () =>
this.contentGenerator.generateContent({
cg.generateContent({
model: modelToUse,
config: requestConfig,
contents,
@ -478,8 +478,9 @@ export class GeminiClient {
model: this.embeddingModel,
contents: texts,
};
const embedContentResponse: EmbedContentResponse =
await this.contentGenerator.embedContent(embedModelParams);
const cg = await this.contentGenerator;
const embedContentResponse = await cg.embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
@ -508,7 +509,8 @@ export class GeminiClient {
const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history
const { totalTokens } = await this.contentGenerator.countTokens({
const cg = await this.contentGenerator;
const { totalTokens } = await cg.countTokens({
model: this.model,
contents: history,
});

View File

@ -13,6 +13,7 @@ import {
EmbedContentParameters,
GoogleGenAI,
} from '@google/genai';
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
@ -35,11 +36,15 @@ export type ContentGeneratorConfig = {
model: string;
apiKey?: string;
vertexai?: boolean;
codeAssist?: boolean;
};
export function createContentGenerator(
export async function createContentGenerator(
config: ContentGeneratorConfig,
): ContentGenerator {
): Promise<ContentGenerator> {
if (config.codeAssist) {
return createCodeAssistContentGenerator();
}
const version = process.env.CLI_VERSION || process.version;
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,