CCPA Count Token support (#1170)

This commit is contained in:
Tommaso Sciortino 2025-06-18 10:29:42 -07:00 committed by GitHub
parent 332512853e
commit 4662b058e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 50 deletions

View File

@ -6,9 +6,9 @@
import { describe, it, expect } from 'vitest'; import { describe, it, expect } from 'vitest';
import { import {
toCodeAssistRequest, toGenerateContentRequest,
fromCodeAsistResponse, fromGenerateContentResponse,
CodeAssistResponse, CaGenerateContentResponse,
} from './converter.js'; } from './converter.js';
import { import {
GenerateContentParameters, GenerateContentParameters,
@ -24,7 +24,7 @@ describe('converter', () => {
model: 'gemini-pro', model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq, 'my-project'); const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
expect(codeAssistReq).toEqual({ expect(codeAssistReq).toEqual({
model: 'gemini-pro', model: 'gemini-pro',
project: 'my-project', project: 'my-project',
@ -46,7 +46,7 @@ describe('converter', () => {
model: 'gemini-pro', model: 'gemini-pro',
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq); const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq).toEqual({ expect(codeAssistReq).toEqual({
model: 'gemini-pro', model: 'gemini-pro',
project: undefined, project: undefined,
@ -68,7 +68,7 @@ describe('converter', () => {
model: 'gemini-pro', model: 'gemini-pro',
contents: 'Hello', contents: 'Hello',
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq); const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.contents).toEqual([ expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'Hello' }] },
]); ]);
@ -79,7 +79,7 @@ describe('converter', () => {
model: 'gemini-pro', model: 'gemini-pro',
contents: [{ text: 'Hello' }, { text: 'World' }], contents: [{ text: 'Hello' }, { text: 'World' }],
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq); const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.contents).toEqual([ expect(codeAssistReq.request.contents).toEqual([
{ role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'user', parts: [{ text: 'World' }] }, { role: 'user', parts: [{ text: 'World' }] },
@ -94,7 +94,7 @@ describe('converter', () => {
systemInstruction: 'You are a helpful assistant.', systemInstruction: 'You are a helpful assistant.',
}, },
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq); const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.systemInstruction).toEqual({ expect(codeAssistReq.request.systemInstruction).toEqual({
role: 'user', role: 'user',
parts: [{ text: 'You are a helpful assistant.' }], parts: [{ text: 'You are a helpful assistant.' }],
@ -110,7 +110,7 @@ describe('converter', () => {
topK: 40, topK: 40,
}, },
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq); const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.generationConfig).toEqual({ expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.8, temperature: 0.8,
topK: 40, topK: 40,
@ -136,7 +136,7 @@ describe('converter', () => {
responseMimeType: 'application/json', responseMimeType: 'application/json',
}, },
}; };
const codeAssistReq = toCodeAssistRequest(genaiReq); const codeAssistReq = toGenerateContentRequest(genaiReq);
expect(codeAssistReq.request.generationConfig).toEqual({ expect(codeAssistReq.request.generationConfig).toEqual({
temperature: 0.1, temperature: 0.1,
topP: 0.2, topP: 0.2,
@ -156,7 +156,7 @@ describe('converter', () => {
describe('fromCodeAssistResponse', () => { describe('fromCodeAssistResponse', () => {
it('should convert a simple response', () => { it('should convert a simple response', () => {
const codeAssistRes: CodeAssistResponse = { const codeAssistRes: CaGenerateContentResponse = {
response: { response: {
candidates: [ candidates: [
{ {
@ -171,13 +171,13 @@ describe('converter', () => {
], ],
}, },
}; };
const genaiRes = fromCodeAsistResponse(codeAssistRes); const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes).toBeInstanceOf(GenerateContentResponse); expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates); expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
}); });
it('should handle prompt feedback and usage metadata', () => { it('should handle prompt feedback and usage metadata', () => {
const codeAssistRes: CodeAssistResponse = { const codeAssistRes: CaGenerateContentResponse = {
response: { response: {
candidates: [], candidates: [],
promptFeedback: { promptFeedback: {
@ -191,7 +191,7 @@ describe('converter', () => {
}, },
}, },
}; };
const genaiRes = fromCodeAsistResponse(codeAssistRes); const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.promptFeedback).toEqual( expect(genaiRes.promptFeedback).toEqual(
codeAssistRes.response.promptFeedback, codeAssistRes.response.promptFeedback,
); );
@ -201,7 +201,7 @@ describe('converter', () => {
}); });
it('should handle automatic function calling history', () => { it('should handle automatic function calling history', () => {
const codeAssistRes: CodeAssistResponse = { const codeAssistRes: CaGenerateContentResponse = {
response: { response: {
candidates: [], candidates: [],
automaticFunctionCallingHistory: [ automaticFunctionCallingHistory: [
@ -221,7 +221,7 @@ describe('converter', () => {
], ],
}, },
}; };
const genaiRes = fromCodeAsistResponse(codeAssistRes); const genaiRes = fromGenerateContentResponse(codeAssistRes);
expect(genaiRes.automaticFunctionCallingHistory).toEqual( expect(genaiRes.automaticFunctionCallingHistory).toEqual(
codeAssistRes.response.automaticFunctionCallingHistory, codeAssistRes.response.automaticFunctionCallingHistory,
); );

View File

@ -10,6 +10,8 @@ import {
ContentUnion, ContentUnion,
GenerateContentConfig, GenerateContentConfig,
GenerateContentParameters, GenerateContentParameters,
CountTokensParameters,
CountTokensResponse,
GenerateContentResponse, GenerateContentResponse,
GenerationConfigRoutingConfig, GenerationConfigRoutingConfig,
MediaResolution, MediaResolution,
@ -27,13 +29,13 @@ import {
ToolConfig, ToolConfig,
} from '@google/genai'; } from '@google/genai';
export interface CodeAssistRequest { export interface CAGenerateContentRequest {
model: string; model: string;
project?: string; project?: string;
request: CodeAssistGenerateContentRequest; request: VertexGenerateContentRequest;
} }
interface CodeAssistGenerateContentRequest { interface VertexGenerateContentRequest {
contents: Content[]; contents: Content[];
systemInstruction?: Content; systemInstruction?: Content;
cachedContent?: string; cachedContent?: string;
@ -41,10 +43,10 @@ interface CodeAssistGenerateContentRequest {
toolConfig?: ToolConfig; toolConfig?: ToolConfig;
labels?: Record<string, string>; labels?: Record<string, string>;
safetySettings?: SafetySetting[]; safetySettings?: SafetySetting[];
generationConfig?: CodeAssistGenerationConfig; generationConfig?: VertexGenerationConfig;
} }
interface CodeAssistGenerationConfig { interface VertexGenerationConfig {
temperature?: number; temperature?: number;
topP?: number; topP?: number;
topK?: number; topK?: number;
@ -67,30 +69,61 @@ interface CodeAssistGenerationConfig {
thinkingConfig?: ThinkingConfig; thinkingConfig?: ThinkingConfig;
} }
export interface CodeAssistResponse { export interface CaGenerateContentResponse {
response: VertexResponse; response: VertexGenerateContentResponse;
} }
interface VertexResponse { interface VertexGenerateContentResponse {
candidates: Candidate[]; candidates: Candidate[];
automaticFunctionCallingHistory?: Content[]; automaticFunctionCallingHistory?: Content[];
promptFeedback?: GenerateContentResponsePromptFeedback; promptFeedback?: GenerateContentResponsePromptFeedback;
usageMetadata?: GenerateContentResponseUsageMetadata; usageMetadata?: GenerateContentResponseUsageMetadata;
} }
export interface CaCountTokenRequest {
request: VertexCountTokenRequest;
}
export function toCodeAssistRequest( interface VertexCountTokenRequest {
req: GenerateContentParameters, model: string;
project?: string, contents: Content[];
): CodeAssistRequest { }
export interface CaCountTokenResponse {
totalTokens: number;
}
export function toCountTokenRequest(
req: CountTokensParameters,
): CaCountTokenRequest {
return { return {
model: req.model, request: {
project, model: 'models/' + req.model,
request: toCodeAssistGenerateContentRequest(req), contents: toContents(req.contents),
},
}; };
} }
export function fromCodeAsistResponse( export function fromCountTokenResponse(
res: CodeAssistResponse, res: CaCountTokenResponse,
): CountTokensResponse {
return {
totalTokens: res.totalTokens,
};
}
export function toGenerateContentRequest(
req: GenerateContentParameters,
project?: string,
): CAGenerateContentRequest {
return {
model: req.model,
project,
request: toVertexGenerateContentRequest(req),
};
}
export function fromGenerateContentResponse(
res: CaGenerateContentResponse,
): GenerateContentResponse { ): GenerateContentResponse {
const inres = res.response; const inres = res.response;
const out = new GenerateContentResponse(); const out = new GenerateContentResponse();
@ -101,9 +134,9 @@ export function fromCodeAsistResponse(
return out; return out;
} }
function toCodeAssistGenerateContentRequest( function toVertexGenerateContentRequest(
req: GenerateContentParameters, req: GenerateContentParameters,
): CodeAssistGenerateContentRequest { ): VertexGenerateContentRequest {
return { return {
contents: toContents(req.contents), contents: toContents(req.contents),
systemInstruction: maybeToContent(req.config?.systemInstruction), systemInstruction: maybeToContent(req.config?.systemInstruction),
@ -112,7 +145,7 @@ function toCodeAssistGenerateContentRequest(
toolConfig: req.config?.toolConfig, toolConfig: req.config?.toolConfig,
labels: req.config?.labels, labels: req.config?.labels,
safetySettings: req.config?.safetySettings, safetySettings: req.config?.safetySettings,
generationConfig: toCodeAssistGenerationConfig(req.config), generationConfig: toVertexGenerationConfig(req.config),
}; };
} }
@ -170,9 +203,9 @@ function toPart(part: PartUnion): Part {
return part; return part;
} }
function toCodeAssistGenerationConfig( function toVertexGenerationConfig(
config?: GenerateContentConfig, config?: GenerateContentConfig,
): CodeAssistGenerationConfig | undefined { ): VertexGenerationConfig | undefined {
if (!config) { if (!config) {
return undefined; return undefined;
} }

View File

@ -133,11 +133,16 @@ describe('CodeAssistServer', () => {
it('should return 0 for countTokens', async () => { it('should return 0 for countTokens', async () => {
const auth = new OAuth2Client(); const auth = new OAuth2Client();
const server = new CodeAssistServer(auth, 'test-project'); const server = new CodeAssistServer(auth, 'test-project');
const mockResponse = {
totalTokens: 100,
};
vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
const response = await server.countTokens({ const response = await server.countTokens({
model: 'test-model', model: 'test-model',
contents: [{ role: 'user', parts: [{ text: 'request' }] }], contents: [{ role: 'user', parts: [{ text: 'request' }] }],
}); });
expect(response.totalTokens).toBe(0); expect(response.totalTokens).toBe(100);
}); });
it('should throw an error for embedContent', async () => { it('should throw an error for embedContent', async () => {

View File

@ -22,9 +22,12 @@ import {
import * as readline from 'readline'; import * as readline from 'readline';
import { ContentGenerator } from '../core/contentGenerator.js'; import { ContentGenerator } from '../core/contentGenerator.js';
import { import {
CodeAssistResponse, CaGenerateContentResponse,
toCodeAssistRequest, toGenerateContentRequest,
fromCodeAsistResponse, fromGenerateContentResponse,
toCountTokenRequest,
fromCountTokenResponse,
CaCountTokenResponse,
} from './converter.js'; } from './converter.js';
import { PassThrough } from 'node:stream'; import { PassThrough } from 'node:stream';
@ -50,14 +53,14 @@ export class CodeAssistServer implements ContentGenerator {
async generateContentStream( async generateContentStream(
req: GenerateContentParameters, req: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>> { ): Promise<AsyncGenerator<GenerateContentResponse>> {
const resps = await this.streamEndpoint<CodeAssistResponse>( const resps = await this.streamEndpoint<CaGenerateContentResponse>(
'streamGenerateContent', 'streamGenerateContent',
toCodeAssistRequest(req, this.projectId), toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal, req.config?.abortSignal,
); );
return (async function* (): AsyncGenerator<GenerateContentResponse> { return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) { for await (const resp of resps) {
yield fromCodeAsistResponse(resp); yield fromGenerateContentResponse(resp);
} }
})(); })();
} }
@ -65,12 +68,12 @@ export class CodeAssistServer implements ContentGenerator {
async generateContent( async generateContent(
req: GenerateContentParameters, req: GenerateContentParameters,
): Promise<GenerateContentResponse> { ): Promise<GenerateContentResponse> {
const resp = await this.callEndpoint<CodeAssistResponse>( const resp = await this.callEndpoint<CaGenerateContentResponse>(
'generateContent', 'generateContent',
toCodeAssistRequest(req, this.projectId), toGenerateContentRequest(req, this.projectId),
req.config?.abortSignal, req.config?.abortSignal,
); );
return fromCodeAsistResponse(resp); return fromGenerateContentResponse(resp);
} }
async onboardUser( async onboardUser(
@ -91,8 +94,12 @@ export class CodeAssistServer implements ContentGenerator {
); );
} }
async countTokens(_req: CountTokensParameters): Promise<CountTokensResponse> { async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
return { totalTokens: 0 }; const resp = await this.callEndpoint<CaCountTokenResponse>(
'countTokens',
toCountTokenRequest(req),
);
return fromCountTokenResponse(resp);
} }
async embedContent( async embedContent(