CCPA Count Token support (#1170)
This commit is contained in:
parent
332512853e
commit
4662b058e8
|
@ -6,9 +6,9 @@
|
|||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
toCodeAssistRequest,
|
||||
fromCodeAsistResponse,
|
||||
CodeAssistResponse,
|
||||
toGenerateContentRequest,
|
||||
fromGenerateContentResponse,
|
||||
CaGenerateContentResponse,
|
||||
} from './converter.js';
|
||||
import {
|
||||
GenerateContentParameters,
|
||||
|
@ -24,7 +24,7 @@ describe('converter', () => {
|
|||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq, 'my-project');
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: 'my-project',
|
||||
|
@ -46,7 +46,7 @@ describe('converter', () => {
|
|||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||
expect(codeAssistReq).toEqual({
|
||||
model: 'gemini-pro',
|
||||
project: undefined,
|
||||
|
@ -68,7 +68,7 @@ describe('converter', () => {
|
|||
model: 'gemini-pro',
|
||||
contents: 'Hello',
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||
expect(codeAssistReq.request.contents).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
]);
|
||||
|
@ -79,7 +79,7 @@ describe('converter', () => {
|
|||
model: 'gemini-pro',
|
||||
contents: [{ text: 'Hello' }, { text: 'World' }],
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||
expect(codeAssistReq.request.contents).toEqual([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'user', parts: [{ text: 'World' }] },
|
||||
|
@ -94,7 +94,7 @@ describe('converter', () => {
|
|||
systemInstruction: 'You are a helpful assistant.',
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||
expect(codeAssistReq.request.systemInstruction).toEqual({
|
||||
role: 'user',
|
||||
parts: [{ text: 'You are a helpful assistant.' }],
|
||||
|
@ -110,7 +110,7 @@ describe('converter', () => {
|
|||
topK: 40,
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
|
@ -136,7 +136,7 @@ describe('converter', () => {
|
|||
responseMimeType: 'application/json',
|
||||
},
|
||||
};
|
||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
||||
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||
temperature: 0.1,
|
||||
topP: 0.2,
|
||||
|
@ -156,7 +156,7 @@ describe('converter', () => {
|
|||
|
||||
describe('fromCodeAssistResponse', () => {
|
||||
it('should convert a simple response', () => {
|
||||
const codeAssistRes: CodeAssistResponse = {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [
|
||||
{
|
||||
|
@ -171,13 +171,13 @@ describe('converter', () => {
|
|||
],
|
||||
},
|
||||
};
|
||||
const genaiRes = fromCodeAsistResponse(codeAssistRes);
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
|
||||
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
|
||||
});
|
||||
|
||||
it('should handle prompt feedback and usage metadata', () => {
|
||||
const codeAssistRes: CodeAssistResponse = {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
promptFeedback: {
|
||||
|
@ -191,7 +191,7 @@ describe('converter', () => {
|
|||
},
|
||||
},
|
||||
};
|
||||
const genaiRes = fromCodeAsistResponse(codeAssistRes);
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.promptFeedback).toEqual(
|
||||
codeAssistRes.response.promptFeedback,
|
||||
);
|
||||
|
@ -201,7 +201,7 @@ describe('converter', () => {
|
|||
});
|
||||
|
||||
it('should handle automatic function calling history', () => {
|
||||
const codeAssistRes: CodeAssistResponse = {
|
||||
const codeAssistRes: CaGenerateContentResponse = {
|
||||
response: {
|
||||
candidates: [],
|
||||
automaticFunctionCallingHistory: [
|
||||
|
@ -221,7 +221,7 @@ describe('converter', () => {
|
|||
],
|
||||
},
|
||||
};
|
||||
const genaiRes = fromCodeAsistResponse(codeAssistRes);
|
||||
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
|
||||
codeAssistRes.response.automaticFunctionCallingHistory,
|
||||
);
|
||||
|
|
|
@ -10,6 +10,8 @@ import {
|
|||
ContentUnion,
|
||||
GenerateContentConfig,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
CountTokensResponse,
|
||||
GenerateContentResponse,
|
||||
GenerationConfigRoutingConfig,
|
||||
MediaResolution,
|
||||
|
@ -27,13 +29,13 @@ import {
|
|||
ToolConfig,
|
||||
} from '@google/genai';
|
||||
|
||||
export interface CodeAssistRequest {
|
||||
export interface CAGenerateContentRequest {
|
||||
model: string;
|
||||
project?: string;
|
||||
request: CodeAssistGenerateContentRequest;
|
||||
request: VertexGenerateContentRequest;
|
||||
}
|
||||
|
||||
interface CodeAssistGenerateContentRequest {
|
||||
interface VertexGenerateContentRequest {
|
||||
contents: Content[];
|
||||
systemInstruction?: Content;
|
||||
cachedContent?: string;
|
||||
|
@ -41,10 +43,10 @@ interface CodeAssistGenerateContentRequest {
|
|||
toolConfig?: ToolConfig;
|
||||
labels?: Record<string, string>;
|
||||
safetySettings?: SafetySetting[];
|
||||
generationConfig?: CodeAssistGenerationConfig;
|
||||
generationConfig?: VertexGenerationConfig;
|
||||
}
|
||||
|
||||
interface CodeAssistGenerationConfig {
|
||||
interface VertexGenerationConfig {
|
||||
temperature?: number;
|
||||
topP?: number;
|
||||
topK?: number;
|
||||
|
@ -67,30 +69,61 @@ interface CodeAssistGenerationConfig {
|
|||
thinkingConfig?: ThinkingConfig;
|
||||
}
|
||||
|
||||
export interface CodeAssistResponse {
|
||||
response: VertexResponse;
|
||||
export interface CaGenerateContentResponse {
|
||||
response: VertexGenerateContentResponse;
|
||||
}
|
||||
|
||||
interface VertexResponse {
|
||||
interface VertexGenerateContentResponse {
|
||||
candidates: Candidate[];
|
||||
automaticFunctionCallingHistory?: Content[];
|
||||
promptFeedback?: GenerateContentResponsePromptFeedback;
|
||||
usageMetadata?: GenerateContentResponseUsageMetadata;
|
||||
}
|
||||
export interface CaCountTokenRequest {
|
||||
request: VertexCountTokenRequest;
|
||||
}
|
||||
|
||||
export function toCodeAssistRequest(
|
||||
req: GenerateContentParameters,
|
||||
project?: string,
|
||||
): CodeAssistRequest {
|
||||
interface VertexCountTokenRequest {
|
||||
model: string;
|
||||
contents: Content[];
|
||||
}
|
||||
|
||||
export interface CaCountTokenResponse {
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export function toCountTokenRequest(
|
||||
req: CountTokensParameters,
|
||||
): CaCountTokenRequest {
|
||||
return {
|
||||
model: req.model,
|
||||
project,
|
||||
request: toCodeAssistGenerateContentRequest(req),
|
||||
request: {
|
||||
model: 'models/' + req.model,
|
||||
contents: toContents(req.contents),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function fromCodeAsistResponse(
|
||||
res: CodeAssistResponse,
|
||||
export function fromCountTokenResponse(
|
||||
res: CaCountTokenResponse,
|
||||
): CountTokensResponse {
|
||||
return {
|
||||
totalTokens: res.totalTokens,
|
||||
};
|
||||
}
|
||||
|
||||
export function toGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
project?: string,
|
||||
): CAGenerateContentRequest {
|
||||
return {
|
||||
model: req.model,
|
||||
project,
|
||||
request: toVertexGenerateContentRequest(req),
|
||||
};
|
||||
}
|
||||
|
||||
export function fromGenerateContentResponse(
|
||||
res: CaGenerateContentResponse,
|
||||
): GenerateContentResponse {
|
||||
const inres = res.response;
|
||||
const out = new GenerateContentResponse();
|
||||
|
@ -101,9 +134,9 @@ export function fromCodeAsistResponse(
|
|||
return out;
|
||||
}
|
||||
|
||||
function toCodeAssistGenerateContentRequest(
|
||||
function toVertexGenerateContentRequest(
|
||||
req: GenerateContentParameters,
|
||||
): CodeAssistGenerateContentRequest {
|
||||
): VertexGenerateContentRequest {
|
||||
return {
|
||||
contents: toContents(req.contents),
|
||||
systemInstruction: maybeToContent(req.config?.systemInstruction),
|
||||
|
@ -112,7 +145,7 @@ function toCodeAssistGenerateContentRequest(
|
|||
toolConfig: req.config?.toolConfig,
|
||||
labels: req.config?.labels,
|
||||
safetySettings: req.config?.safetySettings,
|
||||
generationConfig: toCodeAssistGenerationConfig(req.config),
|
||||
generationConfig: toVertexGenerationConfig(req.config),
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -170,9 +203,9 @@ function toPart(part: PartUnion): Part {
|
|||
return part;
|
||||
}
|
||||
|
||||
function toCodeAssistGenerationConfig(
|
||||
function toVertexGenerationConfig(
|
||||
config?: GenerateContentConfig,
|
||||
): CodeAssistGenerationConfig | undefined {
|
||||
): VertexGenerationConfig | undefined {
|
||||
if (!config) {
|
||||
return undefined;
|
||||
}
|
||||
|
|
|
@ -133,11 +133,16 @@ describe('CodeAssistServer', () => {
|
|||
it('should return 0 for countTokens', async () => {
|
||||
const auth = new OAuth2Client();
|
||||
const server = new CodeAssistServer(auth, 'test-project');
|
||||
const mockResponse = {
|
||||
totalTokens: 100,
|
||||
};
|
||||
vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
|
||||
|
||||
const response = await server.countTokens({
|
||||
model: 'test-model',
|
||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||
});
|
||||
expect(response.totalTokens).toBe(0);
|
||||
expect(response.totalTokens).toBe(100);
|
||||
});
|
||||
|
||||
it('should throw an error for embedContent', async () => {
|
||||
|
|
|
@ -22,9 +22,12 @@ import {
|
|||
import * as readline from 'readline';
|
||||
import { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import {
|
||||
CodeAssistResponse,
|
||||
toCodeAssistRequest,
|
||||
fromCodeAsistResponse,
|
||||
CaGenerateContentResponse,
|
||||
toGenerateContentRequest,
|
||||
fromGenerateContentResponse,
|
||||
toCountTokenRequest,
|
||||
fromCountTokenResponse,
|
||||
CaCountTokenResponse,
|
||||
} from './converter.js';
|
||||
import { PassThrough } from 'node:stream';
|
||||
|
||||
|
@ -50,14 +53,14 @@ export class CodeAssistServer implements ContentGenerator {
|
|||
async generateContentStream(
|
||||
req: GenerateContentParameters,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
const resps = await this.streamEndpoint<CodeAssistResponse>(
|
||||
const resps = await this.streamEndpoint<CaGenerateContentResponse>(
|
||||
'streamGenerateContent',
|
||||
toCodeAssistRequest(req, this.projectId),
|
||||
toGenerateContentRequest(req, this.projectId),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||
for await (const resp of resps) {
|
||||
yield fromCodeAsistResponse(resp);
|
||||
yield fromGenerateContentResponse(resp);
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
@ -65,12 +68,12 @@ export class CodeAssistServer implements ContentGenerator {
|
|||
async generateContent(
|
||||
req: GenerateContentParameters,
|
||||
): Promise<GenerateContentResponse> {
|
||||
const resp = await this.callEndpoint<CodeAssistResponse>(
|
||||
const resp = await this.callEndpoint<CaGenerateContentResponse>(
|
||||
'generateContent',
|
||||
toCodeAssistRequest(req, this.projectId),
|
||||
toGenerateContentRequest(req, this.projectId),
|
||||
req.config?.abortSignal,
|
||||
);
|
||||
return fromCodeAsistResponse(resp);
|
||||
return fromGenerateContentResponse(resp);
|
||||
}
|
||||
|
||||
async onboardUser(
|
||||
|
@ -91,8 +94,12 @@ export class CodeAssistServer implements ContentGenerator {
|
|||
);
|
||||
}
|
||||
|
||||
async countTokens(_req: CountTokensParameters): Promise<CountTokensResponse> {
|
||||
return { totalTokens: 0 };
|
||||
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
|
||||
const resp = await this.callEndpoint<CaCountTokenResponse>(
|
||||
'countTokens',
|
||||
toCountTokenRequest(req),
|
||||
);
|
||||
return fromCountTokenResponse(resp);
|
||||
}
|
||||
|
||||
async embedContent(
|
||||
|
|
Loading…
Reference in New Issue