Revert "Propagate user_prompt_id to GenerateConentRequest for logging" (#5007)
This commit is contained in:
parent
9ed351260c
commit
bd85070411
|
@ -24,12 +24,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq).toEqual({
|
expect(codeAssistReq).toEqual({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
project: 'my-project',
|
project: 'my-project',
|
||||||
|
@ -42,9 +37,8 @@ describe('converter', () => {
|
||||||
labels: undefined,
|
labels: undefined,
|
||||||
safetySettings: undefined,
|
safetySettings: undefined,
|
||||||
generationConfig: undefined,
|
generationConfig: undefined,
|
||||||
session_id: 'my-session',
|
session_id: undefined,
|
||||||
},
|
},
|
||||||
user_prompt_id: 'my-prompt',
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -53,12 +47,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
undefined,
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq).toEqual({
|
expect(codeAssistReq).toEqual({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
project: undefined,
|
project: undefined,
|
||||||
|
@ -71,9 +60,8 @@ describe('converter', () => {
|
||||||
labels: undefined,
|
labels: undefined,
|
||||||
safetySettings: undefined,
|
safetySettings: undefined,
|
||||||
generationConfig: undefined,
|
generationConfig: undefined,
|
||||||
session_id: 'my-session',
|
session_id: undefined,
|
||||||
},
|
},
|
||||||
user_prompt_id: 'my-prompt',
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -84,7 +72,6 @@ describe('converter', () => {
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(
|
||||||
genaiReq,
|
genaiReq,
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
'my-project',
|
||||||
'session-123',
|
'session-123',
|
||||||
);
|
);
|
||||||
|
@ -102,7 +89,6 @@ describe('converter', () => {
|
||||||
generationConfig: undefined,
|
generationConfig: undefined,
|
||||||
session_id: 'session-123',
|
session_id: 'session-123',
|
||||||
},
|
},
|
||||||
user_prompt_id: 'my-prompt',
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -111,12 +97,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: 'Hello',
|
contents: 'Hello',
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq.request.contents).toEqual([
|
expect(codeAssistReq.request.contents).toEqual([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
]);
|
]);
|
||||||
|
@ -127,12 +108,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ text: 'Hello' }, { text: 'World' }],
|
contents: [{ text: 'Hello' }, { text: 'World' }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq.request.contents).toEqual([
|
expect(codeAssistReq.request.contents).toEqual([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
{ role: 'user', parts: [{ text: 'World' }] },
|
{ role: 'user', parts: [{ text: 'World' }] },
|
||||||
|
@ -147,12 +123,7 @@ describe('converter', () => {
|
||||||
systemInstruction: 'You are a helpful assistant.',
|
systemInstruction: 'You are a helpful assistant.',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq.request.systemInstruction).toEqual({
|
expect(codeAssistReq.request.systemInstruction).toEqual({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: 'You are a helpful assistant.' }],
|
parts: [{ text: 'You are a helpful assistant.' }],
|
||||||
|
@ -168,12 +139,7 @@ describe('converter', () => {
|
||||||
topK: 40,
|
topK: 40,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
topK: 40,
|
topK: 40,
|
||||||
|
@ -199,12 +165,7 @@ describe('converter', () => {
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toGenerateContentRequest(
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
genaiReq,
|
|
||||||
'my-prompt',
|
|
||||||
'my-project',
|
|
||||||
'my-session',
|
|
||||||
);
|
|
||||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||||
temperature: 0.1,
|
temperature: 0.1,
|
||||||
topP: 0.2,
|
topP: 0.2,
|
||||||
|
|
|
@ -32,7 +32,6 @@ import {
|
||||||
export interface CAGenerateContentRequest {
|
export interface CAGenerateContentRequest {
|
||||||
model: string;
|
model: string;
|
||||||
project?: string;
|
project?: string;
|
||||||
user_prompt_id?: string;
|
|
||||||
request: VertexGenerateContentRequest;
|
request: VertexGenerateContentRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,14 +115,12 @@ export function fromCountTokenResponse(
|
||||||
|
|
||||||
export function toGenerateContentRequest(
|
export function toGenerateContentRequest(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
userPromptId: string,
|
|
||||||
project?: string,
|
project?: string,
|
||||||
sessionId?: string,
|
sessionId?: string,
|
||||||
): CAGenerateContentRequest {
|
): CAGenerateContentRequest {
|
||||||
return {
|
return {
|
||||||
model: req.model,
|
model: req.model,
|
||||||
project,
|
project,
|
||||||
user_prompt_id: userPromptId,
|
|
||||||
request: toVertexGenerateContentRequest(req, sessionId),
|
request: toVertexGenerateContentRequest(req, sessionId),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,25 +14,13 @@ vi.mock('google-auth-library');
|
||||||
describe('CodeAssistServer', () => {
|
describe('CodeAssistServer', () => {
|
||||||
it('should be able to be constructed', () => {
|
it('should be able to be constructed', () => {
|
||||||
const auth = new OAuth2Client();
|
const auth = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(auth, 'test-project');
|
||||||
auth,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
expect(server).toBeInstanceOf(CodeAssistServer);
|
expect(server).toBeInstanceOf(CodeAssistServer);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should call the generateContent endpoint', async () => {
|
it('should call the generateContent endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
client,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
response: {
|
response: {
|
||||||
candidates: [
|
candidates: [
|
||||||
|
@ -50,13 +38,10 @@ describe('CodeAssistServer', () => {
|
||||||
};
|
};
|
||||||
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
vi.spyOn(server, 'requestPost').mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
const response = await server.generateContent(
|
const response = await server.generateContent({
|
||||||
{
|
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||||
},
|
});
|
||||||
'user-prompt-id',
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(server.requestPost).toHaveBeenCalledWith(
|
expect(server.requestPost).toHaveBeenCalledWith(
|
||||||
'generateContent',
|
'generateContent',
|
||||||
|
@ -70,13 +55,7 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should call the generateContentStream endpoint', async () => {
|
it('should call the generateContentStream endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
client,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
const mockResponse = (async function* () {
|
const mockResponse = (async function* () {
|
||||||
yield {
|
yield {
|
||||||
response: {
|
response: {
|
||||||
|
@ -96,13 +75,10 @@ describe('CodeAssistServer', () => {
|
||||||
})();
|
})();
|
||||||
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
vi.spyOn(server, 'requestStreamingPost').mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
const stream = await server.generateContentStream(
|
const stream = await server.generateContentStream({
|
||||||
{
|
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||||
},
|
});
|
||||||
'user-prompt-id',
|
|
||||||
);
|
|
||||||
|
|
||||||
for await (const res of stream) {
|
for await (const res of stream) {
|
||||||
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
expect(server.requestStreamingPost).toHaveBeenCalledWith(
|
||||||
|
@ -116,13 +92,7 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should call the onboardUser endpoint', async () => {
|
it('should call the onboardUser endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
client,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
name: 'operations/123',
|
name: 'operations/123',
|
||||||
done: true,
|
done: true,
|
||||||
|
@ -144,13 +114,7 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should call the loadCodeAssist endpoint', async () => {
|
it('should call the loadCodeAssist endpoint', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
client,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
currentTier: {
|
currentTier: {
|
||||||
id: UserTierId.FREE,
|
id: UserTierId.FREE,
|
||||||
|
@ -176,13 +140,7 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should return 0 for countTokens', async () => {
|
it('should return 0 for countTokens', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
client,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
const mockResponse = {
|
const mockResponse = {
|
||||||
totalTokens: 100,
|
totalTokens: 100,
|
||||||
};
|
};
|
||||||
|
@ -197,13 +155,7 @@ describe('CodeAssistServer', () => {
|
||||||
|
|
||||||
it('should throw an error for embedContent', async () => {
|
it('should throw an error for embedContent', async () => {
|
||||||
const client = new OAuth2Client();
|
const client = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(
|
const server = new CodeAssistServer(client, 'test-project');
|
||||||
client,
|
|
||||||
'test-project',
|
|
||||||
{},
|
|
||||||
'test-session',
|
|
||||||
UserTierId.FREE,
|
|
||||||
);
|
|
||||||
await expect(
|
await expect(
|
||||||
server.embedContent({
|
server.embedContent({
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
|
|
|
@ -53,16 +53,10 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
|
|
||||||
async generateContentStream(
|
async generateContentStream(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
userPromptId: string,
|
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||||
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
const resps = await this.requestStreamingPost<CaGenerateContentResponse>(
|
||||||
'streamGenerateContent',
|
'streamGenerateContent',
|
||||||
toGenerateContentRequest(
|
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
||||||
req,
|
|
||||||
userPromptId,
|
|
||||||
this.projectId,
|
|
||||||
this.sessionId,
|
|
||||||
),
|
|
||||||
req.config?.abortSignal,
|
req.config?.abortSignal,
|
||||||
);
|
);
|
||||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||||
|
@ -74,16 +68,10 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
|
|
||||||
async generateContent(
|
async generateContent(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
userPromptId: string,
|
|
||||||
): Promise<GenerateContentResponse> {
|
): Promise<GenerateContentResponse> {
|
||||||
const resp = await this.requestPost<CaGenerateContentResponse>(
|
const resp = await this.requestPost<CaGenerateContentResponse>(
|
||||||
'generateContent',
|
'generateContent',
|
||||||
toGenerateContentRequest(
|
toGenerateContentRequest(req, this.projectId, this.sessionId),
|
||||||
req,
|
|
||||||
userPromptId,
|
|
||||||
this.projectId,
|
|
||||||
this.sessionId,
|
|
||||||
),
|
|
||||||
req.config?.abortSignal,
|
req.config?.abortSignal,
|
||||||
);
|
);
|
||||||
return fromGenerateContentResponse(resp);
|
return fromGenerateContentResponse(resp);
|
||||||
|
|
|
@ -49,11 +49,8 @@ describe('setupUser', () => {
|
||||||
});
|
});
|
||||||
await setupUser({} as OAuth2Client);
|
await setupUser({} as OAuth2Client);
|
||||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||||
{},
|
expect.any(Object),
|
||||||
'test-project',
|
'test-project',
|
||||||
{},
|
|
||||||
'',
|
|
||||||
undefined,
|
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -65,10 +62,7 @@ describe('setupUser', () => {
|
||||||
});
|
});
|
||||||
const projectId = await setupUser({} as OAuth2Client);
|
const projectId = await setupUser({} as OAuth2Client);
|
||||||
expect(CodeAssistServer).toHaveBeenCalledWith(
|
expect(CodeAssistServer).toHaveBeenCalledWith(
|
||||||
{},
|
expect.any(Object),
|
||||||
undefined,
|
|
||||||
{},
|
|
||||||
'',
|
|
||||||
undefined,
|
undefined,
|
||||||
);
|
);
|
||||||
expect(projectId).toEqual({
|
expect(projectId).toEqual({
|
||||||
|
|
|
@ -34,7 +34,7 @@ export interface UserData {
|
||||||
*/
|
*/
|
||||||
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
export async function setupUser(client: OAuth2Client): Promise<UserData> {
|
||||||
let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
|
let projectId = process.env.GOOGLE_CLOUD_PROJECT || undefined;
|
||||||
const caServer = new CodeAssistServer(client, projectId, {}, '', undefined);
|
const caServer = new CodeAssistServer(client, projectId);
|
||||||
|
|
||||||
const clientMetadata: ClientMetadata = {
|
const clientMetadata: ClientMetadata = {
|
||||||
ideType: 'IDE_UNSPECIFIED',
|
ideType: 'IDE_UNSPECIFIED',
|
||||||
|
|
|
@ -209,9 +209,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
// We can instantiate the client here since Config is mocked
|
// We can instantiate the client here since Config is mocked
|
||||||
// and the constructor will use the mocked GoogleGenAI
|
// and the constructor will use the mocked GoogleGenAI
|
||||||
client = new GeminiClient(
|
client = new GeminiClient(new Config({} as never));
|
||||||
new Config({ sessionId: 'test-session-id' } as never),
|
|
||||||
);
|
|
||||||
mockConfigObject.getGeminiClient.mockReturnValue(client);
|
mockConfigObject.getGeminiClient.mockReturnValue(client);
|
||||||
|
|
||||||
await client.initialize(contentGeneratorConfig);
|
await client.initialize(contentGeneratorConfig);
|
||||||
|
@ -350,8 +348,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
await client.generateContent(contents, generationConfig, abortSignal);
|
await client.generateContent(contents, generationConfig, abortSignal);
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
||||||
{
|
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
config: {
|
config: {
|
||||||
abortSignal,
|
abortSignal,
|
||||||
|
@ -360,9 +357,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
topP: 1,
|
topP: 1,
|
||||||
},
|
},
|
||||||
contents,
|
contents,
|
||||||
},
|
});
|
||||||
'test-session-id',
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -381,8 +376,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
await client.generateJson(contents, schema, abortSignal);
|
await client.generateJson(contents, schema, abortSignal);
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
||||||
{
|
|
||||||
model: 'test-model', // Should use current model from config
|
model: 'test-model', // Should use current model from config
|
||||||
config: {
|
config: {
|
||||||
abortSignal,
|
abortSignal,
|
||||||
|
@ -393,9 +387,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
},
|
},
|
||||||
contents,
|
contents,
|
||||||
},
|
});
|
||||||
'test-session-id',
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should allow overriding model and config', async () => {
|
it('should allow overriding model and config', async () => {
|
||||||
|
@ -419,8 +411,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
customConfig,
|
customConfig,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
||||||
{
|
|
||||||
model: customModel,
|
model: customModel,
|
||||||
config: {
|
config: {
|
||||||
abortSignal,
|
abortSignal,
|
||||||
|
@ -432,9 +423,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
},
|
},
|
||||||
contents,
|
contents,
|
||||||
},
|
});
|
||||||
'test-session-id',
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1017,14 +1006,11 @@ Here are files the user has recently opened, with the most recent at the top:
|
||||||
config: expect.any(Object),
|
config: expect.any(Object),
|
||||||
contents,
|
contents,
|
||||||
});
|
});
|
||||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
||||||
{
|
|
||||||
model: currentModel,
|
model: currentModel,
|
||||||
config: expect.any(Object),
|
config: expect.any(Object),
|
||||||
contents,
|
contents,
|
||||||
},
|
});
|
||||||
'test-session-id',
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ export class GeminiClient {
|
||||||
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||||
|
|
||||||
private readonly loopDetector: LoopDetectionService;
|
private readonly loopDetector: LoopDetectionService;
|
||||||
private lastPromptId: string;
|
private lastPromptId?: string;
|
||||||
|
|
||||||
constructor(private config: Config) {
|
constructor(private config: Config) {
|
||||||
if (config.getProxy()) {
|
if (config.getProxy()) {
|
||||||
|
@ -115,7 +115,6 @@ export class GeminiClient {
|
||||||
|
|
||||||
this.embeddingModel = config.getEmbeddingModel();
|
this.embeddingModel = config.getEmbeddingModel();
|
||||||
this.loopDetector = new LoopDetectionService(config);
|
this.loopDetector = new LoopDetectionService(config);
|
||||||
this.lastPromptId = this.config.getSessionId();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
|
||||||
|
@ -428,8 +427,7 @@ export class GeminiClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.getContentGenerator().generateContent(
|
this.getContentGenerator().generateContent({
|
||||||
{
|
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
config: {
|
config: {
|
||||||
...requestConfig,
|
...requestConfig,
|
||||||
|
@ -438,9 +436,7 @@ export class GeminiClient {
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
},
|
},
|
||||||
contents,
|
contents,
|
||||||
},
|
});
|
||||||
this.lastPromptId,
|
|
||||||
);
|
|
||||||
|
|
||||||
const result = await retryWithBackoff(apiCall, {
|
const result = await retryWithBackoff(apiCall, {
|
||||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||||
|
@ -525,14 +521,11 @@ export class GeminiClient {
|
||||||
};
|
};
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.getContentGenerator().generateContent(
|
this.getContentGenerator().generateContent({
|
||||||
{
|
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
config: requestConfig,
|
config: requestConfig,
|
||||||
contents,
|
contents,
|
||||||
},
|
});
|
||||||
this.lastPromptId,
|
|
||||||
);
|
|
||||||
|
|
||||||
const result = await retryWithBackoff(apiCall, {
|
const result = await retryWithBackoff(apiCall, {
|
||||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||||
|
|
|
@ -25,12 +25,10 @@ import { UserTierId } from '../code_assist/types.js';
|
||||||
export interface ContentGenerator {
|
export interface ContentGenerator {
|
||||||
generateContent(
|
generateContent(
|
||||||
request: GenerateContentParameters,
|
request: GenerateContentParameters,
|
||||||
userPromptId: string,
|
|
||||||
): Promise<GenerateContentResponse>;
|
): Promise<GenerateContentResponse>;
|
||||||
|
|
||||||
generateContentStream(
|
generateContentStream(
|
||||||
request: GenerateContentParameters,
|
request: GenerateContentParameters,
|
||||||
userPromptId: string,
|
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
||||||
|
|
||||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||||
|
|
|
@ -79,14 +79,11 @@ describe('GeminiChat', () => {
|
||||||
|
|
||||||
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
|
await chat.sendMessage({ message: 'hello' }, 'prompt-id-1');
|
||||||
|
|
||||||
expect(mockModelsModule.generateContent).toHaveBeenCalledWith(
|
expect(mockModelsModule.generateContent).toHaveBeenCalledWith({
|
||||||
{
|
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||||
config: {},
|
config: {},
|
||||||
},
|
});
|
||||||
'prompt-id-1',
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -114,14 +111,11 @@ describe('GeminiChat', () => {
|
||||||
|
|
||||||
await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1');
|
await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1');
|
||||||
|
|
||||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith(
|
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
||||||
{
|
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||||
config: {},
|
config: {},
|
||||||
},
|
});
|
||||||
'prompt-id-1',
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -286,14 +286,11 @@ export class GeminiChat {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.contentGenerator.generateContent(
|
return this.contentGenerator.generateContent({
|
||||||
{
|
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
contents: requestContents,
|
contents: requestContents,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
config: { ...this.generationConfig, ...params.config },
|
||||||
},
|
});
|
||||||
prompt_id,
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
response = await retryWithBackoff(apiCall, {
|
response = await retryWithBackoff(apiCall, {
|
||||||
|
@ -396,14 +393,11 @@ export class GeminiChat {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this.contentGenerator.generateContentStream(
|
return this.contentGenerator.generateContentStream({
|
||||||
{
|
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
contents: requestContents,
|
contents: requestContents,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
config: { ...this.generationConfig, ...params.config },
|
||||||
},
|
});
|
||||||
prompt_id,
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries
|
// Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries
|
||||||
|
|
Loading…
Reference in New Issue