Add NO_BROWSER environment variable to trigger offline oauth flow (#3713)
This commit is contained in:
parent
ab66e3a24e
commit
8a128d8dc6
|
@ -317,6 +317,7 @@ export async function loadCliConfig(
|
||||||
name: e.config.name,
|
name: e.config.name,
|
||||||
version: e.config.version,
|
version: e.config.version,
|
||||||
})),
|
})),
|
||||||
|
noBrowser: !!process.env.NO_BROWSER,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ import {
|
||||||
sessionId,
|
sessionId,
|
||||||
logUserPrompt,
|
logUserPrompt,
|
||||||
AuthType,
|
AuthType,
|
||||||
|
getOauthClient,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import { validateAuthMethod } from './config/auth.js';
|
import { validateAuthMethod } from './config/auth.js';
|
||||||
import { setMaxSizedBoxDebugging } from './ui/components/shared/MaxSizedBox.js';
|
import { setMaxSizedBoxDebugging } from './ui/components/shared/MaxSizedBox.js';
|
||||||
|
@ -165,6 +166,15 @@ export async function main() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
settings.merged.selectedAuthType === AuthType.LOGIN_WITH_GOOGLE &&
|
||||||
|
config.getNoBrowser()
|
||||||
|
) {
|
||||||
|
// Do oauth before app renders to make copying the link possible.
|
||||||
|
await getOauthClient(settings.merged.selectedAuthType, config);
|
||||||
|
}
|
||||||
|
|
||||||
let input = config.getQuestion();
|
let input = config.getQuestion();
|
||||||
const startupWarnings = [
|
const startupWarnings = [
|
||||||
...(await getStartupWarnings()),
|
...(await getStartupWarnings()),
|
||||||
|
|
|
@ -728,6 +728,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
|
||||||
/>
|
/>
|
||||||
</Box>
|
</Box>
|
||||||
) : isAuthenticating ? (
|
) : isAuthenticating ? (
|
||||||
|
<>
|
||||||
<AuthInProgress
|
<AuthInProgress
|
||||||
onTimeout={() => {
|
onTimeout={() => {
|
||||||
setAuthError('Authentication timed out. Please try again.');
|
setAuthError('Authentication timed out. Please try again.');
|
||||||
|
@ -735,6 +736,21 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
|
||||||
openAuthDialog();
|
openAuthDialog();
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
{showErrorDetails && (
|
||||||
|
<OverflowProvider>
|
||||||
|
<Box flexDirection="column">
|
||||||
|
<DetailedMessagesDisplay
|
||||||
|
messages={filteredConsoleMessages}
|
||||||
|
maxHeight={
|
||||||
|
constrainHeight ? debugConsoleMaxHeight : undefined
|
||||||
|
}
|
||||||
|
width={inputWidth}
|
||||||
|
/>
|
||||||
|
<ShowMoreLines constrainHeight={constrainHeight} />
|
||||||
|
</Box>
|
||||||
|
</OverflowProvider>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
) : isAuthDialogOpen ? (
|
) : isAuthDialogOpen ? (
|
||||||
<Box flexDirection="column">
|
<Box flexDirection="column">
|
||||||
<AuthDialog
|
<AuthDialog
|
||||||
|
|
|
@ -8,17 +8,19 @@ import { AuthType, ContentGenerator } from '../core/contentGenerator.js';
|
||||||
import { getOauthClient } from './oauth2.js';
|
import { getOauthClient } from './oauth2.js';
|
||||||
import { setupUser } from './setup.js';
|
import { setupUser } from './setup.js';
|
||||||
import { CodeAssistServer, HttpOptions } from './server.js';
|
import { CodeAssistServer, HttpOptions } from './server.js';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
|
|
||||||
export async function createCodeAssistContentGenerator(
|
export async function createCodeAssistContentGenerator(
|
||||||
httpOptions: HttpOptions,
|
httpOptions: HttpOptions,
|
||||||
authType: AuthType,
|
authType: AuthType,
|
||||||
|
config: Config,
|
||||||
sessionId?: string,
|
sessionId?: string,
|
||||||
): Promise<ContentGenerator> {
|
): Promise<ContentGenerator> {
|
||||||
if (
|
if (
|
||||||
authType === AuthType.LOGIN_WITH_GOOGLE ||
|
authType === AuthType.LOGIN_WITH_GOOGLE ||
|
||||||
authType === AuthType.CLOUD_SHELL
|
authType === AuthType.CLOUD_SHELL
|
||||||
) {
|
) {
|
||||||
const authClient = await getOauthClient(authType);
|
const authClient = await getOauthClient(authType, config);
|
||||||
const projectId = await setupUser(authClient);
|
const projectId = await setupUser(authClient);
|
||||||
return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
|
return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import open from 'open';
|
||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
import * as os from 'os';
|
import * as os from 'os';
|
||||||
import { AuthType } from '../core/contentGenerator.js';
|
import { AuthType } from '../core/contentGenerator.js';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
|
|
||||||
vi.mock('os', async (importOriginal) => {
|
vi.mock('os', async (importOriginal) => {
|
||||||
const os = await importOriginal<typeof import('os')>();
|
const os = await importOriginal<typeof import('os')>();
|
||||||
|
@ -28,6 +29,10 @@ vi.mock('http');
|
||||||
vi.mock('open');
|
vi.mock('open');
|
||||||
vi.mock('crypto');
|
vi.mock('crypto');
|
||||||
|
|
||||||
|
const mockConfig = {
|
||||||
|
getNoBrowser: () => false,
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
// Mock fetch globally
|
// Mock fetch globally
|
||||||
global.fetch = vi.fn();
|
global.fetch = vi.fn();
|
||||||
|
|
||||||
|
@ -136,7 +141,10 @@ describe('oauth2', () => {
|
||||||
return mockHttpServer as unknown as http.Server;
|
return mockHttpServer as unknown as http.Server;
|
||||||
});
|
});
|
||||||
|
|
||||||
const clientPromise = getOauthClient(AuthType.LOGIN_WITH_GOOGLE);
|
const clientPromise = getOauthClient(
|
||||||
|
AuthType.LOGIN_WITH_GOOGLE,
|
||||||
|
mockConfig,
|
||||||
|
);
|
||||||
|
|
||||||
// wait for server to start listening.
|
// wait for server to start listening.
|
||||||
await serverListeningPromise;
|
await serverListeningPromise;
|
||||||
|
@ -214,7 +222,7 @@ describe('oauth2', () => {
|
||||||
() => mockClient as unknown as OAuth2Client,
|
() => mockClient as unknown as OAuth2Client,
|
||||||
);
|
);
|
||||||
|
|
||||||
await getOauthClient(AuthType.LOGIN_WITH_GOOGLE);
|
await getOauthClient(AuthType.LOGIN_WITH_GOOGLE, mockConfig);
|
||||||
|
|
||||||
expect(fs.promises.readFile).toHaveBeenCalledWith(
|
expect(fs.promises.readFile).toHaveBeenCalledWith(
|
||||||
'/user/home/.gemini/oauth_creds.json',
|
'/user/home/.gemini/oauth_creds.json',
|
||||||
|
@ -227,7 +235,7 @@ describe('oauth2', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should use Compute to get a client if no cached credentials exist', async () => {
|
it('should use Compute to get a client if no cached credentials exist', async () => {
|
||||||
await getOauthClient(AuthType.CLOUD_SHELL);
|
await getOauthClient(AuthType.CLOUD_SHELL, mockConfig);
|
||||||
|
|
||||||
expect(Compute).toHaveBeenCalledWith({});
|
expect(Compute).toHaveBeenCalledWith({});
|
||||||
expect(mockGetAccessToken).toHaveBeenCalled();
|
expect(mockGetAccessToken).toHaveBeenCalled();
|
||||||
|
@ -238,13 +246,13 @@ describe('oauth2', () => {
|
||||||
mockComputeClient.credentials = newCredentials;
|
mockComputeClient.credentials = newCredentials;
|
||||||
mockGetAccessToken.mockResolvedValue({ token: 'new-adc-token' });
|
mockGetAccessToken.mockResolvedValue({ token: 'new-adc-token' });
|
||||||
|
|
||||||
await getOauthClient(AuthType.CLOUD_SHELL);
|
await getOauthClient(AuthType.CLOUD_SHELL, mockConfig);
|
||||||
|
|
||||||
expect(fs.promises.writeFile).not.toHaveBeenCalled();
|
expect(fs.promises.writeFile).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return the Compute client on successful ADC authentication', async () => {
|
it('should return the Compute client on successful ADC authentication', async () => {
|
||||||
const client = await getOauthClient(AuthType.CLOUD_SHELL);
|
const client = await getOauthClient(AuthType.CLOUD_SHELL, mockConfig);
|
||||||
expect(client).toBe(mockComputeClient);
|
expect(client).toBe(mockComputeClient);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -252,7 +260,9 @@ describe('oauth2', () => {
|
||||||
const testError = new Error('ADC Failed');
|
const testError = new Error('ADC Failed');
|
||||||
mockGetAccessToken.mockRejectedValue(testError);
|
mockGetAccessToken.mockRejectedValue(testError);
|
||||||
|
|
||||||
await expect(getOauthClient(AuthType.CLOUD_SHELL)).rejects.toThrow(
|
await expect(
|
||||||
|
getOauthClient(AuthType.CLOUD_SHELL, mockConfig),
|
||||||
|
).rejects.toThrow(
|
||||||
'Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ADC Failed',
|
'Could not authenticate using Cloud Shell credentials. Please select a different authentication method or ensure you are in a properly configured environment. Error: ADC Failed',
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
|
@ -4,7 +4,12 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { OAuth2Client, Credentials, Compute } from 'google-auth-library';
|
import {
|
||||||
|
OAuth2Client,
|
||||||
|
Credentials,
|
||||||
|
Compute,
|
||||||
|
CodeChallengeMethod,
|
||||||
|
} from 'google-auth-library';
|
||||||
import * as http from 'http';
|
import * as http from 'http';
|
||||||
import url from 'url';
|
import url from 'url';
|
||||||
import crypto from 'crypto';
|
import crypto from 'crypto';
|
||||||
|
@ -13,8 +18,10 @@ import open from 'open';
|
||||||
import path from 'node:path';
|
import path from 'node:path';
|
||||||
import { promises as fs, existsSync, readFileSync } from 'node:fs';
|
import { promises as fs, existsSync, readFileSync } from 'node:fs';
|
||||||
import * as os from 'os';
|
import * as os from 'os';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
import { AuthType } from '../core/contentGenerator.js';
|
import { AuthType } from '../core/contentGenerator.js';
|
||||||
|
import readline from 'node:readline';
|
||||||
|
|
||||||
// OAuth Client ID used to initiate OAuth2Client class.
|
// OAuth Client ID used to initiate OAuth2Client class.
|
||||||
const OAUTH_CLIENT_ID =
|
const OAUTH_CLIENT_ID =
|
||||||
|
@ -57,6 +64,7 @@ export interface OauthWebLogin {
|
||||||
|
|
||||||
export async function getOauthClient(
|
export async function getOauthClient(
|
||||||
authType: AuthType,
|
authType: AuthType,
|
||||||
|
config: Config,
|
||||||
): Promise<OAuth2Client> {
|
): Promise<OAuth2Client> {
|
||||||
const client = new OAuth2Client({
|
const client = new OAuth2Client({
|
||||||
clientId: OAUTH_CLIENT_ID,
|
clientId: OAUTH_CLIENT_ID,
|
||||||
|
@ -109,9 +117,25 @@ export async function getOauthClient(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, obtain creds using standard web flow
|
if (config.getNoBrowser()) {
|
||||||
|
let success = false;
|
||||||
|
const maxRetries = 2;
|
||||||
|
for (let i = 0; !success && i < maxRetries; i++) {
|
||||||
|
success = await authWithUserCode(client);
|
||||||
|
if (!success) {
|
||||||
|
console.error(
|
||||||
|
'\nFailed to authenticate with user code.',
|
||||||
|
i === maxRetries - 1 ? '' : 'Retrying...\n',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!success) {
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
const webLogin = await authWithWeb(client);
|
const webLogin = await authWithWeb(client);
|
||||||
|
|
||||||
|
// This does basically nothing, as it isn't show to the user.
|
||||||
console.log(
|
console.log(
|
||||||
`\n\nCode Assist login required.\n` +
|
`\n\nCode Assist login required.\n` +
|
||||||
`Attempting to open authentication page in your browser.\n` +
|
`Attempting to open authentication page in your browser.\n` +
|
||||||
|
@ -121,15 +145,65 @@ export async function getOauthClient(
|
||||||
console.log('Waiting for authentication...');
|
console.log('Waiting for authentication...');
|
||||||
|
|
||||||
await webLogin.loginCompletePromise;
|
await webLogin.loginCompletePromise;
|
||||||
|
}
|
||||||
|
|
||||||
return client;
|
return client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function authWithUserCode(client: OAuth2Client): Promise<boolean> {
|
||||||
|
const redirectUri = 'https://sdk.cloud.google.com/authcode_cloudcode.html';
|
||||||
|
const codeVerifier = await client.generateCodeVerifierAsync();
|
||||||
|
const state = crypto.randomBytes(32).toString('hex');
|
||||||
|
const authUrl: string = client.generateAuthUrl({
|
||||||
|
redirect_uri: redirectUri,
|
||||||
|
access_type: 'offline',
|
||||||
|
scope: OAUTH_SCOPE,
|
||||||
|
code_challenge_method: CodeChallengeMethod.S256,
|
||||||
|
code_challenge: codeVerifier.codeChallenge,
|
||||||
|
state,
|
||||||
|
});
|
||||||
|
console.error('Please visit the following URL to authorize the application:');
|
||||||
|
console.error('');
|
||||||
|
console.error(authUrl);
|
||||||
|
console.error('');
|
||||||
|
|
||||||
|
const code = await new Promise<string>((resolve) => {
|
||||||
|
const rl = readline.createInterface({
|
||||||
|
input: process.stdin,
|
||||||
|
output: process.stdout,
|
||||||
|
});
|
||||||
|
rl.question('Enter the authorization code: ', (answer) => {
|
||||||
|
rl.close();
|
||||||
|
resolve(answer.trim());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!code) {
|
||||||
|
console.error('Authorization code is required.');
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
console.error(`Received authorization code: "${code}"`);
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await client.getToken({
|
||||||
|
code,
|
||||||
|
codeVerifier: codeVerifier.codeVerifier,
|
||||||
|
redirect_uri: redirectUri,
|
||||||
|
});
|
||||||
|
client.setCredentials(response.tokens);
|
||||||
|
} catch (_error) {
|
||||||
|
// Consider logging the error.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
|
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
|
||||||
const port = await getAvailablePort();
|
const port = await getAvailablePort();
|
||||||
const redirectUri = `http://localhost:${port}/oauth2callback`;
|
const redirectUri = `http://localhost:${port}/oauth2callback`;
|
||||||
const state = crypto.randomBytes(32).toString('hex');
|
const state = crypto.randomBytes(32).toString('hex');
|
||||||
const authUrl: string = client.generateAuthUrl({
|
const authUrl = client.generateAuthUrl({
|
||||||
redirect_uri: redirectUri,
|
redirect_uri: redirectUri,
|
||||||
access_type: 'offline',
|
access_type: 'offline',
|
||||||
scope: OAUTH_SCOPE,
|
scope: OAUTH_SCOPE,
|
||||||
|
|
|
@ -141,6 +141,7 @@ export interface ConfigParameters {
|
||||||
extensionContextFilePaths?: string[];
|
extensionContextFilePaths?: string[];
|
||||||
listExtensions?: boolean;
|
listExtensions?: boolean;
|
||||||
activeExtensions?: ActiveExtension[];
|
activeExtensions?: ActiveExtension[];
|
||||||
|
noBrowser?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Config {
|
export class Config {
|
||||||
|
@ -179,6 +180,7 @@ export class Config {
|
||||||
private readonly bugCommand: BugCommandSettings | undefined;
|
private readonly bugCommand: BugCommandSettings | undefined;
|
||||||
private readonly model: string;
|
private readonly model: string;
|
||||||
private readonly extensionContextFilePaths: string[];
|
private readonly extensionContextFilePaths: string[];
|
||||||
|
private readonly noBrowser: boolean;
|
||||||
private modelSwitchedDuringSession: boolean = false;
|
private modelSwitchedDuringSession: boolean = false;
|
||||||
private readonly listExtensions: boolean;
|
private readonly listExtensions: boolean;
|
||||||
private readonly _activeExtensions: ActiveExtension[];
|
private readonly _activeExtensions: ActiveExtension[];
|
||||||
|
@ -227,6 +229,7 @@ export class Config {
|
||||||
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
|
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
|
||||||
this.listExtensions = params.listExtensions ?? false;
|
this.listExtensions = params.listExtensions ?? false;
|
||||||
this._activeExtensions = params.activeExtensions ?? [];
|
this._activeExtensions = params.activeExtensions ?? [];
|
||||||
|
this.noBrowser = params.noBrowser ?? false;
|
||||||
|
|
||||||
if (params.contextFileName) {
|
if (params.contextFileName) {
|
||||||
setGeminiMdFilename(params.contextFileName);
|
setGeminiMdFilename(params.contextFileName);
|
||||||
|
@ -475,6 +478,10 @@ export class Config {
|
||||||
return this._activeExtensions;
|
return this._activeExtensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getNoBrowser(): boolean {
|
||||||
|
return this.noBrowser;
|
||||||
|
}
|
||||||
|
|
||||||
async getGitService(): Promise<GitService> {
|
async getGitService(): Promise<GitService> {
|
||||||
if (!this.gitService) {
|
if (!this.gitService) {
|
||||||
this.gitService = new GitService(this.targetDir);
|
this.gitService = new GitService(this.targetDir);
|
||||||
|
|
|
@ -180,6 +180,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
getFileService: vi.fn().mockReturnValue(fileService),
|
getFileService: vi.fn().mockReturnValue(fileService),
|
||||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||||
setQuotaErrorOccurred: vi.fn(),
|
setQuotaErrorOccurred: vi.fn(),
|
||||||
|
getNoBrowser: vi.fn().mockReturnValue(false),
|
||||||
};
|
};
|
||||||
return mock as unknown as Config;
|
return mock as unknown as Config;
|
||||||
});
|
});
|
||||||
|
|
|
@ -109,6 +109,7 @@ export class GeminiClient {
|
||||||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||||
this.contentGenerator = await createContentGenerator(
|
this.contentGenerator = await createContentGenerator(
|
||||||
contentGeneratorConfig,
|
contentGeneratorConfig,
|
||||||
|
this.config,
|
||||||
this.config.getSessionId(),
|
this.config.getSessionId(),
|
||||||
);
|
);
|
||||||
this.chat = await this.startChat();
|
this.chat = await this.startChat();
|
||||||
|
|
|
@ -12,20 +12,26 @@ import {
|
||||||
} from './contentGenerator.js';
|
} from './contentGenerator.js';
|
||||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||||
import { GoogleGenAI } from '@google/genai';
|
import { GoogleGenAI } from '@google/genai';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
|
|
||||||
vi.mock('../code_assist/codeAssist.js');
|
vi.mock('../code_assist/codeAssist.js');
|
||||||
vi.mock('@google/genai');
|
vi.mock('@google/genai');
|
||||||
|
|
||||||
|
const mockConfig = {} as unknown as Config;
|
||||||
|
|
||||||
describe('createContentGenerator', () => {
|
describe('createContentGenerator', () => {
|
||||||
it('should create a CodeAssistContentGenerator', async () => {
|
it('should create a CodeAssistContentGenerator', async () => {
|
||||||
const mockGenerator = {} as unknown;
|
const mockGenerator = {} as unknown;
|
||||||
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
vi.mocked(createCodeAssistContentGenerator).mockResolvedValue(
|
||||||
mockGenerator as never,
|
mockGenerator as never,
|
||||||
);
|
);
|
||||||
const generator = await createContentGenerator({
|
const generator = await createContentGenerator(
|
||||||
|
{
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||||
});
|
},
|
||||||
|
mockConfig,
|
||||||
|
);
|
||||||
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
|
expect(createCodeAssistContentGenerator).toHaveBeenCalled();
|
||||||
expect(generator).toBe(mockGenerator);
|
expect(generator).toBe(mockGenerator);
|
||||||
});
|
});
|
||||||
|
@ -35,11 +41,14 @@ describe('createContentGenerator', () => {
|
||||||
models: {},
|
models: {},
|
||||||
} as unknown;
|
} as unknown;
|
||||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||||
const generator = await createContentGenerator({
|
const generator = await createContentGenerator(
|
||||||
|
{
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
apiKey: 'test-api-key',
|
apiKey: 'test-api-key',
|
||||||
authType: AuthType.USE_GEMINI,
|
authType: AuthType.USE_GEMINI,
|
||||||
});
|
},
|
||||||
|
mockConfig,
|
||||||
|
);
|
||||||
expect(GoogleGenAI).toHaveBeenCalledWith({
|
expect(GoogleGenAI).toHaveBeenCalledWith({
|
||||||
apiKey: 'test-api-key',
|
apiKey: 'test-api-key',
|
||||||
vertexai: undefined,
|
vertexai: undefined,
|
||||||
|
|
|
@ -15,6 +15,7 @@ import {
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
import { getEffectiveModel } from './modelCheck.js';
|
import { getEffectiveModel } from './modelCheck.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -99,6 +100,7 @@ export async function createContentGeneratorConfig(
|
||||||
|
|
||||||
export async function createContentGenerator(
|
export async function createContentGenerator(
|
||||||
config: ContentGeneratorConfig,
|
config: ContentGeneratorConfig,
|
||||||
|
gcConfig: Config,
|
||||||
sessionId?: string,
|
sessionId?: string,
|
||||||
): Promise<ContentGenerator> {
|
): Promise<ContentGenerator> {
|
||||||
const version = process.env.CLI_VERSION || process.version;
|
const version = process.env.CLI_VERSION || process.version;
|
||||||
|
@ -114,6 +116,7 @@ export async function createContentGenerator(
|
||||||
return createCodeAssistContentGenerator(
|
return createCodeAssistContentGenerator(
|
||||||
httpOptions,
|
httpOptions,
|
||||||
config.authType,
|
config.authType,
|
||||||
|
gcConfig,
|
||||||
sessionId,
|
sessionId,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue