diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts index 475b5662..9141105f 100644 --- a/packages/core/src/utils/nextSpeakerChecker.test.ts +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -6,6 +6,7 @@ import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest'; import { Content, GoogleGenAI, Models } from '@google/genai'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { GeminiClient } from '../core/client.js'; import { Config } from '../config/config.js'; import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js'; @@ -231,4 +232,22 @@ describe('checkNextSpeaker', () => { ); expect(result).toBeNull(); }); + + it('should call generateJson with DEFAULT_GEMINI_FLASH_MODEL', async () => { + (chatInstance.getHistory as Mock).mockReturnValue([ + { role: 'model', parts: [{ text: 'Some model output.' }] }, + ] as Content[]); + const mockApiResponse: NextSpeakerResponse = { + reasoning: 'Model made a statement, awaiting user input.', + next_speaker: 'user', + }; + (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); + + await checkNextSpeaker(chatInstance, mockGeminiClient, abortSignal); + + expect(mockGeminiClient.generateJson).toHaveBeenCalled(); + const generateJsonCall = (mockGeminiClient.generateJson as Mock).mock + .calls[0]; + expect(generateJsonCall[3]).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); }); diff --git a/packages/core/src/utils/nextSpeakerChecker.ts b/packages/core/src/utils/nextSpeakerChecker.ts index 165f277a..9d428887 100644 --- a/packages/core/src/utils/nextSpeakerChecker.ts +++ b/packages/core/src/utils/nextSpeakerChecker.ts @@ -5,6 +5,7 @@ */ import { Content, SchemaUnion, Type } from '@google/genai'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { GeminiClient } from '../core/client.js'; import { GeminiChat } from '../core/geminiChat.js'; import { isFunctionResponse } from './messageInspectors.js'; @@ -131,6 +132,7 @@ export async function checkNextSpeaker( contents, RESPONSE_SCHEMA, abortSignal, + DEFAULT_GEMINI_FLASH_MODEL, )) as unknown as NextSpeakerResponse; if (