CCPA Count Token support (#1170)

This commit is contained in:
Tommaso Sciortino 2025-06-18 10:29:42 -07:00 committed by GitHub
parent 332512853e
commit 4662b058e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 50 deletions

View File

@ -6,9 +6,9 @@
import { describe, it, expect } from 'vitest';
import {
toCodeAssistRequest,
fromCodeAsistResponse,
CodeAssistResponse,
toGenerateContentRequest,
fromGenerateContentResponse,
CaGenerateContentResponse,
} from './converter.js';
import {
GenerateContentParameters,
@ -24,7 +24,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toCodeAssistRequest(genaiReq, 'my-project');
const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: 'my-project',
@ -46,7 +46,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const codeAssistReq = toCodeAssistRequest(genaiReq);
const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq).toEqual({
model: 'gemini-pro',
project: undefined,
@ -68,7 +68,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: 'Hello',
};
const codeAssistReq = toCodeAssistRequest(genaiReq);
const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
]);
@ -79,7 +79,7 @@ describe('converter', () => {
model: 'gemini-pro',
contents: [{ text: 'Hello' }, { text: 'World' }],
};
const codeAssistReq = toCodeAssistRequest(genaiReq);
const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'user', parts: [{ text: 'World' }] },
@ -94,7 +94,7 @@ describe('converter', () => {
systemInstruction: 'You are a helpful assistant.',
},
};
const codeAssistReq = toCodeAssistRequest(genaiReq);
const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.systemInstruction).toEqual({
role: 'user',
parts: [{ text: 'You are a helpful assistant.' }],
@ -110,7 +110,7 @@ describe('converter', () => {
topK: 40,
},
};
const codeAssistReq = toCodeAssistRequest(genaiReq);
const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.8,
topK: 40,
@ -136,7 +136,7 @@ describe('converter', () => {
responseMimeType: 'application/json',
},
};
const codeAssistReq = toCodeAssistRequest(genaiReq);
const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.1,
topP: 0.2,
@ -156,7 +156,7 @@ describe('converter', () => {
describe('fromCodeAssistResponse', () => {
it('should convert a simple response', () => {
const codeAssistRes: CodeAssistResponse = {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [
{
@ -171,13 +171,13 @@ describe('converter', () => {
],
},
};
const genaiRes = fromCodeAsistResponse(codeAssistRes);
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
});
it('should handle prompt feedback and usage metadata', () => {
const codeAssistRes: CodeAssistResponse = {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
promptFeedback: {
@ -191,7 +191,7 @@ describe('converter', () => {
},
},
};
const genaiRes = fromCodeAsistResponse(codeAssistRes);
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.promptFeedback).toEqual(
codeAssistRes.response.promptFeedback,
);
@ -201,7 +201,7 @@ describe('converter', () => {
});
it('should handle automatic function calling history', () => {
const codeAssistRes: CodeAssistResponse = {
const codeAssistRes: CaGenerateContentResponse = {
response: {
candidates: [],
automaticFunctionCallingHistory: [
@ -221,7 +221,7 @@ describe('converter', () => {
],
},
};
const genaiRes = fromCodeAsistResponse(codeAssistRes);
const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
codeAssistRes.response.automaticFunctionCallingHistory,
);

View File

@ -10,6 +10,8 @@ import {
ContentUnion,
GenerateContentConfig,
GenerateContentParameters,
CountTokensParameters,
CountTokensResponse,
GenerateContentResponse,
GenerationConfigRoutingConfig,
MediaResolution,
@ -27,13 +29,13 @@ import {
ToolConfig,
} from '@google/genai';
export interface CodeAssistRequest {
export interface CAGenerateContentRequest {
model: string;
project?: string;
request: CodeAssistGenerateContentRequest;
request: VertexGenerateContentRequest;
}
interface CodeAssistGenerateContentRequest {
interface VertexGenerateContentRequest {
contents: Content[];
systemInstruction?: Content;
cachedContent?: string;
@ -41,10 +43,10 @@ interface CodeAssistGenerateContentRequest {
toolConfig?: ToolConfig;
labels?: Record<string, string>;
safetySettings?: SafetySetting[];
generationConfig?: CodeAssistGenerationConfig;
generationConfig?: VertexGenerationConfig;
}
interface CodeAssistGenerationConfig {
interface VertexGenerationConfig {
temperature?: number;
topP?: number;
topK?: number;
@ -67,30 +69,61 @@ interface CodeAssistGenerationConfig {
thinkingConfig?: ThinkingConfig;
}
export interface CodeAssistResponse {
response: VertexResponse;
export interface CaGenerateContentResponse {
response: VertexGenerateContentResponse;
}
interface VertexResponse {
interface VertexGenerateContentResponse {
candidates: Candidate[];
automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata;
}
export interface CaCountTokenRequest {
request: VertexCountTokenRequest;
}
export function toCodeAssistRequest(
req: GenerateContentParameters,
project?: string,
): CodeAssistRequest {
interface VertexCountTokenRequest {
model: string;
contents: Content[];
}
export interface CaCountTokenResponse {
totalTokens: number;
}
export function toCountTokenRequest(
req: CountTokensParameters,
): CaCountTokenRequest {
return {
model: req.model,
project,
request: toCodeAssistGenerateContentRequest(req),
request: {
model: 'models/' + req.model,
contents: toContents(req.contents),
},
};
}
export function fromCodeAsistResponse(
res: CodeAssistResponse,
export function fromCountTokenResponse(
res: CaCountTokenResponse,
): CountTokensResponse {
return {
totalTokens: res.totalTokens,
};
}
export function toGenerateContentRequest(
req: GenerateContentParameters,
project?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
request: toVertexGenerateContentRequest(req),
};
}
export function fromGenerateContentResponse(
res: CaGenerateContentResponse,
): GenerateContentResponse {
const inres = res.response;
const out = new GenerateContentResponse();
@ -101,9 +134,9 @@ export function fromCodeAsistResponse(
return out;
}
function toCodeAssistGenerateContentRequest(
function toVertexGenerateContentRequest(
req: GenerateContentParameters,
): CodeAssistGenerateContentRequest {
): VertexGenerateContentRequest {
return {
contents: toContents(req.contents),
systemInstruction: maybeToContent(req.config?.systemInstruction),
@ -112,7 +145,7 @@ function toCodeAssistGenerateContentRequest(
toolConfig: req.config?.toolConfig,
labels: req.config?.labels,
safetySettings: req.config?.safetySettings,
generationConfig: toCodeAssistGenerationConfig(req.config),
generationConfig: toVertexGenerationConfig(req.config),
};
}
@ -170,9 +203,9 @@ function toPart(part: PartUnion): Part {
return part;
}
function toCodeAssistGenerationConfig(
function toVertexGenerationConfig(
config?: GenerateContentConfig,
): CodeAssistGenerationConfig | undefined {
): VertexGenerationConfig | undefined {
if (!config) {
return undefined;
}

View File

@ -133,11 +133,16 @@ describe('CodeAssistServer', () => {
it('should return 0 for countTokens', async () => {
const auth = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project');
const mockResponse = {
totalTokens: 100,
};
vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
const response = await server.countTokens({
model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
});
expect(response.totalTokens).toBe(0);
expect(response.totalTokens).toBe(100);
});
it('should throw an error for embedContent', async () => {

View File

@ -22,9 +22,12 @@ import {
import * as readline from 'readline';
import { ContentGenerator } from '../core/contentGenerator.js';
import {
CodeAssistResponse,
toCodeAssistRequest,
fromCodeAsistResponse,
CaGenerateContentResponse,
toGenerateContentRequest,
fromGenerateContentResponse,
toCountTokenRequest,
fromCountTokenResponse,
CaCountTokenResponse,
} from './converter.js';
import { PassThrough } from 'node:stream';
@ -50,14 +53,14 @@ export class CodeAssistServer implements ContentGenerator {
async generateContentStream(
req: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.streamEndpoint<CodeAssistResponse>(
const resps = await this.streamEndpoint<CaGenerateContentResponse>(
'streamGenerateContent',
toCodeAssistRequest(req, this.projectId),
toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal,
);
return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) {
yield fromCodeAsistResponse(resp);
yield fromGenerateContentResponse(resp);
}
})();
}
@ -65,12 +68,12 @@ export class CodeAssistServer implements ContentGenerator {
async generateContent(
req: GenerateContentParameters,
): Promise<GenerateContentResponse> {
const resp = await this.callEndpoint<CodeAssistResponse>(
const resp = await this.callEndpoint<CaGenerateContentResponse>(
'generateContent',
toCodeAssistRequest(req, this.projectId),
toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal,
);
return fromCodeAsistResponse(resp);
return fromGenerateContentResponse(resp);
}
async onboardUser(
@ -91,8 +94,12 @@ export class CodeAssistServer implements ContentGenerator {
);
}
async countTokens(_req: CountTokensParameters): Promise<CountTokensResponse> {
return { totalTokens: 0 };
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
const resp = await this.callEndpoint<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
return fromCountTokenResponse(resp);
}
async embedContent(