Added support for session_id in API calls (#2886)

This commit is contained in:
Bryan Morgan 2025-07-01 19:16:09 -04:00 committed by GitHub
parent 3492c429b9
commit dbe88f6e0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 46 additions and 5 deletions

View File

@ -12,11 +12,12 @@ import { CodeAssistServer, HttpOptions } from './server.js';
export async function createCodeAssistContentGenerator( export async function createCodeAssistContentGenerator(
httpOptions: HttpOptions, httpOptions: HttpOptions,
authType: AuthType, authType: AuthType,
sessionId?: string,
): Promise<ContentGenerator> { ): Promise<ContentGenerator> {
if (authType === AuthType.LOGIN_WITH_GOOGLE) { if (authType === AuthType.LOGIN_WITH_GOOGLE) {
const authClient = await getOauthClient(); const authClient = await getOauthClient();
const projectId = await setupUser(authClient); const projectId = await setupUser(authClient);
return new CodeAssistServer(authClient, projectId, httpOptions); return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
} }
throw new Error(`Unsupported authType: ${authType}`); throw new Error(`Unsupported authType: ${authType}`);

View File

@ -37,6 +37,7 @@ describe('converter', () => {
labels: undefined, labels: undefined,
safetySettings: undefined, safetySettings: undefined,
generationConfig: undefined, generationConfig: undefined,
session_id: undefined,
}, },
}); });
}); });
@ -59,6 +60,34 @@ describe('converter', () => {
labels: undefined, labels: undefined,
safetySettings: undefined, safetySettings: undefined,
generationConfig: undefined, generationConfig: undefined,
session_id: undefined,
},
});
});
it('should convert a request with sessionId', () => {
const genaiReq: GenerateContentParameters = {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toGenerateContentRequest(
genaiReq,
'my-project',
'session-123',
);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
request: {
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
systemInstruction: undefined,
cachedContent: undefined,
tools: undefined,
toolConfig: undefined,
labels: undefined,
safetySettings: undefined,
generationConfig: undefined,
session_id: 'session-123',
}, },
}); });
}); });

View File

@ -44,6 +44,7 @@ interface VertexGenerateContentRequest {
labels?: Record<string, string>; labels?: Record<string, string>;
safetySettings?: SafetySetting[]; safetySettings?: SafetySetting[];
generationConfig?: VertexGenerationConfig; generationConfig?: VertexGenerationConfig;
session_id?: string;
} }
interface VertexGenerationConfig { interface VertexGenerationConfig {
@ -114,11 +115,12 @@ export function fromCountTokenResponse(
export function toGenerateContentRequest( export function toGenerateContentRequest(
req: GenerateContentParameters, req: GenerateContentParameters,
project?: string, project?: string,
sessionId?: string,
): CAGenerateContentRequest { ): CAGenerateContentRequest {
return { return {
model: req.model, model: req.model,
project, project,
request: toVertexGenerateContentRequest(req), request: toVertexGenerateContentRequest(req, sessionId),
}; };
} }
@ -136,6 +138,7 @@ export function fromGenerateContentResponse(
function toVertexGenerateContentRequest( function toVertexGenerateContentRequest(
req: GenerateContentParameters, req: GenerateContentParameters,
sessionId?: string,
): VertexGenerateContentRequest { ): VertexGenerateContentRequest {
return { return {
contents: toContents(req.contents), contents: toContents(req.contents),
@ -146,6 +149,7 @@ function toVertexGenerateContentRequest(
labels: req.config?.labels, labels: req.config?.labels,
safetySettings: req.config?.safetySettings, safetySettings: req.config?.safetySettings,
generationConfig: toVertexGenerationConfig(req.config), generationConfig: toVertexGenerationConfig(req.config),
session_id: sessionId,
}; };
} }

View File

@ -48,6 +48,7 @@ export class CodeAssistServer implements ContentGenerator {
readonly client: OAuth2Client, readonly client: OAuth2Client,
readonly projectId?: string, readonly projectId?: string,
readonly httpOptions: HttpOptions = {}, readonly httpOptions: HttpOptions = {},
readonly sessionId?: string,
) {} ) {}
async generateContentStream( async generateContentStream(
@ -55,7 +56,7 @@ export class CodeAssistServer implements ContentGenerator {
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.requestStreamingPost<CaGenerateContentResponse>( const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
'streamGenerateContent', 'streamGenerateContent',
toGenerateContentRequest(req, this.projectId), toGenerateContentRequest(req, this.projectId, this.sessionId),
req.config?.abortSignal, req.config?.abortSignal,
); );
return (async function* (): AsyncGenerator<GenerateContentResponse> { return (async function* (): AsyncGenerator<GenerateContentResponse> {
@ -70,7 +71,7 @@ export class CodeAssistServer implements ContentGenerator {
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const resp = await this.requestPost<CaGenerateContentResponse>( const resp = await this.requestPost<CaGenerateContentResponse>(
'generateContent', 'generateContent',
toGenerateContentRequest(req, this.projectId), toGenerateContentRequest(req, this.projectId, this.sessionId),
req.config?.abortSignal, req.config?.abortSignal,
); );
return fromGenerateContentResponse(resp); return fromGenerateContentResponse(resp);

View File

@ -68,6 +68,7 @@ export class GeminiClient {
async initialize(contentGeneratorConfig: ContentGeneratorConfig) { async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
this.contentGenerator = await createContentGenerator( this.contentGenerator = await createContentGenerator(
contentGeneratorConfig, contentGeneratorConfig,
this.config.getSessionId(),
); );
this.chat = await this.startChat(); this.chat = await this.startChat();
} }

View File

@ -101,6 +101,7 @@ export async function createContentGeneratorConfig(
export async function createContentGenerator( export async function createContentGenerator(
config: ContentGeneratorConfig, config: ContentGeneratorConfig,
sessionId?: string,
): Promise<ContentGenerator> { ): Promise<ContentGenerator> {
const version = process.env.CLI_VERSION || process.version; const version = process.env.CLI_VERSION || process.version;
const httpOptions = { const httpOptions = {
@ -109,7 +110,11 @@ export async function createContentGenerator(
}, },
}; };
if (config.authType === AuthType.LOGIN_WITH_GOOGLE) { if (config.authType === AuthType.LOGIN_WITH_GOOGLE) {
return createCodeAssistContentGenerator(httpOptions, config.authType); return createCodeAssistContentGenerator(
httpOptions,
config.authType,
sessionId,
);
} }
if ( if (