From a6a386f72aee3c5509a7c9fe7482060cf7d5884e Mon Sep 17 00:00:00 2001 From: owenofbrien <86964623+owenofbrien@users.noreply.github.com> Date: Fri, 1 Aug 2025 14:37:56 -0500 Subject: [PATCH] Propagate prompt (#5033) --- .../core/src/code_assist/converter.test.ts | 57 ++++++++++-- packages/core/src/code_assist/converter.ts | 3 + packages/core/src/code_assist/server.test.ts | 78 ++++++++++++---- packages/core/src/code_assist/server.ts | 16 +++- packages/core/src/code_assist/setup.test.ts | 10 ++- packages/core/src/code_assist/setup.ts | 2 +- packages/core/src/core/client.test.ts | 90 +++++++++++-------- packages/core/src/core/client.ts | 37 ++++---- packages/core/src/core/contentGenerator.ts | 2 + packages/core/src/core/geminiChat.test.ts | 26 +++--- packages/core/src/core/geminiChat.ts | 26 +++--- 11 files changed, 245 insertions(+), 102 deletions(-) diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts index 03f388dc..3d3a8ef3 100644 --- a/packages/core/src/code_assist/converter.test.ts +++ b/packages/core/src/code_assist/converter.test.ts @@ -24,7 +24,12 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project'); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: 'my-project', @@ -37,8 +42,9 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, - session_id: undefined, + session_id: 'my-session', }, + user_prompt_id: 'my-prompt', }); }); @@ -47,7 +53,12 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + undefined, + 'my-session', + ); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: undefined, @@ -60,8 +71,9 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, - session_id: undefined, + session_id: 'my-session', }, + user_prompt_id: 'my-prompt', }); }); @@ -72,6 +84,7 @@ describe('converter', () => { }; const codeAssistReq = toGenerateContentRequest( genaiReq, + 'my-prompt', 'my-project', 'session-123', ); @@ -89,6 +102,7 @@ describe('converter', () => { generationConfig: undefined, session_id: 'session-123', }, + user_prompt_id: 'my-prompt', }); }); @@ -97,7 +111,12 @@ describe('converter', () => { model: 'gemini-pro', contents: 'Hello', }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, ]); @@ -108,7 +127,12 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ text: 'Hello' }, { text: 'World' }], }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'World' }] }, @@ -123,7 +147,12 @@ describe('converter', () => { systemInstruction: 'You are a helpful assistant.', }, }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.systemInstruction).toEqual({ role: 'user', parts: [{ text: 'You are a helpful assistant.' }], @@ -139,7 +168,12 @@ describe('converter', () => { topK: 40, }, }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.8, topK: 40, @@ -165,7 +199,12 @@ describe('converter', () => { responseMimeType: 'application/json', }, }; - const codeAssistReq = toGenerateContentRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-prompt', + 'my-project', + 'my-session', + ); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.1, topP: 0.2, diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index 8340cfc1..ffd471da 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -32,6 +32,7 @@ import { export interface CAGenerateContentRequest { model: string; project?: string; + user_prompt_id?: string; request: VertexGenerateContentRequest; } @@ -115,12 +116,14 @@ export function fromCountTokenResponse( export function toGenerateContentRequest( req: GenerateContentParameters, + userPromptId: string, project?: string, sessionId?: string, ): CAGenerateContentRequest { return { model: req.model, project, + user_prompt_id: userPromptId, request: toVertexGenerateContentRequest(req, sessionId), }; } diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 6246fd4e..3fc1891f 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -14,13 +14,25 @@ vi.mock('google-auth-library'); describe('CodeAssistServer', () => { it('should be able to be constructed', () => { const auth = new OAuth2Client(); - const server = new CodeAssistServer(auth, 'test-project'); + const server = new CodeAssistServer( + auth, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); expect(server).toBeInstanceOf(CodeAssistServer); }); it('should call the generateContent endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { response: { candidates: [ @@ -38,10 +50,13 @@ describe('CodeAssistServer', () => { }; vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse); - const response = await server.generateContent({ - model: 'test-model', - contents: [{ role: 'user', parts: [{ text: 'request' }] }], - }); + const response = await server.generateContent( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); expect(server.requestPost).toHaveBeenCalledWith( 'generateContent', @@ -55,7 +70,13 @@ describe('CodeAssistServer', () => { it('should call the generateContentStream endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = (async function* () { yield { response: { @@ -75,10 +96,13 @@ describe('CodeAssistServer', () => { })(); vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse); - const stream = await server.generateContentStream({ - model: 'test-model', - contents: [{ role: 'user', parts: [{ text: 'request' }] }], - }); + const stream = await server.generateContentStream( + { + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }, + 'user-prompt-id', + ); for await (const res of stream) { expect(server.requestStreamingPost).toHaveBeenCalledWith( @@ -92,7 +116,13 @@ describe('CodeAssistServer', () => { it('should call the onboardUser endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { name: 'operations/123', done: true, @@ -114,7 +144,13 @@ describe('CodeAssistServer', () => { it('should call the loadCodeAssist endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { currentTier: { id: UserTierId.FREE, @@ -140,7 +176,13 @@ describe('CodeAssistServer', () => { it('should return 0 for countTokens', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); const mockResponse = { totalTokens: 100, }; @@ -155,7 +197,13 @@ describe('CodeAssistServer', () => { it('should throw an error for embedContent', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer(client, 'test-project'); + const server = new CodeAssistServer( + client, + 'test-project', + {}, + 'test-session', + UserTierId.FREE, + ); await expect( server.embedContent({ model: 'test-model', diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 7af643f7..08339bdc 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -53,10 +53,16 @@ export class CodeAssistServer implements ContentGenerator { async generateContentStream( req: GenerateContentParameters, + userPromptId: string, ): Promise> { const resps = await this.requestStreamingPost( 'streamGenerateContent', - toGenerateContentRequest(req, this.projectId, this.sessionId), + toGenerateContentRequest( + req, + userPromptId, + this.projectId, + this.sessionId, + ), req.config?.abortSignal, ); return (async function* (): AsyncGenerator { @@ -68,10 +74,16 @@ export class CodeAssistServer implements ContentGenerator { async generateContent( req: GenerateContentParameters, + userPromptId: string, ): Promise { const resp = await this.requestPost( 'generateContent', - toGenerateContentRequest(req, this.projectId, this.sessionId), + toGenerateContentRequest( + req, + userPromptId, + this.projectId, + this.sessionId, + ), req.config?.abortSignal, ); return fromGenerateContentResponse(resp); diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts index 6db5fd88..c1260e3f 100644 --- a/packages/core/src/code_assist/setup.test.ts +++ b/packages/core/src/code_assist/setup.test.ts @@ -49,8 +49,11 @@ describe('setupUser', () => { }); await setupUser({} as OAuth2Client); expect(CodeAssistServer).toHaveBeenCalledWith( - expect.any(Object), + {}, 'test-project', + {}, + '', + undefined, ); }); @@ -62,7 +65,10 @@ describe('setupUser', () => { }); const projectId = await setupUser({} as OAuth2Client); expect(CodeAssistServer).toHaveBeenCalledWith( - expect.any(Object), + {}, + undefined, + {}, + '', undefined, ); expect(projectId).toEqual({ diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 8831d24b..9c7a8043 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -34,7 +34,7 @@ export interface UserData { */ export async function setupUser(client: OAuth2Client): Promise { let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined; - const caServer = new CodeAssistServer(client, projectId); + const caServer = new CodeAssistServer(client, projectId, {}, '', undefined); const clientMetadata: ClientMetadata = { ideType: 'IDE_UNSPECIFIED', diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 68d8c231..1e39758a 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -214,7 +214,9 @@ describe('Gemini Client (client.ts)', () => { // We can instantiate the client here since Config is mocked // and the constructor will use the mocked GoogleGenAI - client = new GeminiClient(new Config({} as never)); + client = new GeminiClient( + new Config({ sessionId: 'test-session-id' } as never), + ); mockConfigObject.getGeminiClient.mockReturnValue(client); await client.initialize(contentGeneratorConfig); @@ -353,16 +355,19 @@ describe('Gemini Client (client.ts)', () => { await client.generateContent(contents, generationConfig, abortSignal); - expect(mockGenerateContentFn).toHaveBeenCalledWith({ - model: 'test-model', - config: { - abortSignal, - systemInstruction: getCoreSystemPrompt(''), - temperature: 0.5, - topP: 1, + expect(mockGenerateContentFn).toHaveBeenCalledWith( + { + model: 'test-model', + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0.5, + topP: 1, + }, + contents, }, - contents, - }); + 'test-session-id', + ); }); }); @@ -381,18 +386,21 @@ describe('Gemini Client (client.ts)', () => { await client.generateJson(contents, schema, abortSignal); - expect(mockGenerateContentFn).toHaveBeenCalledWith({ - model: 'test-model', // Should use current model from config - config: { - abortSignal, - systemInstruction: getCoreSystemPrompt(''), - temperature: 0, - topP: 1, - responseSchema: schema, - responseMimeType: 'application/json', + expect(mockGenerateContentFn).toHaveBeenCalledWith( + { + model: 'test-model', // Should use current model from config + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0, + topP: 1, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents, }, - contents, - }); + 'test-session-id', + ); }); it('should allow overriding model and config', async () => { @@ -416,19 +424,22 @@ describe('Gemini Client (client.ts)', () => { customConfig, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith({ - model: customModel, - config: { - abortSignal, - systemInstruction: getCoreSystemPrompt(''), - temperature: 0.9, - topP: 1, // from default - topK: 20, - responseSchema: schema, - responseMimeType: 'application/json', + expect(mockGenerateContentFn).toHaveBeenCalledWith( + { + model: customModel, + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0.9, + topP: 1, // from default + topK: 20, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents, }, - contents, - }); + 'test-session-id', + ); }); }); @@ -1196,11 +1207,14 @@ Here are some files the user has open, with the most recent at the top: config: expect.any(Object), contents, }); - expect(mockGenerateContentFn).toHaveBeenCalledWith({ - model: currentModel, - config: expect.any(Object), - contents, - }); + expect(mockGenerateContentFn).toHaveBeenCalledWith( + { + model: currentModel, + config: expect.any(Object), + contents, + }, + 'test-session-id', + ); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 57457826..3b6b57f9 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -110,7 +110,7 @@ export class GeminiClient { private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3; private readonly loopDetector: LoopDetectionService; - private lastPromptId?: string; + private lastPromptId: string; constructor(private config: Config) { if (config.getProxy()) { @@ -119,6 +119,7 @@ export class GeminiClient { this.embeddingModel = config.getEmbeddingModel(); this.loopDetector = new LoopDetectionService(config); + this.lastPromptId = this.config.getSessionId(); } async initialize(contentGeneratorConfig: ContentGeneratorConfig) { @@ -493,16 +494,19 @@ export class GeminiClient { }; const apiCall = () => - this.getContentGenerator().generateContent({ - model: modelToUse, - config: { - ...requestConfig, - systemInstruction, - responseSchema: schema, - responseMimeType: 'application/json', + this.getContentGenerator().generateContent( + { + model: modelToUse, + config: { + ...requestConfig, + systemInstruction, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents, }, - contents, - }); + this.lastPromptId, + ); const result = await retryWithBackoff(apiCall, { onPersistent429: async (authType?: string, error?: unknown) => @@ -601,11 +605,14 @@ export class GeminiClient { }; const apiCall = () => - this.getContentGenerator().generateContent({ - model: modelToUse, - config: requestConfig, - contents, - }); + this.getContentGenerator().generateContent( + { + model: modelToUse, + config: requestConfig, + contents, + }, + this.lastPromptId, + ); const result = await retryWithBackoff(apiCall, { onPersistent429: async (authType?: string, error?: unknown) => diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 44ed7beb..797bad73 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -25,10 +25,12 @@ import { UserTierId } from '../code_assist/types.js'; export interface ContentGenerator { generateContent( request: GenerateContentParameters, + userPromptId: string, ): Promise; generateContentStream( request: GenerateContentParameters, + userPromptId: string, ): Promise>; countTokens(request: CountTokensParameters): Promise; diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 39dd883e..cd5e3841 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -79,11 +79,14 @@ describe('GeminiChat', () => { await chat.sendMessage({ message: 'hello' }, 'prompt-id-1'); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith({ - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - config: {}, - }); + expect(mockModelsModule.generateContent).toHaveBeenCalledWith( + { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + config: {}, + }, + 'prompt-id-1', + ); }); }); @@ -111,11 +114,14 @@ describe('GeminiChat', () => { await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1'); - expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({ - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - config: {}, - }); + expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith( + { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + config: {}, + }, + 'prompt-id-1', + ); }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index d3b2e060..bd81400f 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -287,11 +287,14 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContent({ - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, - }); + return this.contentGenerator.generateContent( + { + model: modelToUse, + contents: requestContents, + config: { ...this.generationConfig, ...params.config }, + }, + prompt_id, + ); }; response = await retryWithBackoff(apiCall, { @@ -394,11 +397,14 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContentStream({ - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, - }); + return this.contentGenerator.generateContentStream( + { + model: modelToUse, + contents: requestContents, + config: { ...this.generationConfig, ...params.config }, + }, + prompt_id, + ); }; // Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries