CCPA Count Token support (#1170)
This commit is contained in:
parent
332512853e
commit
4662b058e8
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
import { describe, it, expect } from 'vitest';
|
import { describe, it, expect } from 'vitest';
|
||||||
import {
|
import {
|
||||||
toCodeAssistRequest,
|
toGenerateContentRequest,
|
||||||
fromCodeAsistResponse,
|
fromGenerateContentResponse,
|
||||||
CodeAssistResponse,
|
CaGenerateContentResponse,
|
||||||
} from './converter.js';
|
} from './converter.js';
|
||||||
import {
|
import {
|
||||||
GenerateContentParameters,
|
GenerateContentParameters,
|
||||||
|
@ -24,7 +24,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq, 'my-project');
|
const codeAssistReq = toGenerateContentRequest(genaiReq, 'my-project');
|
||||||
expect(codeAssistReq).toEqual({
|
expect(codeAssistReq).toEqual({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
project: 'my-project',
|
project: 'my-project',
|
||||||
|
@ -46,7 +46,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
expect(codeAssistReq).toEqual({
|
expect(codeAssistReq).toEqual({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
project: undefined,
|
project: undefined,
|
||||||
|
@ -68,7 +68,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: 'Hello',
|
contents: 'Hello',
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
expect(codeAssistReq.request.contents).toEqual([
|
expect(codeAssistReq.request.contents).toEqual([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
]);
|
]);
|
||||||
|
@ -79,7 +79,7 @@ describe('converter', () => {
|
||||||
model: 'gemini-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ text: 'Hello' }, { text: 'World' }],
|
contents: [{ text: 'Hello' }, { text: 'World' }],
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
expect(codeAssistReq.request.contents).toEqual([
|
expect(codeAssistReq.request.contents).toEqual([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
{ role: 'user', parts: [{ text: 'World' }] },
|
{ role: 'user', parts: [{ text: 'World' }] },
|
||||||
|
@ -94,7 +94,7 @@ describe('converter', () => {
|
||||||
systemInstruction: 'You are a helpful assistant.',
|
systemInstruction: 'You are a helpful assistant.',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
expect(codeAssistReq.request.systemInstruction).toEqual({
|
expect(codeAssistReq.request.systemInstruction).toEqual({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: 'You are a helpful assistant.' }],
|
parts: [{ text: 'You are a helpful assistant.' }],
|
||||||
|
@ -110,7 +110,7 @@ describe('converter', () => {
|
||||||
topK: 40,
|
topK: 40,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||||
temperature: 0.8,
|
temperature: 0.8,
|
||||||
topK: 40,
|
topK: 40,
|
||||||
|
@ -136,7 +136,7 @@ describe('converter', () => {
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const codeAssistReq = toCodeAssistRequest(genaiReq);
|
const codeAssistReq = toGenerateContentRequest(genaiReq);
|
||||||
expect(codeAssistReq.request.generationConfig).toEqual({
|
expect(codeAssistReq.request.generationConfig).toEqual({
|
||||||
temperature: 0.1,
|
temperature: 0.1,
|
||||||
topP: 0.2,
|
topP: 0.2,
|
||||||
|
@ -156,7 +156,7 @@ describe('converter', () => {
|
||||||
|
|
||||||
describe('fromCodeAssistResponse', () => {
|
describe('fromCodeAssistResponse', () => {
|
||||||
it('should convert a simple response', () => {
|
it('should convert a simple response', () => {
|
||||||
const codeAssistRes: CodeAssistResponse = {
|
const codeAssistRes: CaGenerateContentResponse = {
|
||||||
response: {
|
response: {
|
||||||
candidates: [
|
candidates: [
|
||||||
{
|
{
|
||||||
|
@ -171,13 +171,13 @@ describe('converter', () => {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const genaiRes = fromCodeAsistResponse(codeAssistRes);
|
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||||
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
|
expect(genaiRes).toBeInstanceOf(GenerateContentResponse);
|
||||||
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
|
expect(genaiRes.candidates).toEqual(codeAssistRes.response.candidates);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle prompt feedback and usage metadata', () => {
|
it('should handle prompt feedback and usage metadata', () => {
|
||||||
const codeAssistRes: CodeAssistResponse = {
|
const codeAssistRes: CaGenerateContentResponse = {
|
||||||
response: {
|
response: {
|
||||||
candidates: [],
|
candidates: [],
|
||||||
promptFeedback: {
|
promptFeedback: {
|
||||||
|
@ -191,7 +191,7 @@ describe('converter', () => {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const genaiRes = fromCodeAsistResponse(codeAssistRes);
|
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||||
expect(genaiRes.promptFeedback).toEqual(
|
expect(genaiRes.promptFeedback).toEqual(
|
||||||
codeAssistRes.response.promptFeedback,
|
codeAssistRes.response.promptFeedback,
|
||||||
);
|
);
|
||||||
|
@ -201,7 +201,7 @@ describe('converter', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle automatic function calling history', () => {
|
it('should handle automatic function calling history', () => {
|
||||||
const codeAssistRes: CodeAssistResponse = {
|
const codeAssistRes: CaGenerateContentResponse = {
|
||||||
response: {
|
response: {
|
||||||
candidates: [],
|
candidates: [],
|
||||||
automaticFunctionCallingHistory: [
|
automaticFunctionCallingHistory: [
|
||||||
|
@ -221,7 +221,7 @@ describe('converter', () => {
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const genaiRes = fromCodeAsistResponse(codeAssistRes);
|
const genaiRes = fromGenerateContentResponse(codeAssistRes);
|
||||||
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
|
expect(genaiRes.automaticFunctionCallingHistory).toEqual(
|
||||||
codeAssistRes.response.automaticFunctionCallingHistory,
|
codeAssistRes.response.automaticFunctionCallingHistory,
|
||||||
);
|
);
|
||||||
|
|
|
@ -10,6 +10,8 @@ import {
|
||||||
ContentUnion,
|
ContentUnion,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GenerateContentParameters,
|
GenerateContentParameters,
|
||||||
|
CountTokensParameters,
|
||||||
|
CountTokensResponse,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
GenerationConfigRoutingConfig,
|
GenerationConfigRoutingConfig,
|
||||||
MediaResolution,
|
MediaResolution,
|
||||||
|
@ -27,13 +29,13 @@ import {
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
|
|
||||||
export interface CodeAssistRequest {
|
export interface CAGenerateContentRequest {
|
||||||
model: string;
|
model: string;
|
||||||
project?: string;
|
project?: string;
|
||||||
request: CodeAssistGenerateContentRequest;
|
request: VertexGenerateContentRequest;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CodeAssistGenerateContentRequest {
|
interface VertexGenerateContentRequest {
|
||||||
contents: Content[];
|
contents: Content[];
|
||||||
systemInstruction?: Content;
|
systemInstruction?: Content;
|
||||||
cachedContent?: string;
|
cachedContent?: string;
|
||||||
|
@ -41,10 +43,10 @@ interface CodeAssistGenerateContentRequest {
|
||||||
toolConfig?: ToolConfig;
|
toolConfig?: ToolConfig;
|
||||||
labels?: Record<string, string>;
|
labels?: Record<string, string>;
|
||||||
safetySettings?: SafetySetting[];
|
safetySettings?: SafetySetting[];
|
||||||
generationConfig?: CodeAssistGenerationConfig;
|
generationConfig?: VertexGenerationConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface CodeAssistGenerationConfig {
|
interface VertexGenerationConfig {
|
||||||
temperature?: number;
|
temperature?: number;
|
||||||
topP?: number;
|
topP?: number;
|
||||||
topK?: number;
|
topK?: number;
|
||||||
|
@ -67,30 +69,61 @@ interface CodeAssistGenerationConfig {
|
||||||
thinkingConfig?: ThinkingConfig;
|
thinkingConfig?: ThinkingConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CodeAssistResponse {
|
export interface CaGenerateContentResponse {
|
||||||
response: VertexResponse;
|
response: VertexGenerateContentResponse;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface VertexResponse {
|
interface VertexGenerateContentResponse {
|
||||||
candidates: Candidate[];
|
candidates: Candidate[];
|
||||||
automaticFunctionCallingHistory?: Content[];
|
automaticFunctionCallingHistory?: Content[];
|
||||||
promptFeedback?: GenerateContentResponsePromptFeedback;
|
promptFeedback?: GenerateContentResponsePromptFeedback;
|
||||||
usageMetadata?: GenerateContentResponseUsageMetadata;
|
usageMetadata?: GenerateContentResponseUsageMetadata;
|
||||||
}
|
}
|
||||||
|
export interface CaCountTokenRequest {
|
||||||
|
request: VertexCountTokenRequest;
|
||||||
|
}
|
||||||
|
|
||||||
export function toCodeAssistRequest(
|
interface VertexCountTokenRequest {
|
||||||
req: GenerateContentParameters,
|
model: string;
|
||||||
project?: string,
|
contents: Content[];
|
||||||
): CodeAssistRequest {
|
}
|
||||||
|
|
||||||
|
export interface CaCountTokenResponse {
|
||||||
|
totalTokens: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function toCountTokenRequest(
|
||||||
|
req: CountTokensParameters,
|
||||||
|
): CaCountTokenRequest {
|
||||||
return {
|
return {
|
||||||
model: req.model,
|
request: {
|
||||||
project,
|
model: 'models/' + req.model,
|
||||||
request: toCodeAssistGenerateContentRequest(req),
|
contents: toContents(req.contents),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
export function fromCodeAsistResponse(
|
export function fromCountTokenResponse(
|
||||||
res: CodeAssistResponse,
|
res: CaCountTokenResponse,
|
||||||
|
): CountTokensResponse {
|
||||||
|
return {
|
||||||
|
totalTokens: res.totalTokens,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function toGenerateContentRequest(
|
||||||
|
req: GenerateContentParameters,
|
||||||
|
project?: string,
|
||||||
|
): CAGenerateContentRequest {
|
||||||
|
return {
|
||||||
|
model: req.model,
|
||||||
|
project,
|
||||||
|
request: toVertexGenerateContentRequest(req),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function fromGenerateContentResponse(
|
||||||
|
res: CaGenerateContentResponse,
|
||||||
): GenerateContentResponse {
|
): GenerateContentResponse {
|
||||||
const inres = res.response;
|
const inres = res.response;
|
||||||
const out = new GenerateContentResponse();
|
const out = new GenerateContentResponse();
|
||||||
|
@ -101,9 +134,9 @@ export function fromCodeAsistResponse(
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
function toCodeAssistGenerateContentRequest(
|
function toVertexGenerateContentRequest(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
): CodeAssistGenerateContentRequest {
|
): VertexGenerateContentRequest {
|
||||||
return {
|
return {
|
||||||
contents: toContents(req.contents),
|
contents: toContents(req.contents),
|
||||||
systemInstruction: maybeToContent(req.config?.systemInstruction),
|
systemInstruction: maybeToContent(req.config?.systemInstruction),
|
||||||
|
@ -112,7 +145,7 @@ function toCodeAssistGenerateContentRequest(
|
||||||
toolConfig: req.config?.toolConfig,
|
toolConfig: req.config?.toolConfig,
|
||||||
labels: req.config?.labels,
|
labels: req.config?.labels,
|
||||||
safetySettings: req.config?.safetySettings,
|
safetySettings: req.config?.safetySettings,
|
||||||
generationConfig: toCodeAssistGenerationConfig(req.config),
|
generationConfig: toVertexGenerationConfig(req.config),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,9 +203,9 @@ function toPart(part: PartUnion): Part {
|
||||||
return part;
|
return part;
|
||||||
}
|
}
|
||||||
|
|
||||||
function toCodeAssistGenerationConfig(
|
function toVertexGenerationConfig(
|
||||||
config?: GenerateContentConfig,
|
config?: GenerateContentConfig,
|
||||||
): CodeAssistGenerationConfig | undefined {
|
): VertexGenerationConfig | undefined {
|
||||||
if (!config) {
|
if (!config) {
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,11 +133,16 @@ describe('CodeAssistServer', () => {
|
||||||
it('should return 0 for countTokens', async () => {
|
it('should return 0 for countTokens', async () => {
|
||||||
const auth = new OAuth2Client();
|
const auth = new OAuth2Client();
|
||||||
const server = new CodeAssistServer(auth, 'test-project');
|
const server = new CodeAssistServer(auth, 'test-project');
|
||||||
|
const mockResponse = {
|
||||||
|
totalTokens: 100,
|
||||||
|
};
|
||||||
|
vi.spyOn(server, 'callEndpoint').mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
const response = await server.countTokens({
|
const response = await server.countTokens({
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'request' }] }],
|
||||||
});
|
});
|
||||||
expect(response.totalTokens).toBe(0);
|
expect(response.totalTokens).toBe(100);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should throw an error for embedContent', async () => {
|
it('should throw an error for embedContent', async () => {
|
||||||
|
|
|
@ -22,9 +22,12 @@ import {
|
||||||
import * as readline from 'readline';
|
import * as readline from 'readline';
|
||||||
import { ContentGenerator } from '../core/contentGenerator.js';
|
import { ContentGenerator } from '../core/contentGenerator.js';
|
||||||
import {
|
import {
|
||||||
CodeAssistResponse,
|
CaGenerateContentResponse,
|
||||||
toCodeAssistRequest,
|
toGenerateContentRequest,
|
||||||
fromCodeAsistResponse,
|
fromGenerateContentResponse,
|
||||||
|
toCountTokenRequest,
|
||||||
|
fromCountTokenResponse,
|
||||||
|
CaCountTokenResponse,
|
||||||
} from './converter.js';
|
} from './converter.js';
|
||||||
import { PassThrough } from 'node:stream';
|
import { PassThrough } from 'node:stream';
|
||||||
|
|
||||||
|
@ -50,14 +53,14 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
async generateContentStream(
|
async generateContentStream(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||||
const resps = await this.streamEndpoint<CodeAssistResponse>(
|
const resps = await this.streamEndpoint<CaGenerateContentResponse>(
|
||||||
'streamGenerateContent',
|
'streamGenerateContent',
|
||||||
toCodeAssistRequest(req, this.projectId),
|
toGenerateContentRequest(req, this.projectId),
|
||||||
req.config?.abortSignal,
|
req.config?.abortSignal,
|
||||||
);
|
);
|
||||||
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
return (async function* (): AsyncGenerator<GenerateContentResponse> {
|
||||||
for await (const resp of resps) {
|
for await (const resp of resps) {
|
||||||
yield fromCodeAsistResponse(resp);
|
yield fromGenerateContentResponse(resp);
|
||||||
}
|
}
|
||||||
})();
|
})();
|
||||||
}
|
}
|
||||||
|
@ -65,12 +68,12 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
async generateContent(
|
async generateContent(
|
||||||
req: GenerateContentParameters,
|
req: GenerateContentParameters,
|
||||||
): Promise<GenerateContentResponse> {
|
): Promise<GenerateContentResponse> {
|
||||||
const resp = await this.callEndpoint<CodeAssistResponse>(
|
const resp = await this.callEndpoint<CaGenerateContentResponse>(
|
||||||
'generateContent',
|
'generateContent',
|
||||||
toCodeAssistRequest(req, this.projectId),
|
toGenerateContentRequest(req, this.projectId),
|
||||||
req.config?.abortSignal,
|
req.config?.abortSignal,
|
||||||
);
|
);
|
||||||
return fromCodeAsistResponse(resp);
|
return fromGenerateContentResponse(resp);
|
||||||
}
|
}
|
||||||
|
|
||||||
async onboardUser(
|
async onboardUser(
|
||||||
|
@ -91,8 +94,12 @@ export class CodeAssistServer implements ContentGenerator {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
async countTokens(_req: CountTokensParameters): Promise<CountTokensResponse> {
|
async countTokens(req: CountTokensParameters): Promise<CountTokensResponse> {
|
||||||
return { totalTokens: 0 };
|
const resp = await this.callEndpoint<CaCountTokenResponse>(
|
||||||
|
'countTokens',
|
||||||
|
toCountTokenRequest(req),
|
||||||
|
);
|
||||||
|
return fromCountTokenResponse(resp);
|
||||||
}
|
}
|
||||||
|
|
||||||
async embedContent(
|
async embedContent(
|
||||||
|
|
Loading…
Reference in New Issue