diff --git a/packages/core/src/code_assist/codeAssist.ts b/packages/core/src/code_assist/codeAssist.ts index c3cb9293..80d95ca9 100644 --- a/packages/core/src/code_assist/codeAssist.ts +++ b/packages/core/src/code_assist/codeAssist.ts @@ -12,11 +12,12 @@ import { CodeAssistServer, HttpOptions } from './server.js'; export async function createCodeAssistContentGenerator( httpOptions: HttpOptions, authType: AuthType, + sessionId?: string, ): Promise { if (authType === AuthType.LOGIN_WITH_GOOGLE) { const authClient = await getOauthClient(); const projectId = await setupUser(authClient); - return new CodeAssistServer(authClient, projectId, httpOptions); + return new CodeAssistServer(authClient, projectId, httpOptions, sessionId); } throw new Error(`Unsupported authType: ${authType}`); diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts index 2170c960..03f388dc 100644 --- a/packages/core/src/code_assist/converter.test.ts +++ b/packages/core/src/code_assist/converter.test.ts @@ -37,6 +37,7 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, + session_id: undefined, }, }); }); @@ -59,6 +60,34 @@ describe('converter', () => { labels: undefined, safetySettings: undefined, generationConfig: undefined, + session_id: undefined, + }, + }); + }); + + it('should convert a request with sessionId', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + const codeAssistReq = toGenerateContentRequest( + genaiReq, + 'my-project', + 'session-123', + ); + expect(codeAssistReq).toEqual({ + model: 'gemini-pro', + project: 'my-project', + request: { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + systemInstruction: undefined, + cachedContent: undefined, + tools: undefined, + toolConfig: undefined, + labels: undefined, + safetySettings: undefined, + generationConfig: undefined, + session_id: 'session-123', }, }); }); diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index b9b854fc..b27617c4 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -44,6 +44,7 @@ interface VertexGenerateContentRequest { labels?: Record; safetySettings?: SafetySetting[]; generationConfig?: VertexGenerationConfig; + session_id?: string; } interface VertexGenerationConfig { @@ -114,11 +115,12 @@ export function fromCountTokenResponse( export function toGenerateContentRequest( req: GenerateContentParameters, project?: string, + sessionId?: string, ): CAGenerateContentRequest { return { model: req.model, project, - request: toVertexGenerateContentRequest(req), + request: toVertexGenerateContentRequest(req, sessionId), }; } @@ -136,6 +138,7 @@ export function fromGenerateContentResponse( function toVertexGenerateContentRequest( req: GenerateContentParameters, + sessionId?: string, ): VertexGenerateContentRequest { return { contents: toContents(req.contents), @@ -146,6 +149,7 @@ function toVertexGenerateContentRequest( labels: req.config?.labels, safetySettings: req.config?.safetySettings, generationConfig: toVertexGenerationConfig(req.config), + session_id: sessionId, }; } diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 3cf0c721..f285dba8 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -48,6 +48,7 @@ export class CodeAssistServer implements ContentGenerator { readonly client: OAuth2Client, readonly projectId?: string, readonly httpOptions: HttpOptions = {}, + readonly sessionId?: string, ) {} async generateContentStream( @@ -55,7 +56,7 @@ export class CodeAssistServer implements ContentGenerator { ): Promise> { const resps = await this.requestStreamingPost( 'streamGenerateContent', - toGenerateContentRequest(req, this.projectId), + toGenerateContentRequest(req, this.projectId, this.sessionId), req.config?.abortSignal, ); return (async function* (): AsyncGenerator { @@ -70,7 +71,7 @@ export class CodeAssistServer implements ContentGenerator { ): Promise { const resp = await this.requestPost( 'generateContent', - toGenerateContentRequest(req, this.projectId), + toGenerateContentRequest(req, this.projectId, this.sessionId), req.config?.abortSignal, ); return fromGenerateContentResponse(resp); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index b00a689b..fe60112d 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -68,6 +68,7 @@ export class GeminiClient { async initialize(contentGeneratorConfig: ContentGeneratorConfig) { this.contentGenerator = await createContentGenerator( contentGeneratorConfig, + this.config.getSessionId(), ); this.chat = await this.startChat(); } diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 4740c4ee..f0c163d2 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -101,6 +101,7 @@ export async function createContentGeneratorConfig( export async function createContentGenerator( config: ContentGeneratorConfig, + sessionId?: string, ): Promise { const version = process.env.CLI_VERSION || process.version; const httpOptions = { @@ -109,7 +110,11 @@ export async function createContentGenerator( }, }; if (config.authType === AuthType.LOGIN_WITH_GOOGLE) { - return createCodeAssistContentGenerator(httpOptions, config.authType); + return createCodeAssistContentGenerator( + httpOptions, + config.authType, + sessionId, + ); } if (