diff --git a/packages/core/src/code_assist/ccpaServer.ts b/packages/core/src/code_assist/ccpaServer.ts index 7a542db4..3ef8b084 100644 --- a/packages/core/src/code_assist/ccpaServer.ts +++ b/packages/core/src/code_assist/ccpaServer.ts @@ -19,10 +19,10 @@ import { CountTokensResponse, EmbedContentParameters, } from '@google/genai'; -import { Readable } from 'stream'; import * as readline from 'readline'; -import type { ReadableStream } from 'node:stream/web'; import { ContentGenerator } from '../core/contentGenerator.js'; +import { CcpaResponse, toCcpaRequest, fromCcpaResponse } from './converter.js'; +import { PassThrough } from 'node:stream'; // TODO: Use production endpoint once it supports our methods. export const CCPA_ENDPOINT = @@ -38,19 +38,25 @@ export class CcpaServer implements ContentGenerator { async generateContentStream( req: GenerateContentParameters, ): Promise> { - return await this.streamEndpoint( + const resps = await this.streamEndpoint( 'streamGenerateContent', - req, + toCcpaRequest(req, this.projectId), ); + return (async function* (): AsyncGenerator { + for await (const resp of resps) { + yield fromCcpaResponse(resp); + } + })(); } async generateContent( req: GenerateContentParameters, ): Promise { - return await this.callEndpoint( + const resp = await this.callEndpoint( 'generateContent', - req, + toCcpaRequest(req, this.projectId), ); + return fromCcpaResponse(resp); } async onboardUser( @@ -92,11 +98,6 @@ export class CcpaServer implements ContentGenerator { responseType: 'json', body: JSON.stringify(req), }); - if (res.status !== 200) { - throw new Error( - `Failed to fetch from ${method}: ${res.status} ${res.data}`, - ); - } return res.data as T; } @@ -114,15 +115,10 @@ export class CcpaServer implements ContentGenerator { responseType: 'stream', body: JSON.stringify(req), }); - if (res.status !== 200) { - throw new Error( - `Failed to fetch from ${method}: ${res.status} ${res.data}`, - ); - } return (async function* (): AsyncGenerator { const rl = readline.createInterface({ - input: Readable.fromWeb(res.data as ReadableStream), + input: res.data as PassThrough, crlfDelay: Infinity, // Recognizes '\r\n' and '\n' as line breaks }); diff --git a/packages/core/src/code_assist/converter.test.ts b/packages/core/src/code_assist/converter.test.ts new file mode 100644 index 00000000..4536d65f --- /dev/null +++ b/packages/core/src/code_assist/converter.test.ts @@ -0,0 +1,222 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { toCcpaRequest, fromCcpaResponse, CcpaResponse } from './converter.js'; +import { + GenerateContentParameters, + GenerateContentResponse, + FinishReason, + BlockedReason, +} from '@google/genai'; + +describe('converter', () => { + describe('toCcpaRequest', () => { + it('should convert a simple request with project', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + const ccpaReq = toCcpaRequest(genaiReq, 'my-project'); + expect(ccpaReq).toEqual({ + model: 'gemini-pro', + project: 'my-project', + request: { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + systemInstruction: undefined, + cachedContent: undefined, + tools: undefined, + toolConfig: undefined, + labels: undefined, + safetySettings: undefined, + generationConfig: undefined, + }, + }); + }); + + it('should convert a request without a project', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + const ccpaReq = toCcpaRequest(genaiReq); + expect(ccpaReq).toEqual({ + model: 'gemini-pro', + project: undefined, + request: { + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + systemInstruction: undefined, + cachedContent: undefined, + tools: undefined, + toolConfig: undefined, + labels: undefined, + safetySettings: undefined, + generationConfig: undefined, + }, + }); + }); + + it('should handle string content', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: 'Hello', + }; + const ccpaReq = toCcpaRequest(genaiReq); + expect(ccpaReq.request.contents).toEqual([ + { role: 'user', parts: [{ text: 'Hello' }] }, + ]); + }); + + it('should handle Part[] content', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ text: 'Hello' }, { text: 'World' }], + }; + const ccpaReq = toCcpaRequest(genaiReq); + expect(ccpaReq.request.contents).toEqual([ + { role: 'user', parts: [{ text: 'Hello' }] }, + { role: 'user', parts: [{ text: 'World' }] }, + ]); + }); + + it('should handle system instructions', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: 'Hello', + config: { + systemInstruction: 'You are a helpful assistant.', + }, + }; + const ccpaReq = toCcpaRequest(genaiReq); + expect(ccpaReq.request.systemInstruction).toEqual({ + role: 'user', + parts: [{ text: 'You are a helpful assistant.' }], + }); + }); + + it('should handle generation config', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: 'Hello', + config: { + temperature: 0.8, + topK: 40, + }, + }; + const ccpaReq = toCcpaRequest(genaiReq); + expect(ccpaReq.request.generationConfig).toEqual({ + temperature: 0.8, + topK: 40, + }); + }); + + it('should handle all generation config fields', () => { + const genaiReq: GenerateContentParameters = { + model: 'gemini-pro', + contents: 'Hello', + config: { + temperature: 0.1, + topP: 0.2, + topK: 3, + candidateCount: 4, + maxOutputTokens: 5, + stopSequences: ['a'], + responseLogprobs: true, + logprobs: 6, + presencePenalty: 0.7, + frequencyPenalty: 0.8, + seed: 9, + responseMimeType: 'application/json', + }, + }; + const ccpaReq = toCcpaRequest(genaiReq); + expect(ccpaReq.request.generationConfig).toEqual({ + temperature: 0.1, + topP: 0.2, + topK: 3, + candidateCount: 4, + maxOutputTokens: 5, + stopSequences: ['a'], + responseLogprobs: true, + logprobs: 6, + presencePenalty: 0.7, + frequencyPenalty: 0.8, + seed: 9, + responseMimeType: 'application/json', + }); + }); + }); + + describe('fromCcpaResponse', () => { + it('should convert a simple response', () => { + const ccpaRes: CcpaResponse = { + response: { + candidates: [ + { + index: 0, + content: { + role: 'model', + parts: [{ text: 'Hi there!' }], + }, + finishReason: FinishReason.STOP, + safetyRatings: [], + }, + ], + }, + }; + const genaiRes = fromCcpaResponse(ccpaRes); + expect(genaiRes).toBeInstanceOf(GenerateContentResponse); + expect(genaiRes.candidates).toEqual(ccpaRes.response.candidates); + }); + + it('should handle prompt feedback and usage metadata', () => { + const ccpaRes: CcpaResponse = { + response: { + candidates: [], + promptFeedback: { + blockReason: BlockedReason.SAFETY, + safetyRatings: [], + }, + usageMetadata: { + promptTokenCount: 10, + candidatesTokenCount: 20, + totalTokenCount: 30, + }, + }, + }; + const genaiRes = fromCcpaResponse(ccpaRes); + expect(genaiRes.promptFeedback).toEqual(ccpaRes.response.promptFeedback); + expect(genaiRes.usageMetadata).toEqual(ccpaRes.response.usageMetadata); + }); + + it('should handle automatic function calling history', () => { + const ccpaRes: CcpaResponse = { + response: { + candidates: [], + automaticFunctionCallingHistory: [ + { + role: 'model', + parts: [ + { + functionCall: { + name: 'test_function', + args: { + foo: 'bar', + }, + }, + }, + ], + }, + ], + }, + }; + const genaiRes = fromCcpaResponse(ccpaRes); + expect(genaiRes.automaticFunctionCallingHistory).toEqual( + ccpaRes.response.automaticFunctionCallingHistory, + ); + }); + }); +}); diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts new file mode 100644 index 00000000..c7b0e7c7 --- /dev/null +++ b/packages/core/src/code_assist/converter.ts @@ -0,0 +1,199 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Content, + ContentListUnion, + ContentUnion, + GenerateContentConfig, + GenerateContentParameters, + GenerateContentResponse, + GenerationConfigRoutingConfig, + MediaResolution, + Candidate, + ModelSelectionConfig, + GenerateContentResponsePromptFeedback, + GenerateContentResponseUsageMetadata, + Part, + SafetySetting, + PartUnion, + SchemaUnion, + SpeechConfigUnion, + ThinkingConfig, + ToolListUnion, + ToolConfig, +} from '@google/genai'; + +export interface CcpaRequest { + model: string; + project?: string; + request: CcpaGenerateContentRequest; +} + +interface CcpaGenerateContentRequest { + contents: Content[]; + systemInstruction?: Content; + cachedContent?: string; + tools?: ToolListUnion; + toolConfig?: ToolConfig; + labels?: Record; + safetySettings?: SafetySetting[]; + generationConfig?: CcpaGenerationConfig; +} + +interface CcpaGenerationConfig { + temperature?: number; + topP?: number; + topK?: number; + candidateCount?: number; + maxOutputTokens?: number; + stopSequences?: string[]; + responseLogprobs?: boolean; + logprobs?: number; + presencePenalty?: number; + frequencyPenalty?: number; + seed?: number; + responseMimeType?: string; + responseSchema?: SchemaUnion; + routingConfig?: GenerationConfigRoutingConfig; + modelSelectionConfig?: ModelSelectionConfig; + responseModalities?: string[]; + mediaResolution?: MediaResolution; + speechConfig?: SpeechConfigUnion; + audioTimestamp?: boolean; + thinkingConfig?: ThinkingConfig; +} + +export interface CcpaResponse { + response: VertexResponse; +} + +interface VertexResponse { + candidates: Candidate[]; + automaticFunctionCallingHistory?: Content[]; + promptFeedback?: GenerateContentResponsePromptFeedback; + usageMetadata?: GenerateContentResponseUsageMetadata; +} + +export function toCcpaRequest( + req: GenerateContentParameters, + project?: string, +): CcpaRequest { + return { + model: req.model, + project, + request: toCcpaGenerateContentRequest(req), + }; +} + +export function fromCcpaResponse(res: CcpaResponse): GenerateContentResponse { + const inres = res.response; + const out = new GenerateContentResponse(); + out.candidates = inres.candidates; + out.automaticFunctionCallingHistory = inres.automaticFunctionCallingHistory; + out.promptFeedback = inres.promptFeedback; + out.usageMetadata = inres.usageMetadata; + return out; +} + +function toCcpaGenerateContentRequest( + req: GenerateContentParameters, +): CcpaGenerateContentRequest { + return { + contents: toContents(req.contents), + systemInstruction: maybeToContent(req.config?.systemInstruction), + cachedContent: req.config?.cachedContent, + tools: req.config?.tools, + toolConfig: req.config?.toolConfig, + labels: req.config?.labels, + safetySettings: req.config?.safetySettings, + generationConfig: toCcpaGenerationConfig(req.config), + }; +} + +function toContents(contents: ContentListUnion): Content[] { + if (Array.isArray(contents)) { + // it's a Content[] or a PartsUnion[] + return contents.map(toContent); + } + // it's a Content or a PartsUnion + return [toContent(contents)]; +} + +function maybeToContent(content?: ContentUnion): Content | undefined { + if (!content) { + return undefined; + } + return toContent(content); +} + +function toContent(content: ContentUnion): Content { + if (Array.isArray(content)) { + // it's a PartsUnion[] + return { + role: 'user', + parts: toParts(content), + }; + } + if (typeof content === 'string') { + // it's a string + return { + role: 'user', + parts: [{ text: content }], + }; + } + if ('parts' in content) { + // it's a Content + return content; + } + // it's a Part + return { + role: 'user', + parts: [content as Part], + }; +} + +function toParts(parts: PartUnion[]): Part[] { + return parts.map(toPart); +} + +function toPart(part: PartUnion): Part { + if (typeof part === 'string') { + // it's a string + return { text: part }; + } + return part; +} + +function toCcpaGenerationConfig( + config?: GenerateContentConfig, +): CcpaGenerationConfig | undefined { + if (!config) { + return undefined; + } + return { + temperature: config.temperature, + topP: config.topP, + topK: config.topK, + candidateCount: config.candidateCount, + maxOutputTokens: config.maxOutputTokens, + stopSequences: config.stopSequences, + responseLogprobs: config.responseLogprobs, + logprobs: config.logprobs, + presencePenalty: config.presencePenalty, + frequencyPenalty: config.frequencyPenalty, + seed: config.seed, + responseMimeType: config.responseMimeType, + responseSchema: config.responseSchema, + routingConfig: config.routingConfig, + modelSelectionConfig: config.modelSelectionConfig, + responseModalities: config.responseModalities, + mediaResolution: config.mediaResolution, + speechConfig: config.speechConfig, + audioTimestamp: config.audioTimestamp, + thinkingConfig: config.thinkingConfig, + }; +}