Refactor in preparation for Reauth (#1196)

This commit is contained in:
Tommaso Sciortino 2025-06-18 16:34:00 -07:00 committed by GitHub
parent b96fbd913e
commit 8bc3b415c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 128 additions and 98 deletions

View File

@ -47,6 +47,10 @@ export async function main() {
const extensions = loadExtensions(workspaceRoot);
const config = await loadCliConfig(settings.merged, extensions, sessionId);
// When using Code Assist this triggers the Oauth login.
// Do this now, before sandboxing, so web redirect works.
await config.getGeminiClient().initialize();
// Initialize centralized FileDiscoveryService
config.getFileService();
if (config.getCheckpointEnabled()) {
@ -65,10 +69,6 @@ export async function main() {
}
}
// When using Code Assist this triggers the Oauth login.
// Do this now, before sandboxing, so web redirect works.
await config.getGeminiClient().getChat();
// hop into sandbox if we are outside and sandboxing is enabled
if (!process.env.SANDBOX) {
const sandboxConfig = config.getSandbox();

View File

@ -73,8 +73,10 @@ describe('oauth2', () => {
(resolve) => (serverListeningCallback = resolve),
);
let capturedPort = 0;
const mockHttpServer = {
listen: vi.fn((port: number, callback?: () => void) => {
capturedPort = port;
if (callback) {
callback();
}
@ -86,7 +88,7 @@ describe('oauth2', () => {
}
}),
on: vi.fn(),
address: () => ({ port: 1234 }),
address: () => ({ port: capturedPort }),
};
vi.mocked(http.createServer).mockImplementation((cb) => {
requestCallback = cb as http.RequestListener<
@ -115,7 +117,10 @@ describe('oauth2', () => {
expect(client).toBe(mockOAuth2Client);
expect(open).toHaveBeenCalledWith(mockAuthUrl);
expect(mockGetToken).toHaveBeenCalledWith(mockCode);
expect(mockGetToken).toHaveBeenCalledWith({
code: mockCode,
redirect_uri: `http://localhost:${capturedPort}/oauth2callback`,
});
expect(mockSetCredentials).toHaveBeenCalledWith(mockTokens);
const tokenPath = path.join(tempHomeDir, '.gemini', 'oauth_creds.json');

View File

@ -42,39 +42,54 @@ const SIGN_IN_FAILURE_URL =
const GEMINI_DIR = '.gemini';
const CREDENTIAL_FILENAME = 'oauth_creds.json';
export async function getOauthClient(): Promise<OAuth2Client> {
try {
return await getCachedCredentialClient();
} catch (_) {
const loggedInClient = await webLoginClient();
await setCachedCredentials(loggedInClient.credentials);
return loggedInClient;
}
/**
* An Authentication URL for updating the credentials of a Oauth2Client
* as well as a promise that will resolve when the credentials have
* been refreshed (or which throws error when refreshing credentials failed).
*/
export interface OauthWebLogin {
authUrl: string;
loginCompletePromise: Promise<void>;
}
async function webLoginClient(): Promise<OAuth2Client> {
const port = await getAvailablePort();
const oAuth2Client = new OAuth2Client({
export async function getOauthClient(): Promise<OAuth2Client> {
const client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
redirectUri: `http://localhost:${port}/oauth2callback`,
});
return new Promise((resolve, reject) => {
const state = crypto.randomBytes(32).toString('hex');
const authURL: string = oAuth2Client.generateAuthUrl({
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${authURL}\n\n`,
);
open(authURL);
console.log('Waiting for authentication...');
if (await loadCachedCredentials(client)) {
// Found valid cached credentials.
return client;
}
const webLogin = await authWithWeb(client);
console.log(
`\n\nCode Assist login required.\n` +
`Attempting to open authentication page in your browser.\n` +
`Otherwise navigate to:\n\n${webLogin.authUrl}\n\n`,
);
await open(webLogin.authUrl);
console.log('Waiting for authentication...');
await webLogin.loginCompletePromise;
return client;
}
async function authWithWeb(client: OAuth2Client): Promise<OauthWebLogin> {
const port = await getAvailablePort();
const redirectUri = `http://localhost:${port}/oauth2callback`;
const state = crypto.randomBytes(32).toString('hex');
const authUrl: string = client.generateAuthUrl({
redirect_uri: redirectUri,
access_type: 'offline',
scope: OAUTH_SCOPE,
state,
});
const loginCompletePromise = new Promise<void>((resolve, reject) => {
const server = http.createServer(async (req, res) => {
try {
if (req.url!.indexOf('/oauth2callback') === -1) {
@ -94,13 +109,16 @@ async function webLoginClient(): Promise<OAuth2Client> {
reject(new Error('State mismatch. Possible CSRF attack'));
} else if (qs.get('code')) {
const code: string = qs.get('code')!;
const { tokens } = await oAuth2Client.getToken(code);
oAuth2Client.setCredentials(tokens);
const { tokens } = await client.getToken({
code: qs.get('code')!,
redirect_uri: redirectUri,
});
client.setCredentials(tokens);
await cacheCredentials(client.credentials);
res.writeHead(HTTP_REDIRECT, { Location: SIGN_IN_SUCCESS_URL });
res.end();
resolve(oAuth2Client);
resolve();
} else {
reject(new Error('No code found in request'));
}
@ -112,9 +130,14 @@ async function webLoginClient(): Promise<OAuth2Client> {
});
server.listen(port);
});
return {
authUrl,
loginCompletePromise,
};
}
function getAvailablePort(): Promise<number> {
export function getAvailablePort(): Promise<number> {
return new Promise((resolve, reject) => {
let port = 0;
try {
@ -135,25 +158,20 @@ function getAvailablePort(): Promise<number> {
});
}
async function getCachedCredentialClient(): Promise<OAuth2Client> {
async function loadCachedCredentials(client: OAuth2Client): Promise<boolean> {
try {
const creds = await fs.readFile(getCachedCredentialPath(), 'utf-8');
const oAuth2Client = new OAuth2Client({
clientId: OAUTH_CLIENT_ID,
clientSecret: OAUTH_CLIENT_SECRET,
});
oAuth2Client.setCredentials(JSON.parse(creds));
client.setCredentials(JSON.parse(creds));
// This will either return the existing token or refresh it.
await oAuth2Client.getAccessToken();
// If we are here, the token is valid.
return oAuth2Client;
await client.getAccessToken();
return true;
} catch (_) {
// Could not load credentials.
throw new Error('Could not load credentials');
return false;
}
}
async function setCachedCredentials(credentials: Credentials) {
async function cacheCredentials(credentials: Credentials) {
const filePath = getCachedCredentialPath();
await fs.mkdir(path.dirname(filePath), { recursive: true });

View File

@ -65,7 +65,7 @@ vi.mock('../telemetry/index.js', () => ({
describe('Gemini Client (client.ts)', () => {
let client: GeminiClient;
beforeEach(() => {
beforeEach(async () => {
vi.resetAllMocks();
// Set up the mock for GoogleGenAI constructor and its methods
@ -131,6 +131,7 @@ describe('Gemini Client (client.ts)', () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const mockConfig = new Config({} as any);
client = new GeminiClient(mockConfig);
await client.initialize();
});
afterEach(() => {
@ -262,9 +263,7 @@ describe('Gemini Client (client.ts)', () => {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = Promise.resolve(
mockGenerator as ContentGenerator,
);
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateContent(contents, generationConfig, abortSignal);
@ -292,9 +291,7 @@ describe('Gemini Client (client.ts)', () => {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = Promise.resolve(
mockGenerator as ContentGenerator,
);
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateJson(contents, schema, abortSignal);
@ -319,7 +316,7 @@ describe('Gemini Client (client.ts)', () => {
addHistory: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
client['chat'] = Promise.resolve(mockChat as any);
client['chat'] = mockChat as any;
const newContent = {
role: 'user',
@ -371,14 +368,12 @@ describe('Gemini Client (client.ts)', () => {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = Promise.resolve(mockChat as GeminiChat);
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = Promise.resolve(
mockGenerator as ContentGenerator,
);
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(

View File

@ -44,8 +44,8 @@ function isThinkingSupported(model: string) {
}
export class GeminiClient {
private chat: Promise<GeminiChat>;
private contentGenerator: Promise<ContentGenerator>;
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
private model: string;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
@ -59,35 +59,45 @@ export class GeminiClient {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
}
this.contentGenerator = createContentGenerator(
this.config.getContentGeneratorConfig(),
);
this.model = config.getModel();
this.embeddingModel = config.getEmbeddingModel();
this.chat = this.startChat();
}
async initialize() {
this.contentGenerator = await createContentGenerator(
this.config.getContentGeneratorConfig(),
);
this.chat = await this.startChat();
}
async addHistory(content: Content) {
const chat = await this.chat;
chat.addHistory(content);
this.getChat().addHistory(content);
}
getChat(): Promise<GeminiChat> {
getChat(): GeminiChat {
if (!this.chat) {
throw new Error('Chat not initialized');
}
return this.chat;
}
private getContentGenerator(): ContentGenerator {
if (!this.contentGenerator) {
throw new Error('Content generator not initialized');
}
return this.contentGenerator;
}
async getHistory(): Promise<Content[]> {
const chat = await this.chat;
return chat.getHistory();
return this.getChat().getHistory();
}
async setHistory(history: Content[]): Promise<void> {
const chat = await this.chat;
chat.setHistory(history);
this.getChat().setHistory(history);
}
async resetChat(): Promise<void> {
this.chat = this.startChat();
this.chat = await this.startChat();
await this.chat;
}
@ -184,7 +194,7 @@ export class GeminiClient {
: this.generateContentConfig;
return new GeminiChat(
this.config,
await this.contentGenerator,
this.getContentGenerator(),
this.model,
{
systemInstruction,
@ -210,22 +220,24 @@ export class GeminiClient {
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!turns) {
const chat = await this.chat;
return new Turn(chat);
return new Turn(this.getChat());
}
const compressed = await this.tryCompressChat();
if (compressed) {
yield { type: GeminiEventType.ChatCompressed, value: compressed };
}
const chat = await this.chat;
const turn = new Turn(chat);
const turn = new Turn(this.getChat());
const resultStream = turn.run(request, signal);
for await (const event of resultStream) {
yield event;
}
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this,
signal,
);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
@ -243,7 +255,6 @@ export class GeminiClient {
model: string = DEFAULT_GEMINI_FLASH_MODEL,
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
const cg = await this.contentGenerator;
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
@ -254,7 +265,7 @@ export class GeminiClient {
};
const apiCall = () =>
cg.generateContent({
this.getContentGenerator().generateContent({
model,
config: {
...requestConfig,
@ -327,7 +338,6 @@ export class GeminiClient {
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
): Promise<GenerateContentResponse> {
const cg = await this.contentGenerator;
const modelToUse = this.model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
@ -345,7 +355,7 @@ export class GeminiClient {
};
const apiCall = () =>
cg.generateContent({
this.getContentGenerator().generateContent({
model: modelToUse,
config: requestConfig,
contents,
@ -386,8 +396,8 @@ export class GeminiClient {
contents: texts,
};
const cg = await this.contentGenerator;
const embedContentResponse = await cg.embedContent(embedModelParams);
const embedContentResponse =
await this.getContentGenerator().embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
@ -415,19 +425,18 @@ export class GeminiClient {
async tryCompressChat(
force: boolean = false,
): Promise<ChatCompressionInfo | null> {
const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history
const history = this.getChat().getHistory(true); // Get curated history
// Regardless of `force`, don't do anything if the history is empty.
if (history.length === 0) {
return null;
}
const cg = await this.contentGenerator;
const { totalTokens: originalTokenCount } = await cg.countTokens({
model: this.model,
contents: history,
});
const { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({
model: this.model,
contents: history,
});
// If not forced, check if we should compress based on context size.
if (!force) {
@ -457,7 +466,7 @@ export class GeminiClient {
const summarizationRequestMessage = {
text: 'Summarize our conversation up to this point. The summary should be a concise yet comprehensive overview of all key topics, questions, answers, and important details discussed. This summary will replace the current chat history to conserve tokens, so it must capture everything essential to understand the context and continue our conversation effectively as if no information was lost.',
};
const response = await chat.sendMessage({
const response = await this.getChat().sendMessage({
message: summarizationRequestMessage,
});
const newHistory = [
@ -470,9 +479,12 @@ export class GeminiClient {
parts: [{ text: response.text }],
},
];
this.chat = this.startChat(newHistory);
this.chat = await this.startChat(newHistory);
const newTokenCount = (
await cg.countTokens({ model: this.model, contents: newHistory })
await this.getContentGenerator().countTokens({
model: this.model,
contents: newHistory,
})
).totalTokens;
return originalTokenCount && newTokenCount

View File

@ -49,7 +49,7 @@ export async function createContentGenerator(
},
};
if (config.codeAssist) {
return createCodeAssistContentGenerator(httpOptions);
return await createCodeAssistContentGenerator(httpOptions);
}
const googleGenAI = new GoogleGenAI({
apiKey: config.apiKey === '' ? undefined : config.apiKey,