Propagate prompt (#5033)
This commit is contained in:
parent
67d16992cf
commit
a6a386f72a
|
@ -24,7 +24,12 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
'my-project',
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq).toEqual({
|
expect(codeAssistReq).toEqual({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
project: 'my-project',
|
project: 'my-project',
|
||||||
|
@ -37,8 +42,9 @@ describe('converter', () => {
|
||||||
labels: undefined,
|
labels: undefined,
|
||||||
safetySettings: undefined,
|
safetySettings: undefined,
|
||||||
generationConfig: undefined,
|
generationConfig: undefined,
|
||||||
session_id: undefined,
|
session_id: 'my-session',
|
||||||
},
|
},
|
||||||
|
user_prompt_id: 'my-prompt',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -47,7 +53,12 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
undefined,
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq).toEqual({
|
expect(codeAssistReq).toEqual({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
project: undefined,
|
project: undefined,
|
||||||
|
@ -60,8 +71,9 @@ describe('converter', () => {
|
||||||
labels: undefined,
|
labels: undefined,
|
||||||
safetySettings: undefined,
|
safetySettings: undefined,
|
||||||
generationConfig: undefined,
|
generationConfig: undefined,
|
||||||
session_id: undefined,
|
session_id: 'my-session',
|
||||||
},
|
},
|
||||||
|
user_prompt_id: 'my-prompt',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -72,6 +84,7 @@ describe('converter', () => {
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
genaiReq,
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
'my-project',
|
'my-project',
|
||||||
'session-123',
|
'session-123',
|
||||||
);
|
);
|
||||||
|
@ -89,6 +102,7 @@ describe('converter', () => {
|
||||||
generationConfig: undefined,
|
generationConfig: undefined,
|
||||||
session_id: 'session-123',
|
session_id: 'session-123',
|
||||||
},
|
},
|
||||||
|
user_prompt_id: 'my-prompt',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -97,7 +111,12 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: 'Hello',
|
contents: 'Hello',
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
'my-project',
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq.request.contents).toEqual([
|
expect(codeAssistReq.request.contents).toEqual([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
]);
|
]);
|
||||||
|
@ -108,7 +127,12 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ text: 'Hello' }, { text: 'World' }],
|
contents: [{ text: 'Hello' }, { text: 'World' }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
'my-project',
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq.request.contents).toEqual([
|
expect(codeAssistReq.request.contents).toEqual([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
{ role: 'user', parts: [{ text: 'World' }] },
|
{ role: 'user', parts: [{ text: 'World' }] },
|
||||||
|
@ -123,7 +147,12 @@ describe('converter', () => {
|
||||||
systemInstruction: 'You are a helpful assistant.',
|
systemInstruction: 'You are a helpful assistant.',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
'my-project',
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq.request.systemInstruction).toEqual({
|
expect(codeAssistReq.request.systemInstruction).toEqual({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: 'You are a helpful assistant.' }],
|
parts: [{ text: 'You are a helpful assistant.' }],
|
||||||
|
@ -139,7 +168,12 @@ describe('converter', () => {
|
||||||
topK: 40,
|
topK: 40,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
'my-project',
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
topK: 40,
|
topK: 40,
|
||||||
|
@ -165,7 +199,12 @@ describe('converter', () => {
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
|
genaiReq,
|
||||||
|
'my-prompt',
|
||||||
|
'my-project',
|
||||||
|
'my-session',
|
||||||
|
);
|
||||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||||
temperature: 0.1,
|
temperature: 0.1,
|
||||||
topP: 0.2,
|
topP: 0.2,
|
||||||
|
|
|
@ -32,6 +32,7 @@ import {
|
||||||
export interface CAGenerateContentRequest {
|
export interface CAGenerateContentRequest {
|
||||||
model: string;
|
model: string;
|
||||||
project?: string;
|
project?: string;
|
||||||
|
user_prompt_id?: string;
|
||||||
request: VertexGenerateContentRequest;
|
request: VertexGenerateContentRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,12 +116,14 @@ export function fromCountTokenResponse(
|
||||||
|
|
||||||
export function toGenerateContentRequest(
|
export function toGenerateContentRequest(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
|
userPromptId: string,
|
||||||
project?: string,
|
project?: string,
|
||||||
sessionId?: string,
|
sessionId?: string,
|
||||||
): CAGenerateContentRequest {
|
): CAGenerateContentRequest {
|
||||||
return {
|
return {
|
||||||
model: req.model,
|
model: req.model,
|
||||||
project,
|
project,
|
||||||
|
user_prompt_id: userPromptId,
|
||||||
request: toVertexGenerateContentRequest(req, sessionId),
|
request: toVertexGenerateContentRequest(req, sessionId),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,13 +14,25 @@ vi.mock('google-auth-library');
|
||||||
describe('CodeAssistServer', () => {
|
describe('CodeAssistServer', () => {
|
||||||
it('should be able to be constructed', () => {
|
it('should be able to be constructed', () => {
|
||||||
const auth = new OAuth2Client();
|
const auth = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
auth,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
expect(server).toBeInstanceOf(CodeAssistServer);
|
expect(server).toBeInstanceOf(CodeAssistServer);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should call the generateContent endpoint', async () => {
|
it('should call the generateContent endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(client, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
client,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
response: {
|
response: {
|
||||||
candidates: [
|
candidates: [
|
||||||
|
@ -38,10 +50,13 @@ describe('CodeAssistServer', () => {
|
||||||
};
|
};
|
||||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
const response = await server.generateContent({
|
const response = await server.generateContent(
|
||||||
model: 'test-model',
|
{
|
||||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
model: 'test-model',
|
||||||
});
|
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||||
|
},
|
||||||
|
'user-prompt-id',
|
||||||
|
);
|
||||||
|
|
||||||
expect(server.requestPost).toHaveBeenCalledWith(
|
expect(server.requestPost).toHaveBeenCalledWith(
|
||||||
'generateContent',
|
'generateContent',
|
||||||
|
@ -55,7 +70,13 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should call the generateContentStream endpoint', async () => {
|
it('should call the generateContentStream endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(client, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
client,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
const mockResponse = (async function* () {
|
const mockResponse = (async function* () {
|
||||||
yield {
|
yield {
|
||||||
response: {
|
response: {
|
||||||
|
@ -75,10 +96,13 @@ describe('CodeAssistServer', () => {
|
||||||
})();
|
})();
|
||||||
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
const stream = await server.generateContentStream({
|
const stream = await server.generateContentStream(
|
||||||
model: 'test-model',
|
{
|
||||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
model: 'test-model',
|
||||||
});
|
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||||
|
},
|
||||||
|
'user-prompt-id',
|
||||||
|
);
|
||||||
|
|
||||||
for await (const res of stream) {
|
for await (const res of stream) {
|
||||||
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
||||||
|
@ -92,7 +116,13 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should call the onboardUser endpoint', async () => {
|
it('should call the onboardUser endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(client, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
client,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
name: 'operations/123',
|
name: 'operations/123',
|
||||||
done: true,
|
done: true,
|
||||||
|
@ -114,7 +144,13 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should call the loadCodeAssist endpoint', async () => {
|
it('should call the loadCodeAssist endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(client, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
client,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
currentTier: {
|
currentTier: {
|
||||||
id: UserTierId.FREE,
|
id: UserTierId.FREE,
|
||||||
|
@ -140,7 +176,13 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should return 0 for countTokens', async () => {
|
it('should return 0 for countTokens', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(client, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
client,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
totalTokens: 100,
|
totalTokens: 100,
|
||||||
};
|
};
|
||||||
|
@ -155,7 +197,13 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should throw an error for embedContent', async () => {
|
it('should throw an error for embedContent', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(client, 'test-project');
|
const server = new CodeAssistServer(
|
||||||
|
client,
|
||||||
|
'test-project',
|
||||||
|
{},
|
||||||
|
'test-session',
|
||||||
|
UserTierId.FREE,
|
||||||
|
);
|
||||||
await expect(
|
await expect(
|
||||||
server.embedContent({
|
server.embedContent({
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
|
|
|
@ -53,10 +53,16 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
|
|
||||||
async generateContentStream(
|
async generateContentStream(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
|
userPromptId: string,
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||||
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
||||||
'streamGenerateContent',
|
'streamGenerateContent',
|
||||||
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
toGenerateContentRequest(
|
||||||
|
req,
|
||||||
|
userPromptId,
|
||||||
|
this.projectId,
|
||||||
|
this.sessionId,
|
||||||
|
),
|
||||||
req.config?.abortSignal,
|
req.config?.abortSignal,
|
||||||
);
|
);
|
||||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||||
|
@ -68,10 +74,16 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
|
|
||||||
async generateContent(
|
async generateContent(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
|
userPromptId: string,
|
||||||
): Promise<GenerateContentResponse> {
|
): Promise<GenerateContentResponse> {
|
||||||
const resp = await this.requestPost<CaGenerateContentResponse>(
|
const resp = await this.requestPost<CaGenerateContentResponse>(
|
||||||
'generateContent',
|
'generateContent',
|
||||||
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
toGenerateContentRequest(
|
||||||
|
req,
|
||||||
|
userPromptId,
|
||||||
|
this.projectId,
|
||||||
|
this.sessionId,
|
||||||
|
),
|
||||||
req.config?.abortSignal,
|
req.config?.abortSignal,
|
||||||
);
|
);
|
||||||
return fromGenerateContentResponse(resp);
|
return fromGenerateContentResponse(resp);
|
||||||
|
|
|
@ -49,8 +49,11 @@ describe('setupUser', () => {
|
||||||
});
|
});
|
||||||
await setupUser({} as OAuth2Client);
|
await setupUser({} as OAuth2Client);
|
||||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||||
expect.any(Object),
|
{},
|
||||||
'test-project',
|
'test-project',
|
||||||
|
{},
|
||||||
|
'',
|
||||||
|
undefined,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -62,7 +65,10 @@ describe('setupUser', () => {
|
||||||
});
|
});
|
||||||
const projectId = await setupUser({} as OAuth2Client);
|
const projectId = await setupUser({} as OAuth2Client);
|
||||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||||
expect.any(Object),
|
{},
|
||||||
|
undefined,
|
||||||
|
{},
|
||||||
|
'',
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(projectId).toEqual({
|
expect(projectId).toEqual({
|
||||||
|
|
|
@ -34,7 +34,7 @@ export interface UserData {
|
||||||
*/
|
*/
|
||||||
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
||||||
let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
|
let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
|
||||||
const caServer = new CodeAssistServer(client, projectId);
|
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
|
||||||
|
|
||||||
const clientMetadata: ClientMetadata = {
|
const clientMetadata: ClientMetadata = {
|
||||||
ideType: 'IDE_UNSPECIFIED',
|
ideType: 'IDE_UNSPECIFIED',
|
||||||
|
|
|
@ -214,7 +214,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
// We can instantiate the client here since Config is mocked
|
// We can instantiate the client here since Config is mocked
|
||||||
// and the constructor will use the mocked GoogleGenAI
|
// and the constructor will use the mocked GoogleGenAI
|
||||||
client = new GeminiClient(new Config({} as never));
|
client = new GeminiClient(
|
||||||
|
new Config({ sessionId: 'test-session-id' } as never),
|
||||||
|
);
|
||||||
mockConfigObject.getGeminiClient.mockReturnValue(client);
|
mockConfigObject.getGeminiClient.mockReturnValue(client);
|
||||||
|
|
||||||
await client.initialize(contentGeneratorConfig);
|
await client.initialize(contentGeneratorConfig);
|
||||||
|
@ -353,16 +355,19 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
await client.generateContent(contents, generationConfig, abortSignal);
|
await client.generateContent(contents, generationConfig, abortSignal);
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||||
model: 'test-model',
|
{
|
||||||
config: {
|
model: 'test-model',
|
||||||
abortSignal,
|
config: {
|
||||||
systemInstruction: getCoreSystemPrompt(''),
|
abortSignal,
|
||||||
temperature: 0.5,
|
systemInstruction: getCoreSystemPrompt(''),
|
||||||
topP: 1,
|
temperature: 0.5,
|
||||||
|
topP: 1,
|
||||||
|
},
|
||||||
|
contents,
|
||||||
},
|
},
|
||||||
contents,
|
'test-session-id',
|
||||||
});
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -381,18 +386,21 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
await client.generateJson(contents, schema, abortSignal);
|
await client.generateJson(contents, schema, abortSignal);
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||||
model: 'test-model', // Should use current model from config
|
{
|
||||||
config: {
|
model: 'test-model', // Should use current model from config
|
||||||
abortSignal,
|
config: {
|
||||||
systemInstruction: getCoreSystemPrompt(''),
|
abortSignal,
|
||||||
temperature: 0,
|
systemInstruction: getCoreSystemPrompt(''),
|
||||||
topP: 1,
|
temperature: 0,
|
||||||
responseSchema: schema,
|
topP: 1,
|
||||||
responseMimeType: 'application/json',
|
responseSchema: schema,
|
||||||
|
responseMimeType: 'application/json',
|
||||||
|
},
|
||||||
|
contents,
|
||||||
},
|
},
|
||||||
contents,
|
'test-session-id',
|
||||||
});
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should allow overriding model and config', async () => {
|
it('should allow overriding model and config', async () => {
|
||||||
|
@ -416,19 +424,22 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
customConfig,
|
customConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||||
model: customModel,
|
{
|
||||||
config: {
|
model: customModel,
|
||||||
abortSignal,
|
config: {
|
||||||
systemInstruction: getCoreSystemPrompt(''),
|
abortSignal,
|
||||||
temperature: 0.9,
|
systemInstruction: getCoreSystemPrompt(''),
|
||||||
topP: 1, // from default
|
temperature: 0.9,
|
||||||
topK: 20,
|
topP: 1, // from default
|
||||||
responseSchema: schema,
|
topK: 20,
|
||||||
responseMimeType: 'application/json',
|
responseSchema: schema,
|
||||||
|
responseMimeType: 'application/json',
|
||||||
|
},
|
||||||
|
contents,
|
||||||
},
|
},
|
||||||
contents,
|
'test-session-id',
|
||||||
});
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1196,11 +1207,14 @@ Here are some files the user has open, with the most recent at the top:
|
||||||
config: expect.any(Object),
|
config: expect.any(Object),
|
||||||
contents,
|
contents,
|
||||||
});
|
});
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||||
model: currentModel,
|
{
|
||||||
config: expect.any(Object),
|
model: currentModel,
|
||||||
contents,
|
config: expect.any(Object),
|
||||||
});
|
contents,
|
||||||
|
},
|
||||||
|
'test-session-id',
|
||||||
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ export class GeminiClient {
|
||||||
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||||
|
|
||||||
private readonly loopDetector: LoopDetectionService;
|
private readonly loopDetector: LoopDetectionService;
|
||||||
private lastPromptId?: string;
|
private lastPromptId: string;
|
||||||
|
|
||||||
constructor(private config: Config) {
|
constructor(private config: Config) {
|
||||||
if (config.getProxy()) {
|
if (config.getProxy()) {
|
||||||
|
@ -119,6 +119,7 @@ export class GeminiClient {
|
||||||
|
|
||||||
this.embeddingModel = config.getEmbeddingModel();
|
this.embeddingModel = config.getEmbeddingModel();
|
||||||
this.loopDetector = new LoopDetectionService(config);
|
this.loopDetector = new LoopDetectionService(config);
|
||||||
|
this.lastPromptId = this.config.getSessionId();
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||||
|
@ -493,16 +494,19 @@ export class GeminiClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.getContentGenerator().generateContent({
|
this.getContentGenerator().generateContent(
|
||||||
model: modelToUse,
|
{
|
||||||
config: {
|
model: modelToUse,
|
||||||
...requestConfig,
|
config: {
|
||||||
systemInstruction,
|
...requestConfig,
|
||||||
responseSchema: schema,
|
systemInstruction,
|
||||||
responseMimeType: 'application/json',
|
responseSchema: schema,
|
||||||
|
responseMimeType: 'application/json',
|
||||||
|
},
|
||||||
|
contents,
|
||||||
},
|
},
|
||||||
contents,
|
this.lastPromptId,
|
||||||
});
|
);
|
||||||
|
|
||||||
const result = await retryWithBackoff(apiCall, {
|
const result = await retryWithBackoff(apiCall, {
|
||||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||||
|
@ -601,11 +605,14 @@ export class GeminiClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.getContentGenerator().generateContent({
|
this.getContentGenerator().generateContent(
|
||||||
model: modelToUse,
|
{
|
||||||
config: requestConfig,
|
model: modelToUse,
|
||||||
contents,
|
config: requestConfig,
|
||||||
});
|
contents,
|
||||||
|
},
|
||||||
|
this.lastPromptId,
|
||||||
|
);
|
||||||
|
|
||||||
const result = await retryWithBackoff(apiCall, {
|
const result = await retryWithBackoff(apiCall, {
|
||||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||||
|
|
|
@ -25,10 +25,12 @@ import { UserTierId } from '../code_assist/types.js';
|
||||||
export interface ContentGenerator {
|
export interface ContentGenerator {
|
||||||
generateContent(
|
generateContent(
|
||||||
request: GenerateContentParameters,
|
request: GenerateContentParameters,
|
||||||
|
userPromptId: string,
|
||||||
): Promise<GenerateContentResponse>;
|
): Promise<GenerateContentResponse>;
|
||||||
|
|
||||||
generateContentStream(
|
generateContentStream(
|
||||||
request: GenerateContentParameters,
|
request: GenerateContentParameters,
|
||||||
|
userPromptId: string,
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
||||||
|
|
||||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||||
|
|
|
@ -79,11 +79,14 @@ describe('GeminiChat', () => {
|
||||||
|
|
||||||
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
|
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
|
||||||
|
|
||||||
expect(mockModelsModule.generateContent).toHaveBeenCalledWith({
|
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
|
||||||
model: 'gemini-pro',
|
{
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
model: 'gemini-pro',
|
||||||
config: {},
|
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||||
});
|
config: {},
|
||||||
|
},
|
||||||
|
'prompt-id-1',
|
||||||
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -111,11 +114,14 @@ describe('GeminiChat', () => {
|
||||||
|
|
||||||
await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1');
|
await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1');
|
||||||
|
|
||||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith(
|
||||||
model: 'gemini-pro',
|
{
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
model: 'gemini-pro',
|
||||||
config: {},
|
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||||
});
|
config: {},
|
||||||
|
},
|
||||||
|
'prompt-id-1',
|
||||||
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -287,11 +287,14 @@ export class GeminiChat {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.contentGenerator.generateContent({
|
return this.contentGenerator.generateContent(
|
||||||
model: modelToUse,
|
{
|
||||||
contents: requestContents,
|
model: modelToUse,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
contents: requestContents,
|
||||||
});
|
config: { ...this.generationConfig, ...params.config },
|
||||||
|
},
|
||||||
|
prompt_id,
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
response = await retryWithBackoff(apiCall, {
|
response = await retryWithBackoff(apiCall, {
|
||||||
|
@ -394,11 +397,14 @@ export class GeminiChat {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.contentGenerator.generateContentStream({
|
return this.contentGenerator.generateContentStream(
|
||||||
model: modelToUse,
|
{
|
||||||
contents: requestContents,
|
model: modelToUse,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
contents: requestContents,
|
||||||
});
|
config: { ...this.generationConfig, ...params.config },
|
||||||
|
},
|
||||||
|
prompt_id,
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries
|
// Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries
|
||||||
|
|
Loading…
Reference in New Issue