143 lines
4.6 KiB
TypeScript
143 lines
4.6 KiB
TypeScript
import { millisecondsToTime } from '@peertube/peertube-core-utils'
|
|
import { SUUID, buildAbsoluteFixturePath, buildSUUID } from '@peertube/peertube-node-utils'
|
|
import {
|
|
TranscriptFile,
|
|
TranscriptionEngine,
|
|
TranscriptionEngineName,
|
|
TranscriptionModel,
|
|
transcriberFactory
|
|
} from '@peertube/peertube-transcription'
|
|
import { ensureDir, remove } from 'fs-extra/esm'
|
|
import { tmpdir } from 'node:os'
|
|
import { join } from 'node:path'
|
|
import { PerformanceObserver, performance } from 'node:perf_hooks'
|
|
import { createLogger, format, transports } from 'winston'
|
|
import { TranscriptFileEvaluator } from './transcript-file-evaluator.js'
|
|
|
|
interface BenchmarkResult {
|
|
uuid: SUUID
|
|
WER?: number
|
|
CER?: number
|
|
duration?: number
|
|
engine?: TranscriptionEngine
|
|
model?: string
|
|
}
|
|
|
|
type Benchmark = Record<SUUID, BenchmarkResult>
|
|
|
|
const benchmarkReducer = (benchmark: Benchmark = {}, benchmarkResult: BenchmarkResult) => ({
|
|
...benchmark,
|
|
[benchmarkResult.uuid]: {
|
|
...benchmark[benchmarkResult.uuid],
|
|
...benchmarkResult
|
|
}
|
|
})
|
|
|
|
const groupBenchmarkResultsByModel = (benchmarkResults: Record<string, BenchmarkResult>) => (benchmarksGroupedByModel, uuid) => ({
|
|
...benchmarksGroupedByModel,
|
|
[benchmarkResults[uuid].model]: {
|
|
...benchmarksGroupedByModel[benchmarkResults[uuid].model],
|
|
[uuid]: formatBenchmarkResult(benchmarkResults[uuid])
|
|
}
|
|
})
|
|
|
|
interface FormattedBenchmarkResult {
|
|
WER?: string
|
|
CER?: string
|
|
duration?: string
|
|
model?: string
|
|
engine?: string
|
|
}
|
|
|
|
const formatBenchmarkResult = ({ WER, CER, duration, engine, model }: Partial<BenchmarkResult>): FormattedBenchmarkResult => ({
|
|
WER: WER ? `${WER * 100}%` : undefined,
|
|
CER: CER ? `${CER * 100}%` : undefined,
|
|
duration: duration ? millisecondsToTime(duration) : undefined,
|
|
model,
|
|
engine: engine.name
|
|
})
|
|
|
|
void (async () => {
|
|
const logger = createLogger()
|
|
logger.add(new transports.Console({ format: format.printf(log => log.message) }))
|
|
|
|
const transcribers: TranscriptionEngineName[] = [ 'openai-whisper', 'whisper-ctranslate2' ]
|
|
const models = process.env.MODELS
|
|
? process.env.MODELS.trim().split(',').map(modelName => modelName.trim()).filter(modelName => modelName)
|
|
: [ 'tiny' ]
|
|
|
|
const transcriptDirectory = join(tmpdir(), 'peertube-transcription', 'benchmark')
|
|
const pipDirectory = join(tmpdir(), 'peertube-transcription', 'pip')
|
|
|
|
const mediaFilePath = buildAbsoluteFixturePath('transcription/videos/derive_sectaire.mp4')
|
|
const referenceTranscriptFile = new TranscriptFile({
|
|
path: buildAbsoluteFixturePath('transcription/videos/derive_sectaire.txt'),
|
|
language: 'fr',
|
|
format: 'txt'
|
|
})
|
|
|
|
let benchmarkResults: Record<string, BenchmarkResult> = {}
|
|
|
|
// before
|
|
await ensureDir(transcriptDirectory)
|
|
const performanceObserver = new PerformanceObserver((items) => {
|
|
items
|
|
.getEntries()
|
|
.forEach((entry) => {
|
|
benchmarkResults = benchmarkReducer(benchmarkResults, {
|
|
uuid: entry.name as SUUID,
|
|
duration: entry.duration
|
|
})
|
|
})
|
|
})
|
|
performanceObserver.observe({ type: 'measure' })
|
|
|
|
// benchmark
|
|
logger.info(`Running transcribers benchmark with the following models: ${models.join(', ')}`)
|
|
for (const transcriberName of transcribers) {
|
|
logger.info(`Create "${transcriberName}" transcriber for the benchmark...`)
|
|
|
|
const transcriber = transcriberFactory.createFromEngineName({
|
|
engineName: transcriberName,
|
|
logger: createLogger({ transports: [ new transports.Console() ] }),
|
|
binDirectory: join(pipDirectory, 'bin')
|
|
})
|
|
|
|
await transcriber.install(pipDirectory)
|
|
|
|
for (const modelName of models) {
|
|
logger.info(`Run benchmark with "${modelName}" model:`)
|
|
const model = new TranscriptionModel(modelName)
|
|
const uuid = buildSUUID()
|
|
const transcriptFile = await transcriber.transcribe({
|
|
mediaFilePath,
|
|
model,
|
|
transcriptDirectory,
|
|
language: 'fr',
|
|
format: 'txt',
|
|
runId: uuid
|
|
})
|
|
const evaluator = new TranscriptFileEvaluator(referenceTranscriptFile, transcriptFile)
|
|
await new Promise(resolve => setTimeout(resolve, 1))
|
|
|
|
benchmarkResults = benchmarkReducer(benchmarkResults, {
|
|
uuid,
|
|
engine: transcriber.engine,
|
|
WER: await evaluator.wer(),
|
|
CER: await evaluator.cer(),
|
|
model: model.name
|
|
})
|
|
}
|
|
}
|
|
|
|
// display
|
|
const benchmarkResultsGroupedByModel = Object
|
|
.keys(benchmarkResults)
|
|
.reduce(groupBenchmarkResultsByModel(benchmarkResults), {})
|
|
Object.values(benchmarkResultsGroupedByModel).forEach(benchmark => console.table(benchmark))
|
|
|
|
// after
|
|
await remove(transcriptDirectory)
|
|
performance.clearMarks()
|
|
})()
|