From 4662b058e8eda8588aa2ab272820da6ab0b3f7cd Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Wed, 18 Jun 2025 10:29:42 -0700 Subject: [PATCH] CCPA Count Token support (#1170) --- .../core/src/code_assist/converter.test.ts | 32 ++++---- packages/core/src/code_assist/converter.ts | 77 +++++++++++++------ packages/core/src/code_assist/server.test.ts | 7 +- packages/core/src/code_assist/server.ts | 29 ++++--- 4 files changed, 95 insertions(+), 50 deletions(-) diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts index d0c05015..2170c960 100644 --- a/packages/core/src/code_assist/converter.test.ts +++ b/packages/core/src/code_assist/converter.test.ts @@ -6,9 +6,9 @@ import { describe, it, expect } from 'vitest'; import { - toCodeAssistRequest, - fromCodeAsistResponse, - CodeAssistResponse, + toGenerateContentRequest, + fromGenerateContentResponse, + CaGenerateContentResponse, } from './converter.js'; import { GenerateContentParameters, @@ -24,7 +24,7 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toCodeAssistRequest(genaiReq, 'my-project'); + const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project'); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: 'my-project', @@ -46,7 +46,7 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], }; - const codeAssistReq = toCodeAssistRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq).toEqual({ model: 'gemini-pro', project: undefined, @@ -68,7 +68,7 @@ describe('converter', () => { model: 'gemini-pro', contents: 'Hello', }; - const codeAssistReq = toCodeAssistRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, ]); @@ -79,7 +79,7 @@ describe('converter', () => { model: 'gemini-pro', contents: [{ text: 'Hello' }, { text: 'World' }], }; - const codeAssistReq = toCodeAssistRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.contents).toEqual([ { role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'World' }] }, @@ -94,7 +94,7 @@ describe('converter', () => { systemInstruction: 'You are a helpful assistant.', }, }; - const codeAssistReq = toCodeAssistRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.systemInstruction).toEqual({ role: 'user', parts: [{ text: 'You are a helpful assistant.' }], @@ -110,7 +110,7 @@ describe('converter', () => { topK: 40, }, }; - const codeAssistReq = toCodeAssistRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.8, topK: 40, @@ -136,7 +136,7 @@ describe('converter', () => { responseMimeType: 'application/json', }, }; - const codeAssistReq = toCodeAssistRequest(genaiReq); + const codeAssistReq = toGenerateContentRequest(genaiReq); expect(codeAssistReq.request.generationConfig).toEqual({ temperature: 0.1, topP: 0.2, @@ -156,7 +156,7 @@ describe('converter', () => { describe('fromCodeAssistResponse', () => { it('should convert a simple response', () => { - const codeAssistRes: CodeAssistResponse = { + const codeAssistRes: CaGenerateContentResponse = { response: { candidates: [ { @@ -171,13 +171,13 @@ describe('converter', () => { ], }, }; - const genaiRes = fromCodeAsistResponse(codeAssistRes); + const genaiRes = fromGenerateContentResponse(codeAssistRes); expect(genaiRes).toBeInstanceOf(GenerateContentResponse); expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates); }); it('should handle prompt feedback and usage metadata', () => { - const codeAssistRes: CodeAssistResponse = { + const codeAssistRes: CaGenerateContentResponse = { response: { candidates: [], promptFeedback: { @@ -191,7 +191,7 @@ describe('converter', () => { }, }, }; - const genaiRes = fromCodeAsistResponse(codeAssistRes); + const genaiRes = fromGenerateContentResponse(codeAssistRes); expect(genaiRes.promptFeedback).toEqual( codeAssistRes.response.promptFeedback, ); @@ -201,7 +201,7 @@ describe('converter', () => { }); it('should handle automatic function calling history', () => { - const codeAssistRes: CodeAssistResponse = { + const codeAssistRes: CaGenerateContentResponse = { response: { candidates: [], automaticFunctionCallingHistory: [ @@ -221,7 +221,7 @@ describe('converter', () => { ], }, }; - const genaiRes = fromCodeAsistResponse(codeAssistRes); + const genaiRes = fromGenerateContentResponse(codeAssistRes); expect(genaiRes.automaticFunctionCallingHistory).toEqual( codeAssistRes.response.automaticFunctionCallingHistory, ); diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index 495cbfae..b9b854fc 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -10,6 +10,8 @@ import { ContentUnion, GenerateContentConfig, GenerateContentParameters, + CountTokensParameters, + CountTokensResponse, GenerateContentResponse, GenerationConfigRoutingConfig, MediaResolution, @@ -27,13 +29,13 @@ import { ToolConfig, } from '@google/genai'; -export interface CodeAssistRequest { +export interface CAGenerateContentRequest { model: string; project?: string; - request: CodeAssistGenerateContentRequest; + request: VertexGenerateContentRequest; } -interface CodeAssistGenerateContentRequest { +interface VertexGenerateContentRequest { contents: Content[]; systemInstruction?: Content; cachedContent?: string; @@ -41,10 +43,10 @@ interface CodeAssistGenerateContentRequest { toolConfig?: ToolConfig; labels?: Record; safetySettings?: SafetySetting[]; - generationConfig?: CodeAssistGenerationConfig; + generationConfig?: VertexGenerationConfig; } -interface CodeAssistGenerationConfig { +interface VertexGenerationConfig { temperature?: number; topP?: number; topK?: number; @@ -67,30 +69,61 @@ interface CodeAssistGenerationConfig { thinkingConfig?: ThinkingConfig; } -export interface CodeAssistResponse { - response: VertexResponse; +export interface CaGenerateContentResponse { + response: VertexGenerateContentResponse; } -interface VertexResponse { +interface VertexGenerateContentResponse { candidates: Candidate[]; automaticFunctionCallingHistory?: Content[]; promptFeedback?: GenerateContentResponsePromptFeedback; usageMetadata?: GenerateContentResponseUsageMetadata; } +export interface CaCountTokenRequest { + request: VertexCountTokenRequest; +} -export function toCodeAssistRequest( - req: GenerateContentParameters, - project?: string, -): CodeAssistRequest { +interface VertexCountTokenRequest { + model: string; + contents: Content[]; +} + +export interface CaCountTokenResponse { + totalTokens: number; +} + +export function toCountTokenRequest( + req: CountTokensParameters, +): CaCountTokenRequest { return { - model: req.model, - project, - request: toCodeAssistGenerateContentRequest(req), + request: { + model: 'models/' + req.model, + contents: toContents(req.contents), + }, }; } -export function fromCodeAsistResponse( - res: CodeAssistResponse, +export function fromCountTokenResponse( + res: CaCountTokenResponse, +): CountTokensResponse { + return { + totalTokens: res.totalTokens, + }; +} + +export function toGenerateContentRequest( + req: GenerateContentParameters, + project?: string, +): CAGenerateContentRequest { + return { + model: req.model, + project, + request: toVertexGenerateContentRequest(req), + }; +} + +export function fromGenerateContentResponse( + res: CaGenerateContentResponse, ): GenerateContentResponse { const inres = res.response; const out = new GenerateContentResponse(); @@ -101,9 +134,9 @@ export function fromCodeAsistResponse( return out; } -function toCodeAssistGenerateContentRequest( +function toVertexGenerateContentRequest( req: GenerateContentParameters, -): CodeAssistGenerateContentRequest { +): VertexGenerateContentRequest { return { contents: toContents(req.contents), systemInstruction: maybeToContent(req.config?.systemInstruction), @@ -112,7 +145,7 @@ function toCodeAssistGenerateContentRequest( toolConfig: req.config?.toolConfig, labels: req.config?.labels, safetySettings: req.config?.safetySettings, - generationConfig: toCodeAssistGenerationConfig(req.config), + generationConfig: toVertexGenerationConfig(req.config), }; } @@ -170,9 +203,9 @@ function toPart(part: PartUnion): Part { return part; } -function toCodeAssistGenerationConfig( +function toVertexGenerationConfig( config?: GenerateContentConfig, -): CodeAssistGenerationConfig | undefined { +): VertexGenerationConfig | undefined { if (!config) { return undefined; } diff --git a/packages/core/src/code_assist/server.test.ts b/packages/core/src/code_assist/server.test.ts index 922d20fb..d8d9c10a 100644 --- a/packages/core/src/code_assist/server.test.ts +++ b/packages/core/src/code_assist/server.test.ts @@ -133,11 +133,16 @@ describe('CodeAssistServer', () => { it('should return 0 for countTokens', async () => { const auth = new OAuth2Client(); const server = new CodeAssistServer(auth, 'test-project'); + const mockResponse = { + totalTokens: 100, + }; + vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse); + const response = await server.countTokens({ model: 'test-model', contents: [{ role: 'user', parts: [{ text: 'request' }] }], }); - expect(response.totalTokens).toBe(0); + expect(response.totalTokens).toBe(100); }); it('should throw an error for embedContent', async () => { diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index d700353c..4f8bb643 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -22,9 +22,12 @@ import { import * as readline from 'readline'; import { ContentGenerator } from '../core/contentGenerator.js'; import { - CodeAssistResponse, - toCodeAssistRequest, - fromCodeAsistResponse, + CaGenerateContentResponse, + toGenerateContentRequest, + fromGenerateContentResponse, + toCountTokenRequest, + fromCountTokenResponse, + CaCountTokenResponse, } from './converter.js'; import { PassThrough } from 'node:stream'; @@ -50,14 +53,14 @@ export class CodeAssistServer implements ContentGenerator { async generateContentStream( req: GenerateContentParameters, ): Promise> { - const resps = await this.streamEndpoint( + const resps = await this.streamEndpoint( 'streamGenerateContent', - toCodeAssistRequest(req, this.projectId), + toGenerateContentRequest(req, this.projectId), req.config?.abortSignal, ); return (async function* (): AsyncGenerator { for await (const resp of resps) { - yield fromCodeAsistResponse(resp); + yield fromGenerateContentResponse(resp); } })(); } @@ -65,12 +68,12 @@ export class CodeAssistServer implements ContentGenerator { async generateContent( req: GenerateContentParameters, ): Promise { - const resp = await this.callEndpoint( + const resp = await this.callEndpoint( 'generateContent', - toCodeAssistRequest(req, this.projectId), + toGenerateContentRequest(req, this.projectId), req.config?.abortSignal, ); - return fromCodeAsistResponse(resp); + return fromGenerateContentResponse(resp); } async onboardUser( @@ -91,8 +94,12 @@ export class CodeAssistServer implements ContentGenerator { ); } - async countTokens(_req: CountTokensParameters): Promise { - return { totalTokens: 0 }; + async countTokens(req: CountTokensParameters): Promise { + const resp = await this.callEndpoint( + 'countTokens', + toCountTokenRequest(req), + ); + return fromCountTokenResponse(resp); } async embedContent(