| """ |
| Batch inference for TTS evaluation benchmarks (Seed-TTS-eval & CV3-eval). |
| |
| Uses the llama.cpp backend (MOSS-TTS-Delay). |
| |
| Expected benchmark layout (per case):: |
| |
| {benchmark_dir}/{task}/{case_id}/prompt.wav |
| {benchmark_dir}/{task}/{case_id}/label.txt |
| |
| Output layout:: |
| |
| {result_dir}/{task}/{case_id}/pred.wav |
| |
| Usage:: |
| |
| python scripts/batch_eval_llama_cpp.py \\ |
| --config configs/llama_cpp/default.yaml \\ |
| --benchmark-dir /path/to/eval/tts \\ |
| --result-dir results/my_run \\ |
| --tasks seed-tts-zeroshot-zh seed-tts-zeroshot-en |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| import time |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import numpy as np |
| import soundfile as sf |
| from tqdm import tqdm |
|
|
| from moss_tts_delay.llama_cpp import LlamaCppPipeline, PipelineConfig |
| from moss_tts_delay.llama_cpp._constants import SAMPLE_RATE |
|
|
| log = logging.getLogger(__name__) |
|
|
| SEED_TTS_TASKS = [ |
| "seed-tts-zeroshot-zh", |
| "seed-tts-zeroshot-en", |
| "seed-tts-zeroshot-hard-zh", |
| ] |
|
|
| CV3_TASKS = [ |
| "cv3-crosslingual-en", |
| "cv3-crosslingual-hard-en", |
| "cv3-zeroshot-en", |
| "cv3-zeroshot-hard-en", |
| "cv3-crosslingual-zh", |
| "cv3-crosslingual-hard-zh", |
| "cv3-zeroshot-zh", |
| "cv3-zeroshot-hard-zh", |
| ] |
|
|
| ALL_TASKS = SEED_TTS_TASKS + CV3_TASKS + ["demo-zh", "demo-en"] |
|
|
| TASK_LANGUAGE = { |
| "seed-tts-zeroshot-zh": "zh", |
| "seed-tts-zeroshot-en": "en", |
| "seed-tts-zeroshot-hard-zh": "zh", |
| "cv3-crosslingual-en": "en", |
| "cv3-crosslingual-hard-en": "en", |
| "cv3-zeroshot-en": "en", |
| "cv3-zeroshot-hard-en": "en", |
| "cv3-crosslingual-zh": "zh", |
| "cv3-crosslingual-hard-zh": "zh", |
| "cv3-zeroshot-zh": "zh", |
| "cv3-zeroshot-hard-zh": "zh", |
| "demo-zh": "zh", |
| "demo-en": "en", |
| } |
|
|
|
|
| @dataclass |
| class CaseResult: |
| task: str |
| case_id: str |
| success: bool |
| audio_duration: float = 0.0 |
| generation_time: float = 0.0 |
| error: str = "" |
|
|
|
|
| def discover_cases(benchmark_dir: Path, tasks: list[str]) -> list[tuple[str, str, Path, str]]: |
| cases = [] |
| for task in tasks: |
| task_dir = benchmark_dir / task |
| if not task_dir.is_dir(): |
| log.warning("Task directory not found: %s", task_dir) |
| continue |
| for case_dir in sorted(task_dir.iterdir()): |
| if not case_dir.is_dir(): |
| continue |
| prompt_wav = case_dir / "prompt.wav" |
| label_txt = case_dir / "label.txt" |
| if not label_txt.exists(): |
| log.warning("Missing label.txt: %s", case_dir) |
| continue |
| text = label_txt.read_text().strip() |
| cases.append((task, case_dir.name, prompt_wav, text)) |
| return cases |
|
|
|
|
| def run_batch( |
| pipeline: LlamaCppPipeline, |
| cases: list[tuple[str, str, Path, str]], |
| result_dir: Path, |
| max_cases: int = 0, |
| skip_existing: bool = True, |
| ) -> list[CaseResult]: |
| results: list[CaseResult] = [] |
| total = len(cases) if max_cases <= 0 else min(max_cases, len(cases)) |
| cases = cases[:total] |
|
|
| log.info("Running %d evaluation cases, output -> %s", total, result_dir) |
|
|
| pbar = tqdm(cases, desc="Evaluation", unit="case", total=total, dynamic_ncols=True) |
| for i, (task, case_id, prompt_wav, text) in enumerate(pbar): |
| pbar.set_postfix_str(f"{task}/{case_id}") |
| out_dir = result_dir / task / case_id |
| out_wav = out_dir / "pred.wav" |
|
|
| if skip_existing and out_wav.exists(): |
| log.info("[%d/%d] %s/%s — skipped (exists)", i + 1, total, task, case_id) |
| results.append(CaseResult(task=task, case_id=case_id, success=True)) |
| continue |
|
|
| log.info("[%d/%d] %s/%s — %s", i + 1, total, task, case_id, text[:60]) |
| t0 = time.time() |
|
|
| try: |
| lang = TASK_LANGUAGE.get(task) |
| ref_audio = str(prompt_wav) if prompt_wav.exists() else None |
|
|
| waveform = pipeline.generate( |
| text=text, reference_audio=ref_audio, language=lang, |
| ) |
| elapsed = time.time() - t0 |
|
|
| if waveform.size == 0: |
| results.append(CaseResult( |
| task=task, case_id=case_id, success=False, |
| generation_time=elapsed, error="empty waveform", |
| )) |
| continue |
|
|
| out_dir.mkdir(parents=True, exist_ok=True) |
| sf.write(str(out_wav), waveform, SAMPLE_RATE) |
| audio_dur = len(waveform) / SAMPLE_RATE |
|
|
| results.append(CaseResult( |
| task=task, case_id=case_id, success=True, |
| audio_duration=audio_dur, generation_time=elapsed, |
| )) |
| log.info( |
| " -> %.2fs audio in %.2fs (RTF=%.2f)", |
| audio_dur, elapsed, elapsed / max(audio_dur, 1e-6), |
| ) |
|
|
| except Exception as e: |
| elapsed = time.time() - t0 |
| log.error(" -> FAILED: %s", e) |
| results.append(CaseResult( |
| task=task, case_id=case_id, success=False, |
| generation_time=elapsed, error=str(e), |
| )) |
|
|
| return results |
|
|
|
|
| def write_summary(results: list[CaseResult], result_dir: Path) -> None: |
| succeeded = [r for r in results if r.success] |
| failed = [r for r in results if not r.success] |
|
|
| per_task: dict[str, dict] = {} |
| for r in results: |
| if r.task not in per_task: |
| per_task[r.task] = {"total": 0, "success": 0, "failed": 0, "total_audio_s": 0.0, "total_gen_s": 0.0} |
| per_task[r.task]["total"] += 1 |
| if r.success: |
| per_task[r.task]["success"] += 1 |
| per_task[r.task]["total_audio_s"] += r.audio_duration |
| per_task[r.task]["total_gen_s"] += r.generation_time |
| else: |
| per_task[r.task]["failed"] += 1 |
|
|
| for task, stats in per_task.items(): |
| if stats["total_audio_s"] > 0: |
| stats["avg_rtf"] = round(stats["total_gen_s"] / stats["total_audio_s"], 3) |
|
|
| summary = { |
| "total_cases": len(results), |
| "succeeded": len(succeeded), |
| "failed": len(failed), |
| "per_task": per_task, |
| } |
|
|
| if failed: |
| summary["failures"] = [ |
| {"task": r.task, "case_id": r.case_id, "error": r.error} |
| for r in failed |
| ] |
|
|
| summary_path = result_dir / "inference_summary.json" |
| with open(summary_path, "w") as f: |
| json.dump(summary, f, indent=2, ensure_ascii=False) |
| log.info("Summary written to %s", summary_path) |
|
|
| print("\n" + "=" * 60) |
| print(" BATCH INFERENCE SUMMARY") |
| print("=" * 60) |
| print(f" Total: {len(results)}") |
| print(f" Succeeded: {len(succeeded)}") |
| print(f" Failed: {len(failed)}") |
| for task, stats in per_task.items(): |
| rtf = stats.get("avg_rtf", "N/A") |
| print(f" {task}: {stats['success']}/{stats['total']} RTF={rtf}") |
| print("=" * 60 + "\n") |
|
|
|
|
| def main(): |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| ) |
|
|
| parser = argparse.ArgumentParser( |
| description="Batch TTS evaluation (llama.cpp backend)", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
| parser.add_argument("--config", required=True, help="Pipeline YAML config") |
| parser.add_argument( |
| "--benchmark-dir", |
| default="/inspire/hdd/project/embodied-multimodality/public/speech_generation/data/eval/tts", |
| ) |
| parser.add_argument("--result-dir", required=True) |
| parser.add_argument("--tasks", nargs="+", default=None) |
| parser.add_argument("--suite", choices=["seed-tts", "cv3", "all"], default=None) |
| parser.add_argument("--max-cases", type=int, default=0) |
| parser.add_argument("--no-skip", action="store_true") |
|
|
| parser.add_argument("--text-temp", type=float, default=None) |
| parser.add_argument("--audio-temp", type=float, default=None) |
| parser.add_argument("--audio-top-p", type=float, default=None) |
| parser.add_argument("--audio-top-k", type=int, default=None) |
| parser.add_argument("--audio-rep-penalty", type=float, default=None) |
| parser.add_argument("--n-gpu-layers", type=int, default=None) |
| parser.add_argument("--max-tokens", type=int, default=None) |
| parser.add_argument("--heads-backend", choices=["auto", "numpy", "torch"], default=None) |
|
|
| args = parser.parse_args() |
| config = PipelineConfig.from_yaml(args.config) |
|
|
| if args.text_temp is not None: |
| config.text_temperature = args.text_temp |
| if args.audio_temp is not None: |
| config.audio_temperature = args.audio_temp |
| if args.audio_top_p is not None: |
| config.audio_top_p = args.audio_top_p |
| if args.audio_top_k is not None: |
| config.audio_top_k = args.audio_top_k |
| if args.audio_rep_penalty is not None: |
| config.audio_repetition_penalty = args.audio_rep_penalty |
| if args.n_gpu_layers is not None: |
| config.n_gpu_layers = args.n_gpu_layers |
| if args.max_tokens is not None: |
| config.max_new_tokens = args.max_tokens |
| if args.heads_backend is not None: |
| config.heads_backend = args.heads_backend |
|
|
| if args.tasks: |
| tasks = args.tasks |
| elif args.suite == "seed-tts": |
| tasks = SEED_TTS_TASKS |
| elif args.suite == "cv3": |
| tasks = CV3_TASKS |
| else: |
| tasks = ALL_TASKS |
|
|
| for t in tasks: |
| if t not in ALL_TASKS: |
| log.error("Unknown task: %s. Valid tasks: %s", t, ALL_TASKS) |
| sys.exit(1) |
|
|
| benchmark_dir = Path(args.benchmark_dir) |
| result_dir = Path(args.result_dir) |
| result_dir.mkdir(parents=True, exist_ok=True) |
|
|
| cases = discover_cases(benchmark_dir, tasks) |
| if not cases: |
| log.error("No cases found in %s for tasks %s", benchmark_dir, tasks) |
| sys.exit(1) |
| log.info("Discovered %d cases across %d tasks", len(cases), len(tasks)) |
|
|
| run_meta = { |
| "config": args.config, |
| "benchmark_dir": str(benchmark_dir), |
| "tasks": tasks, |
| "sampling": { |
| "text_temperature": config.text_temperature, |
| "text_top_p": config.text_top_p, |
| "text_top_k": config.text_top_k, |
| "audio_temperature": config.audio_temperature, |
| "audio_top_p": config.audio_top_p, |
| "audio_top_k": config.audio_top_k, |
| "audio_repetition_penalty": config.audio_repetition_penalty, |
| }, |
| "max_new_tokens": config.max_new_tokens, |
| "backbone_gguf": config.backbone_gguf, |
| "heads_backend": config.heads_backend, |
| } |
| with open(result_dir / "run_meta.json", "w") as f: |
| json.dump(run_meta, f, indent=2, ensure_ascii=False) |
|
|
| with LlamaCppPipeline(config) as pipeline: |
| results = run_batch( |
| pipeline, cases, result_dir, |
| max_cases=args.max_cases, |
| skip_existing=not args.no_skip, |
| ) |
|
|
| write_summary(results, result_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|