diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts index 3d3a8ef3..03f388dc 100644 --- a/packages/core/src/code_assist/converter.test.ts +++ b/packages/core/src/code_assist/converter.test.ts @@ -24,12 +24,7 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - 'my-project', - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project'); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: 'my-project', @@ -42,9 +37,8 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, - session_id: 'my-session', + session_id: undefined, }, - user_prompt_id: 'my-prompt', }); }); @@ -53,12 +47,7 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - undefined, - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: undefined, @@ -71,9 +60,8 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, - session_id: 'my-session', + session_id: undefined, }, - user_prompt_id: 'my-prompt', }); }); @@ -84,7 +72,6 @@ describe('converter', () => { }; const codeAssistReq = toGenerateContentRequest( genaiReq, - 'my-prompt', 'my-project', 'session-123', ); @@ -102,7 +89,6 @@ describe('converter', () => { generationConfig: undefined, session_id: 'session-123', }, - user_prompt_id: 'my-prompt', }); }); @@ -111,12 +97,7 @@ describe('converter', () => { model: 'gemini-pro', contents: 'Hello', }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - 'my-project', - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, ]); @@ -127,12 +108,7 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ text: 'Hello' }, { text: 'World' }], }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - 'my-project', - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'World' }] }, @@ -147,12 +123,7 @@ describe('converter', () => { systemInstruction: 'You are a helpful assistant.', }, }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - 'my-project', - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.systemInstruction).toEqual({ role: 'user', parts: [{ text: 'You are a helpful assistant.' }], @@ -168,12 +139,7 @@ describe('converter', () => { topK: 40, }, }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - 'my-project', - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.8, topK: 40, @@ -199,12 +165,7 @@ describe('converter', () => { responseMimeType: 'application/json', }, }; - const codeAssistReq = toGenerateContentRequest( - genaiReq, - 'my-prompt', - 'my-project', - 'my-session', - ); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.1, topP: 0.2, diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index ffd471da..8340cfc1 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -32,7 +32,6 @@ import { export interface CAGenerateContentRequest { model: string; project?: string; - user_prompt_id?: string; request: VertexGenerateContentRequest; } @@ -116,14 +115,12 @@ export function fromCountTokenResponse( export function toGenerateContentRequest( req: GenerateContentParameters, - userPromptId: string, project?: string, sessionId?: string, ): CAGenerateContentRequest { return { model: req.model, project, - user_prompt_id: userPromptId, request: toVertexGenerateContentRequest(req, sessionId), }; } diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 3fc1891f..6246fd4e 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -14,25 +14,13 @@ vi.mock('google-auth-library'); describe('CodeAssistServer', () => { it('should be able to be constructed', () => { const auth = new OAuth2Client(); - const server = new CodeAssistServer( - auth, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(auth, 'test-project'); expect(server).toBeInstanceOf(CodeAssistServer); }); it('should call the generateContent endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { response: { candidates: [ @@ -50,13 +38,10 @@ describe('CodeAssistServer', () => { }; vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse); - const response = await server.generateContent( - { - model: 'test-model', - contents: [{ role: 'user', parts: [{ text: 'request' }] }], - }, - 'user-prompt-id', - ); + const response = await server.generateContent({ + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }); expect(server.requestPost).toHaveBeenCalledWith( 'generateContent', @@ -70,13 +55,7 @@ describe('CodeAssistServer', () => { it('should call the generateContentStream endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = (async function* () { yield { response: { @@ -96,13 +75,10 @@ describe('CodeAssistServer', () => { })(); vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse); - const stream = await server.generateContentStream( - { - model: 'test-model', - contents: [{ role: 'user', parts: [{ text: 'request' }] }], - }, - 'user-prompt-id', - ); + const stream = await server.generateContentStream({ + model: 'test-model', + contents: [{ role: 'user', parts: [{ text: 'request' }] }], + }); for await (const res of stream) { expect(server.requestStreamingPost).toHaveBeenCalledWith( @@ -116,13 +92,7 @@ describe('CodeAssistServer', () => { it('should call the onboardUser endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { name: 'operations/123', done: true, @@ -144,13 +114,7 @@ describe('CodeAssistServer', () => { it('should call the loadCodeAssist endpoint', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { currentTier: { id: UserTierId.FREE, @@ -176,13 +140,7 @@ describe('CodeAssistServer', () => { it('should return 0 for countTokens', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(client, 'test-project'); const mockResponse = { totalTokens: 100, }; @@ -197,13 +155,7 @@ describe('CodeAssistServer', () => { it('should throw an error for embedContent', async () => { const client = new OAuth2Client(); - const server = new CodeAssistServer( - client, - 'test-project', - {}, - 'test-session', - UserTierId.FREE, - ); + const server = new CodeAssistServer(client, 'test-project'); await expect( server.embedContent({ model: 'test-model', diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 08339bdc..7af643f7 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -53,16 +53,10 @@ export class CodeAssistServer implements ContentGenerator { async generateContentStream( req: GenerateContentParameters, - userPromptId: string, ): Promise> { const resps = await this.requestStreamingPost( 'streamGenerateContent', - toGenerateContentRequest( - req, - userPromptId, - this.projectId, - this.sessionId, - ), + toGenerateContentRequest(req, this.projectId, this.sessionId), req.config?.abortSignal, ); return (async function* (): AsyncGenerator { @@ -74,16 +68,10 @@ export class CodeAssistServer implements ContentGenerator { async generateContent( req: GenerateContentParameters, - userPromptId: string, ): Promise { const resp = await this.requestPost( 'generateContent', - toGenerateContentRequest( - req, - userPromptId, - this.projectId, - this.sessionId, - ), + toGenerateContentRequest(req, this.projectId, this.sessionId), req.config?.abortSignal, ); return fromGenerateContentResponse(resp); diff --git a/packages/core/src/code_assist/setup.test.ts b/packages/core/src/code_assist/setup.test.ts index c1260e3f..6db5fd88 100644 --- a/packages/core/src/code_assist/setup.test.ts +++ b/packages/core/src/code_assist/setup.test.ts @@ -49,11 +49,8 @@ describe('setupUser', () => { }); await setupUser({} as OAuth2Client); expect(CodeAssistServer).toHaveBeenCalledWith( - {}, + expect.any(Object), 'test-project', - {}, - '', - undefined, ); }); @@ -65,10 +62,7 @@ describe('setupUser', () => { }); const projectId = await setupUser({} as OAuth2Client); expect(CodeAssistServer).toHaveBeenCalledWith( - {}, - undefined, - {}, - '', + expect.any(Object), undefined, ); expect(projectId).toEqual({ diff --git a/packages/core/src/code_assist/setup.ts b/packages/core/src/code_assist/setup.ts index 9c7a8043..8831d24b 100644 --- a/packages/core/src/code_assist/setup.ts +++ b/packages/core/src/code_assist/setup.ts @@ -34,7 +34,7 @@ export interface UserData { */ export async function setupUser(client: OAuth2Client): Promise { let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined; - const caServer = new CodeAssistServer(client, projectId, {}, '', undefined); + const caServer = new CodeAssistServer(client, projectId); const clientMetadata: ClientMetadata = { ideType: 'IDE_UNSPECIFIED', diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 5101f98b..25ea9bc1 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -209,9 +209,7 @@ describe('Gemini Client (client.ts)', () => { // We can instantiate the client here since Config is mocked // and the constructor will use the mocked GoogleGenAI - client = new GeminiClient( - new Config({ sessionId: 'test-session-id' } as never), - ); + client = new GeminiClient(new Config({} as never)); mockConfigObject.getGeminiClient.mockReturnValue(client); await client.initialize(contentGeneratorConfig); @@ -350,19 +348,16 @@ describe('Gemini Client (client.ts)', () => { await client.generateContent(contents, generationConfig, abortSignal); - expect(mockGenerateContentFn).toHaveBeenCalledWith( - { - model: 'test-model', - config: { - abortSignal, - systemInstruction: getCoreSystemPrompt(''), - temperature: 0.5, - topP: 1, - }, - contents, + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: 'test-model', + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0.5, + topP: 1, }, - 'test-session-id', - ); + contents, + }); }); }); @@ -381,21 +376,18 @@ describe('Gemini Client (client.ts)', () => { await client.generateJson(contents, schema, abortSignal); - expect(mockGenerateContentFn).toHaveBeenCalledWith( - { - model: 'test-model', // Should use current model from config - config: { - abortSignal, - systemInstruction: getCoreSystemPrompt(''), - temperature: 0, - topP: 1, - responseSchema: schema, - responseMimeType: 'application/json', - }, - contents, + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: 'test-model', // Should use current model from config + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0, + topP: 1, + responseSchema: schema, + responseMimeType: 'application/json', }, - 'test-session-id', - ); + contents, + }); }); it('should allow overriding model and config', async () => { @@ -419,22 +411,19 @@ describe('Gemini Client (client.ts)', () => { customConfig, ); - expect(mockGenerateContentFn).toHaveBeenCalledWith( - { - model: customModel, - config: { - abortSignal, - systemInstruction: getCoreSystemPrompt(''), - temperature: 0.9, - topP: 1, // from default - topK: 20, - responseSchema: schema, - responseMimeType: 'application/json', - }, - contents, + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: customModel, + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0.9, + topP: 1, // from default + topK: 20, + responseSchema: schema, + responseMimeType: 'application/json', }, - 'test-session-id', - ); + contents, + }); }); }); @@ -1017,14 +1006,11 @@ Here are files the user has recently opened, with the most recent at the top: config: expect.any(Object), contents, }); - expect(mockGenerateContentFn).toHaveBeenCalledWith( - { - model: currentModel, - config: expect.any(Object), - contents, - }, - 'test-session-id', - ); + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: currentModel, + config: expect.any(Object), + contents, + }); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 02fbeb38..77683a45 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -106,7 +106,7 @@ export class GeminiClient { private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3; private readonly loopDetector: LoopDetectionService; - private lastPromptId: string; + private lastPromptId?: string; constructor(private config: Config) { if (config.getProxy()) { @@ -115,7 +115,6 @@ export class GeminiClient { this.embeddingModel = config.getEmbeddingModel(); this.loopDetector = new LoopDetectionService(config); - this.lastPromptId = this.config.getSessionId(); } async initialize(contentGeneratorConfig: ContentGeneratorConfig) { @@ -428,19 +427,16 @@ export class GeminiClient { }; const apiCall = () => - this.getContentGenerator().generateContent( - { - model: modelToUse, - config: { - ...requestConfig, - systemInstruction, - responseSchema: schema, - responseMimeType: 'application/json', - }, - contents, + this.getContentGenerator().generateContent({ + model: modelToUse, + config: { + ...requestConfig, + systemInstruction, + responseSchema: schema, + responseMimeType: 'application/json', }, - this.lastPromptId, - ); + contents, + }); const result = await retryWithBackoff(apiCall, { onPersistent429: async (authType?: string, error?: unknown) => @@ -525,14 +521,11 @@ export class GeminiClient { }; const apiCall = () => - this.getContentGenerator().generateContent( - { - model: modelToUse, - config: requestConfig, - contents, - }, - this.lastPromptId, - ); + this.getContentGenerator().generateContent({ + model: modelToUse, + config: requestConfig, + contents, + }); const result = await retryWithBackoff(apiCall, { onPersistent429: async (authType?: string, error?: unknown) => diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 797bad73..44ed7beb 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -25,12 +25,10 @@ import { UserTierId } from '../code_assist/types.js'; export interface ContentGenerator { generateContent( request: GenerateContentParameters, - userPromptId: string, ): Promise; generateContentStream( request: GenerateContentParameters, - userPromptId: string, ): Promise>; countTokens(request: CountTokensParameters): Promise; diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index cd5e3841..39dd883e 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -79,14 +79,11 @@ describe('GeminiChat', () => { await chat.sendMessage({ message: 'hello' }, 'prompt-id-1'); - expect(mockModelsModule.generateContent).toHaveBeenCalledWith( - { - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - config: {}, - }, - 'prompt-id-1', - ); + expect(mockModelsModule.generateContent).toHaveBeenCalledWith({ + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + config: {}, + }); }); }); @@ -114,14 +111,11 @@ describe('GeminiChat', () => { await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1'); - expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith( - { - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], - config: {}, - }, - 'prompt-id-1', - ); + expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({ + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + config: {}, + }); }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 14e8d946..4c3cd4c8 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -286,14 +286,11 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContent( - { - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, - }, - prompt_id, - ); + return this.contentGenerator.generateContent({ + model: modelToUse, + contents: requestContents, + config: { ...this.generationConfig, ...params.config }, + }); }; response = await retryWithBackoff(apiCall, { @@ -396,14 +393,11 @@ export class GeminiChat { ); } - return this.contentGenerator.generateContentStream( - { - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, - }, - prompt_id, - ); + return this.contentGenerator.generateContentStream({ + model: modelToUse, + contents: requestContents, + config: { ...this.generationConfig, ...params.config }, + }); }; // Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries