Added support for session_id in API calls (#2886)
This commit is contained in:
parent
3492c429b9
commit
dbe88f6e0e
|
@ -12,11 +12,12 @@ import { CodeAssistServer, HttpOptions } from './server.js';
|
|||
export async function createCodeAssistContentGenerator(
|
||||
httpOptions: HttpOptions,
|
||||
authType: AuthType,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
if (authType === AuthType.LOGIN_WITH_GOOGLE) {
|
||||
const authClient = await getOauthClient();
|
||||
const projectId = await setupUser(authClient);
|
||||
return new CodeAssistServer(authClient, projectId, httpOptions);
|
||||
return new CodeAssistServer(authClient, projectId, httpOptions, sessionId);
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported authType: ${authType}`);
|
||||
|
|
|
@ -37,6 +37,7 @@ describe('converter', () => {
|
|||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
@ -59,6 +60,34 @@ describe('converter', () => {
|
|||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should convert a request with sessionId', () => {
|
||||
const genaiReq: GenerateContentParameters = {
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toGenerateContentRequest(
|
||||
genaiReq,
|
||||
'my-project',
|
||||
'session-123',
|
||||
);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
request: {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
systemInstruction: undefined,
|
||||
cachedContent: undefined,
|
||||
tools: undefined,
|
||||
toolConfig: undefined,
|
||||
labels: undefined,
|
||||
safetySettings: undefined,
|
||||
generationConfig: undefined,
|
||||
session_id: 'session-123',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
|
|
@ -44,6 +44,7 @@ interface VertexGenerateContentRequest {
|
|||
labels?: Record<string, string>;
|
||||
safetySettings?: SafetySetting[];
|
||||
generationConfig?: VertexGenerationConfig;
|
||||
session_id?: string;
|
||||
}
|
||||
|
||||
interface VertexGenerationConfig {
|
||||
|
@ -114,11 +115,12 @@ export function fromCountTokenResponse(
|
|||
export function toGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
project?: string,
|
||||
sessionId?: string,
|
||||
): CAGenerateContentRequest {
|
||||
return {
|
||||
model: req.model,
|
||||
project,
|
||||
request: toVertexGenerateContentRequest(req),
|
||||
request: toVertexGenerateContentRequest(req, sessionId),
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -136,6 +138,7 @@ export function fromGenerateContentResponse(
|
|||
|
||||
function toVertexGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
sessionId?: string,
|
||||
): VertexGenerateContentRequest {
|
||||
return {
|
||||
contents: toContents(req.contents),
|
||||
|
@ -146,6 +149,7 @@ function toVertexGenerateContentRequest(
|
|||
labels: req.config?.labels,
|
||||
safetySettings: req.config?.safetySettings,
|
||||
generationConfig: toVertexGenerationConfig(req.config),
|
||||
session_id: sessionId,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@ export class CodeAssistServer implements ContentGenerator {
|
|||
readonly client: OAuth2Client,
|
||||
readonly projectId?: string,
|
||||
readonly httpOptions: HttpOptions = {},
|
||||
readonly sessionId?: string,
|
||||
) {}
|
||||
|
||||
async generateContentStream(
|
||||
|
@ -55,7 +56,7 @@ export class CodeAssistServer implements ContentGenerator {
|
|||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
||||
'streamGenerateContent',
|
||||
toGenerateContentRequest(req, this.projectId),
|
||||
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||
|
@ -70,7 +71,7 @@ export class CodeAssistServer implements ContentGenerator {
|
|||
): Promise<GenerateContentResponse> {
|
||||
const resp = await this.requestPost<CaGenerateContentResponse>(
|
||||
'generateContent',
|
||||
toGenerateContentRequest(req, this.projectId),
|
||||
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return fromGenerateContentResponse(resp);
|
||||
|
|
|
@ -68,6 +68,7 @@ export class GeminiClient {
|
|||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||
this.contentGenerator = await createContentGenerator(
|
||||
contentGeneratorConfig,
|
||||
this.config.getSessionId(),
|
||||
);
|
||||
this.chat = await this.startChat();
|
||||
}
|
||||
|
|
|
@ -101,6 +101,7 @@ export async function createContentGeneratorConfig(
|
|||
|
||||
export async function createContentGenerator(
|
||||
config: ContentGeneratorConfig,
|
||||
sessionId?: string,
|
||||
): Promise<ContentGenerator> {
|
||||
const version = process.env.CLI_VERSION || process.version;
|
||||
const httpOptions = {
|
||||
|
@ -109,7 +110,11 @@ export async function createContentGenerator(
|
|||
},
|
||||
};
|
||||
if (config.authType === AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return createCodeAssistContentGenerator(httpOptions, config.authType);
|
||||
return createCodeAssistContentGenerator(
|
||||
httpOptions,
|
||||
config.authType,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
|
|
Loading…
Reference in New Issue