Spaces:
Running
Phase 3: Voice-to-Voice S2S pipeline — F5-TTS, LLM brain, CER metric
Browse filesapp.py:
- Add LLM_MODEL_ID env var (default: Qwen/Qwen2.5-72B-Instruct)
- Import GemmaClient + bam_normalize at module level
- Add _voice_ref_path / _voice_ref_text / _llm_client state
- Add set_voice_reference(): converts MP3->24kHz WAV, auto-transcribes
- Add _convo_pipeline(): ASR -> bam_normalize -> LLM (phonetic Bambara
system prompt) -> F5-TTS (voice ref) with MMS-TTS fallback
- handle_ask() accepts convo_mode=bool, routes to _convo_pipeline or
the original sensor pipeline accordingly
- Tab 1 UI: Conversation Mode toggle, Voice Reference upload accordion,
stop_recording auto-submit for true back-to-back loop
src/tts/f5_tts.py (new):
- Lazy-loaded F5TTS wrapper; synthesize() with ref_wav + ref_text
- to_wav_24k() resampler (F5-TTS needs 24 kHz input)
- Graceful fallback (returns None) when f5-tts not installed
src/data/bam_normalize.py (new):
- _bam_norm(): ou->u, dj->j, gn->ny, ch->c, oo->o-open, ee->e-open
- Used at inference (app.py) and training (notebook Cell 11)
requirements.txt: add f5-tts>=1.0.0
Notebook:
- Cell 10: _bam_norm() defined inline (no external dependency)
- Cell 11: apply _bam_norm before tokenisation in prepare_dataset
- Cell 14: compute_metrics returns {cer, wer}; CER is primary metric
- Cell 15: metric_for_best_model='cer'; best checkpoint = lowest CER
- Cell 17: show CER (primary) + WER (secondary) in evaluation output
- Cell 19: CER in Hub commit message
- Cell 20: CER in verification summary
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- app.py +231 -8
- notebooks/kaggle_master_trainer.ipynb +11 -7
- requirements.txt +5 -0
- scripts/patch_notebook_cer.py +177 -0
- src/data/bam_normalize.py +67 -0
- src/tts/f5_tts.py +114 -0
|
@@ -7,6 +7,7 @@ Environment variables (set in Space Settings → Secrets):
|
|
| 7 |
FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private)
|
| 8 |
ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private)
|
| 9 |
WHISPER_MODEL_ID — default: openai/whisper-small
|
|
|
|
| 10 |
KAGGLE_USERNAME — Kaggle username (for auto-trigger training)
|
| 11 |
KAGGLE_KEY — Kaggle API key (for auto-trigger training)
|
| 12 |
KAGGLE_KERNEL_SLUG — default: ous-sow/sahel-voice-master-trainer
|
|
@@ -36,6 +37,7 @@ ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapte
|
|
| 36 |
# whisper-small: ~10s on cpu-basic, good multilingual quality.
|
| 37 |
# Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
|
| 38 |
WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
|
|
|
|
| 39 |
KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "")
|
| 40 |
KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "")
|
| 41 |
KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer")
|
|
@@ -66,11 +68,18 @@ _fine_tuned_models = {} # lang_code -> WhisperForConditionalGeneration (ful
|
|
| 66 |
_model_lock = threading.Lock()
|
| 67 |
_model_status = "not loaded"
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
from src.tts.mms_tts import MMSTTSEngine
|
| 70 |
from src.iot.intent_parser import IntentParser
|
| 71 |
from src.iot.sensor_bridge import SensorBridge
|
| 72 |
from src.iot.voice_responder import VoiceResponder
|
| 73 |
from src.conversation.phrase_matcher import PhraseMatcher
|
|
|
|
|
|
|
| 74 |
|
| 75 |
_tts = MMSTTSEngine()
|
| 76 |
_intent_parser = IntentParser()
|
|
@@ -237,6 +246,174 @@ def _run_pipeline(audio_path: str, language_code: str):
|
|
| 237 |
return transcript, english_translation, response_text, (sample_rate, wav_np)
|
| 238 |
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
# ── HF Hub feedback persistence ───────────────────────────────────────────────
|
| 241 |
|
| 242 |
def _save_feedback_to_hub(
|
|
@@ -937,7 +1114,7 @@ def _harvest_hf_dataset(lang_label: str, max_samples: int = 500) -> str:
|
|
| 937 |
|
| 938 |
# ── Main ask handler ──────────────────────────────────────────────────────────
|
| 939 |
|
| 940 |
-
def handle_ask(audio_path, language_label):
|
| 941 |
if audio_path is None:
|
| 942 |
return "⚠️ No audio — press Record or upload a file.", "", "", None
|
| 943 |
|
|
@@ -948,8 +1125,11 @@ def handle_ask(audio_path, language_label):
|
|
| 948 |
return f"⏳ Model loading ({status}). Wait a moment and try again.", "", "", None
|
| 949 |
|
| 950 |
try:
|
| 951 |
-
|
| 952 |
-
|
|
|
|
|
|
|
|
|
|
| 953 |
except Exception as e:
|
| 954 |
return f"❌ {e}", "", "", None
|
| 955 |
|
|
@@ -977,6 +1157,39 @@ def build_ui() -> gr.Blocks:
|
|
| 977 |
|
| 978 |
# ── Tab 1: Voice Assistant ────────────────────────────────────────
|
| 979 |
with gr.TabItem("🎙️ Voice Assistant", id="tab_voice"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 980 |
with gr.Row():
|
| 981 |
with gr.Column(scale=1):
|
| 982 |
language_dd = gr.Dropdown(
|
|
@@ -999,15 +1212,15 @@ def build_ui() -> gr.Blocks:
|
|
| 999 |
interactive=False,
|
| 1000 |
)
|
| 1001 |
translation_box = gr.Textbox(
|
| 1002 |
-
label="English translation",
|
| 1003 |
lines=2,
|
| 1004 |
placeholder="English meaning will appear here…",
|
| 1005 |
interactive=False,
|
| 1006 |
)
|
| 1007 |
response_box = gr.Textbox(
|
| 1008 |
-
label="
|
| 1009 |
lines=2,
|
| 1010 |
-
placeholder="
|
| 1011 |
interactive=False,
|
| 1012 |
)
|
| 1013 |
audio_output = gr.Audio(
|
|
@@ -1021,10 +1234,20 @@ def build_ui() -> gr.Blocks:
|
|
| 1021 |
size="sm",
|
| 1022 |
)
|
| 1023 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1024 |
ask_btn.click(
|
| 1025 |
fn=handle_ask,
|
| 1026 |
-
inputs=
|
| 1027 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1028 |
)
|
| 1029 |
|
| 1030 |
# ── Tab 2: Feedback & Correction ─────────────────────────────────
|
|
|
|
| 7 |
FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private)
|
| 8 |
ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private)
|
| 9 |
WHISPER_MODEL_ID — default: openai/whisper-small
|
| 10 |
+
LLM_MODEL_ID — default: Qwen/Qwen2.5-72B-Instruct
|
| 11 |
KAGGLE_USERNAME — Kaggle username (for auto-trigger training)
|
| 12 |
KAGGLE_KEY — Kaggle API key (for auto-trigger training)
|
| 13 |
KAGGLE_KERNEL_SLUG — default: ous-sow/sahel-voice-master-trainer
|
|
|
|
| 37 |
# whisper-small: ~10s on cpu-basic, good multilingual quality.
|
| 38 |
# Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
|
| 39 |
WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
|
| 40 |
+
LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "Qwen/Qwen2.5-72B-Instruct")
|
| 41 |
KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "")
|
| 42 |
KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "")
|
| 43 |
KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer")
|
|
|
|
| 68 |
_model_lock = threading.Lock()
|
| 69 |
_model_status = "not loaded"
|
| 70 |
|
| 71 |
+
# ── Conversation-mode state ───────────────────────────────────────────────────
|
| 72 |
+
_voice_ref_path: str | None = None # path to 24 kHz WAV converted from user MP3
|
| 73 |
+
_voice_ref_text: str = "" # auto-transcribed text of reference audio
|
| 74 |
+
_llm_client = None # GemmaClient, lazy init
|
| 75 |
+
|
| 76 |
from src.tts.mms_tts import MMSTTSEngine
|
| 77 |
from src.iot.intent_parser import IntentParser
|
| 78 |
from src.iot.sensor_bridge import SensorBridge
|
| 79 |
from src.iot.voice_responder import VoiceResponder
|
| 80 |
from src.conversation.phrase_matcher import PhraseMatcher
|
| 81 |
+
from src.llm.gemma_client import GemmaClient
|
| 82 |
+
from src.data.bam_normalize import normalize as bam_normalize
|
| 83 |
|
| 84 |
_tts = MMSTTSEngine()
|
| 85 |
_intent_parser = IntentParser()
|
|
|
|
| 246 |
return transcript, english_translation, response_text, (sample_rate, wav_np)
|
| 247 |
|
| 248 |
|
| 249 |
+
# ── Conversation-mode helpers ─────────────────────────────────────────────────
|
| 250 |
+
|
| 251 |
+
# Bambara conversation system prompt — instructs LLM to respond phonetically
|
| 252 |
+
_BAM_CONVO_SYSTEM = """\
|
| 253 |
+
You are a friendly Bambara voice assistant. Rules you must follow:
|
| 254 |
+
1. Always reply in Bambara, matching the user's informal spoken style.
|
| 255 |
+
2. Use phonetic spelling: write 'u' instead of 'ou', 'j' instead of 'dj', \
|
| 256 |
+
'c' instead of 'ch' — spell words as they sound when spoken aloud.
|
| 257 |
+
3. Keep responses short: 1–3 sentences max. This is a voice conversation.
|
| 258 |
+
4. Never add translations or explanations unless explicitly asked.
|
| 259 |
+
5. If the user speaks French or English, switch to that language naturally."""
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _get_llm() -> GemmaClient:
|
| 263 |
+
global _llm_client
|
| 264 |
+
if _llm_client is None:
|
| 265 |
+
_llm_client = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
|
| 266 |
+
return _llm_client
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def set_voice_reference(audio_file) -> str:
|
| 270 |
+
"""
|
| 271 |
+
Store an uploaded audio file as the TTS voice reference.
|
| 272 |
+
Converts to 24 kHz WAV (F5-TTS requirement) and auto-transcribes.
|
| 273 |
+
Returns a status string for the UI.
|
| 274 |
+
"""
|
| 275 |
+
global _voice_ref_path, _voice_ref_text
|
| 276 |
+
|
| 277 |
+
if audio_file is None:
|
| 278 |
+
_voice_ref_path = None
|
| 279 |
+
_voice_ref_text = ""
|
| 280 |
+
return "🗑️ Voice reference cleared — using default MMS-TTS voice."
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
from src.tts.f5_tts import to_wav_24k
|
| 284 |
+
wav_path = to_wav_24k(audio_file)
|
| 285 |
+
_voice_ref_path = wav_path
|
| 286 |
+
|
| 287 |
+
# Auto-transcribe using already-loaded Whisper if available
|
| 288 |
+
if _whisper_model is not None and _whisper_processor is not None:
|
| 289 |
+
import torch, librosa
|
| 290 |
+
audio_np, _ = librosa.load(wav_path, sr=16000, mono=True)
|
| 291 |
+
with _model_lock:
|
| 292 |
+
inputs = _whisper_processor.feature_extractor(
|
| 293 |
+
audio_np, sampling_rate=16000, return_tensors="pt"
|
| 294 |
+
)
|
| 295 |
+
with torch.no_grad():
|
| 296 |
+
ids = _whisper_model.generate(
|
| 297 |
+
inputs.input_features,
|
| 298 |
+
max_new_tokens=128,
|
| 299 |
+
)
|
| 300 |
+
_voice_ref_text = _whisper_processor.batch_decode(
|
| 301 |
+
ids, skip_special_tokens=True
|
| 302 |
+
)[0].strip()
|
| 303 |
+
return (
|
| 304 |
+
f"✅ Voice reference set!\n"
|
| 305 |
+
f"File : {Path(audio_file).name}\n"
|
| 306 |
+
f"Transcript : {_voice_ref_text[:80] or '(empty — F5-TTS will use in-context inference)'}"
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
_voice_ref_text = ""
|
| 310 |
+
return (
|
| 311 |
+
f"✅ Voice reference set (model not loaded yet — transcript pending).\n"
|
| 312 |
+
f"File: {Path(audio_file).name}"
|
| 313 |
+
)
|
| 314 |
+
except Exception as exc:
|
| 315 |
+
return f"❌ Could not process reference audio: {exc}"
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@_gpu
|
| 319 |
+
def _convo_pipeline(audio_path: str, language_code: str):
|
| 320 |
+
"""
|
| 321 |
+
Full S2S conversation pipeline:
|
| 322 |
+
1. ASR — fine-tuned Whisper → transcript
|
| 323 |
+
2. Norm — bam_normalize() on Bambara input
|
| 324 |
+
3. Brain — LLM (Qwen) with Bambara phonetic system prompt → response text
|
| 325 |
+
4. Mouth — F5-TTS with voice reference (or MMS-TTS fallback) → audio
|
| 326 |
+
|
| 327 |
+
Returns same 4-tuple as _run_pipeline.
|
| 328 |
+
"""
|
| 329 |
+
import torch
|
| 330 |
+
|
| 331 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 332 |
+
|
| 333 |
+
if _whisper_model is None:
|
| 334 |
+
return "⏳ Model still loading…", "", "", None
|
| 335 |
+
|
| 336 |
+
import librosa
|
| 337 |
+
audio_np, _ = librosa.load(audio_path, sr=16000, mono=True)
|
| 338 |
+
|
| 339 |
+
active_model = _fine_tuned_models.get(language_code, _whisper_model)
|
| 340 |
+
active_model.to(device)
|
| 341 |
+
|
| 342 |
+
with _model_lock:
|
| 343 |
+
inputs = _whisper_processor.feature_extractor(
|
| 344 |
+
audio_np, sampling_rate=16000, return_tensors="pt"
|
| 345 |
+
)
|
| 346 |
+
input_features = inputs.input_features.to(device)
|
| 347 |
+
|
| 348 |
+
forced_ids = None
|
| 349 |
+
if language_code not in ("bam", "ful"):
|
| 350 |
+
forced_ids = _whisper_processor.get_decoder_prompt_ids(
|
| 351 |
+
language=language_code, task="transcribe"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
predicted_ids = active_model.generate(
|
| 356 |
+
input_features,
|
| 357 |
+
forced_decoder_ids=forced_ids if forced_ids else None,
|
| 358 |
+
max_new_tokens=256,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
transcript = _whisper_processor.batch_decode(
|
| 362 |
+
predicted_ids, skip_special_tokens=True
|
| 363 |
+
)[0].strip()
|
| 364 |
+
|
| 365 |
+
active_model.to("cpu")
|
| 366 |
+
if device == "cuda":
|
| 367 |
+
torch.cuda.empty_cache()
|
| 368 |
+
|
| 369 |
+
# Phonetic normalisation for Bambara (unifies ou→u etc.)
|
| 370 |
+
normalised = bam_normalize(transcript) if language_code == "bam" else transcript
|
| 371 |
+
|
| 372 |
+
# ── LLM brain ─────────────────────────────────────────────────────────────
|
| 373 |
+
try:
|
| 374 |
+
from huggingface_hub import InferenceClient
|
| 375 |
+
client = InferenceClient(token=HF_TOKEN)
|
| 376 |
+
completion = client.chat_completion(
|
| 377 |
+
model=LLM_MODEL_ID,
|
| 378 |
+
messages=[
|
| 379 |
+
{"role": "system", "content": _BAM_CONVO_SYSTEM},
|
| 380 |
+
{"role": "user", "content": normalised},
|
| 381 |
+
],
|
| 382 |
+
max_tokens=256,
|
| 383 |
+
temperature=0.6,
|
| 384 |
+
)
|
| 385 |
+
response_text = completion.choices[0].message.content.strip()
|
| 386 |
+
except Exception as llm_err:
|
| 387 |
+
response_text = normalised # echo transcript if LLM fails
|
| 388 |
+
import logging
|
| 389 |
+
logging.getLogger(__name__).warning("LLM failed: %s", llm_err)
|
| 390 |
+
|
| 391 |
+
# ── TTS mouth — F5-TTS preferred, MMS-TTS fallback ────────────────────────
|
| 392 |
+
audio_out = None
|
| 393 |
+
if _voice_ref_path and Path(_voice_ref_path).exists():
|
| 394 |
+
try:
|
| 395 |
+
from src.tts.f5_tts import synthesize as f5_synthesize
|
| 396 |
+
result = f5_synthesize(
|
| 397 |
+
response_text,
|
| 398 |
+
ref_wav_path=_voice_ref_path,
|
| 399 |
+
ref_text=_voice_ref_text,
|
| 400 |
+
device=device,
|
| 401 |
+
)
|
| 402 |
+
if result is not None:
|
| 403 |
+
wav_np, sr = result
|
| 404 |
+
audio_out = (sr, wav_np)
|
| 405 |
+
except Exception as tts_err:
|
| 406 |
+
import logging
|
| 407 |
+
logging.getLogger(__name__).warning("F5-TTS failed, falling back: %s", tts_err)
|
| 408 |
+
|
| 409 |
+
if audio_out is None:
|
| 410 |
+
# MMS-TTS fallback
|
| 411 |
+
wav_np, sr = _tts.synthesize(response_text, language_code, device=device)
|
| 412 |
+
audio_out = (sr, wav_np)
|
| 413 |
+
|
| 414 |
+
return transcript, "", response_text, audio_out
|
| 415 |
+
|
| 416 |
+
|
| 417 |
# ── HF Hub feedback persistence ───────────────────────────────────────────────
|
| 418 |
|
| 419 |
def _save_feedback_to_hub(
|
|
|
|
| 1114 |
|
| 1115 |
# ── Main ask handler ──────────────────────────────────────────────────────────
|
| 1116 |
|
| 1117 |
+
def handle_ask(audio_path, language_label, convo_mode: bool = False):
|
| 1118 |
if audio_path is None:
|
| 1119 |
return "⚠️ No audio — press Record or upload a file.", "", "", None
|
| 1120 |
|
|
|
|
| 1125 |
return f"⏳ Model loading ({status}). Wait a moment and try again.", "", "", None
|
| 1126 |
|
| 1127 |
try:
|
| 1128 |
+
if convo_mode:
|
| 1129 |
+
transcript, eng, response_text, audio_out = _convo_pipeline(audio_path, language_code)
|
| 1130 |
+
else:
|
| 1131 |
+
transcript, eng, response_text, audio_out = _run_pipeline(audio_path, language_code)
|
| 1132 |
+
return transcript, eng, response_text, audio_out
|
| 1133 |
except Exception as e:
|
| 1134 |
return f"❌ {e}", "", "", None
|
| 1135 |
|
|
|
|
| 1157 |
|
| 1158 |
# ── Tab 1: Voice Assistant ────────────────────────────────────────
|
| 1159 |
with gr.TabItem("🎙️ Voice Assistant", id="tab_voice"):
|
| 1160 |
+
|
| 1161 |
+
# ── Conversation Mode controls (top bar) ─────────────────────
|
| 1162 |
+
with gr.Row():
|
| 1163 |
+
convo_mode_toggle = gr.Checkbox(
|
| 1164 |
+
value=False,
|
| 1165 |
+
label="🔄 Conversation Mode — AI responds with LLM + cloned voice",
|
| 1166 |
+
info="When ON: mic auto-submits on stop; AI replies via LLM + F5-TTS (requires voice reference below).",
|
| 1167 |
+
)
|
| 1168 |
+
|
| 1169 |
+
with gr.Accordion("🎤 Voice Reference — upload an MP3/WAV of the target speaker", open=False):
|
| 1170 |
+
gr.Markdown(
|
| 1171 |
+
"Upload **5–30 seconds** of clear speech in the target voice. "
|
| 1172 |
+
"The AI will speak all its responses using this voice. "
|
| 1173 |
+
"Requires `f5-tts` and a GPU — falls back to MMS-TTS otherwise."
|
| 1174 |
+
)
|
| 1175 |
+
with gr.Row():
|
| 1176 |
+
voice_ref_input = gr.Audio(
|
| 1177 |
+
sources=["upload"],
|
| 1178 |
+
type="filepath",
|
| 1179 |
+
label="Reference audio (MP3 or WAV)",
|
| 1180 |
+
)
|
| 1181 |
+
voice_ref_status = gr.Textbox(
|
| 1182 |
+
label="Status", interactive=False, lines=3
|
| 1183 |
+
)
|
| 1184 |
+
voice_ref_btn = gr.Button("💾 Set as Voice Reference", variant="secondary")
|
| 1185 |
+
voice_ref_btn.click(
|
| 1186 |
+
fn=set_voice_reference,
|
| 1187 |
+
inputs=[voice_ref_input],
|
| 1188 |
+
outputs=[voice_ref_status],
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
gr.Markdown("---")
|
| 1192 |
+
|
| 1193 |
with gr.Row():
|
| 1194 |
with gr.Column(scale=1):
|
| 1195 |
language_dd = gr.Dropdown(
|
|
|
|
| 1212 |
interactive=False,
|
| 1213 |
)
|
| 1214 |
translation_box = gr.Textbox(
|
| 1215 |
+
label="English translation (hidden in Conversation Mode)",
|
| 1216 |
lines=2,
|
| 1217 |
placeholder="English meaning will appear here…",
|
| 1218 |
interactive=False,
|
| 1219 |
)
|
| 1220 |
response_box = gr.Textbox(
|
| 1221 |
+
label="AI response",
|
| 1222 |
lines=2,
|
| 1223 |
+
placeholder="Response will appear here…",
|
| 1224 |
interactive=False,
|
| 1225 |
)
|
| 1226 |
audio_output = gr.Audio(
|
|
|
|
| 1234 |
size="sm",
|
| 1235 |
)
|
| 1236 |
|
| 1237 |
+
_ask_inputs = [audio_input, language_dd, convo_mode_toggle]
|
| 1238 |
+
_ask_outputs = [transcript_box, translation_box, response_box, audio_output]
|
| 1239 |
+
|
| 1240 |
+
# Manual button click
|
| 1241 |
ask_btn.click(
|
| 1242 |
fn=handle_ask,
|
| 1243 |
+
inputs=_ask_inputs,
|
| 1244 |
+
outputs=_ask_outputs,
|
| 1245 |
+
)
|
| 1246 |
+
# Auto-submit when mic recording stops (Conversation Mode only)
|
| 1247 |
+
audio_input.stop_recording(
|
| 1248 |
+
fn=lambda ap, ll, cm: handle_ask(ap, ll, cm) if cm else (None, None, None, None),
|
| 1249 |
+
inputs=_ask_inputs,
|
| 1250 |
+
outputs=_ask_outputs,
|
| 1251 |
)
|
| 1252 |
|
| 1253 |
# ── Tab 2: Feedback & Correction ─────────────────────────────────
|
|
@@ -127,7 +127,9 @@
|
|
| 127 |
"id": "cell-clean",
|
| 128 |
"metadata": {},
|
| 129 |
"outputs": [],
|
| 130 |
-
"source":
|
|
|
|
|
|
|
| 131 |
},
|
| 132 |
{
|
| 133 |
"cell_type": "code",
|
|
@@ -135,7 +137,9 @@
|
|
| 135 |
"id": "cell-prepare",
|
| 136 |
"metadata": {},
|
| 137 |
"outputs": [],
|
| 138 |
-
"source":
|
|
|
|
|
|
|
| 139 |
},
|
| 140 |
{
|
| 141 |
"cell_type": "code",
|
|
@@ -180,7 +184,7 @@
|
|
| 180 |
"metadata": {},
|
| 181 |
"outputs": [],
|
| 182 |
"source": [
|
| 183 |
-
"# -- Cell 14: Data collator +
|
| 184 |
]
|
| 185 |
},
|
| 186 |
{
|
|
@@ -196,7 +200,7 @@
|
|
| 196 |
"metadata": {},
|
| 197 |
"outputs": [],
|
| 198 |
"source": [
|
| 199 |
-
"# -- Cell 15: Training arguments ----------------------------------------------\nimport inspect\nfrom transformers import Seq2SeqTrainingArguments\n\n# transformers 4.x used 'evaluation_strategy'; 4.45+ renamed to 'eval_strategy'.\n# Detect which name this installed version accepts.\n_params = inspect.signature(Seq2SeqTrainingArguments.__init__).parameters\n_eval_key = 'eval_strategy' if 'eval_strategy' in _params else 'evaluation_strategy'\n\ntraining_args = Seq2SeqTrainingArguments(\n output_dir=OUTPUT_DIR,\n\n max_steps=MAX_STEPS,\n warmup_steps=WARMUP_STEPS,\n logging_steps=LOGGING_STEPS,\n save_steps=SAVE_STEPS,\n eval_steps=EVAL_STEPS,\n\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=8,\n gradient_accumulation_steps=GRAD_ACCUM,\n\n fp16=True,\n gradient_checkpointing=True, # reduces activation memory on T4\n\n learning_rate=LEARNING_RATE,\n lr_scheduler_type='cosine',\n weight_decay=0.0,\n adam_beta1=0.9,\n adam_beta2=0.98,\n adam_epsilon=1e-6,\n\n **{_eval_key: 'steps'},\n predict_with_generate=True,\n generation_max_length=225,\n load_best_model_at_end=True,\n metric_for_best_model='
|
| 200 |
]
|
| 201 |
},
|
| 202 |
{
|
|
@@ -222,7 +226,7 @@
|
|
| 222 |
"metadata": {},
|
| 223 |
"outputs": [],
|
| 224 |
"source": [
|
| 225 |
-
"# ── Cell 17: WER evaluation ───────────────────────────────────────────────────\nprint('Running full evaluation on eval split ...')\neval_results = trainer.evaluate()\n\nwer_score = eval_results.get('eval_wer', float('nan'))\nprint(f'\
|
| 226 |
]
|
| 227 |
},
|
| 228 |
{
|
|
@@ -248,7 +252,7 @@
|
|
| 248 |
"metadata": {},
|
| 249 |
"outputs": [],
|
| 250 |
"source": [
|
| 251 |
-
"# ── Cell 19: Push adapter to HF Model repo ───────────────────────────────────\nfrom huggingface_hub import HfApi, create_repo\n\n# Ensure repo exists\ncreate_repo(ADAPTER_REPO_ID, repo_type='model', private=True,\n exist_ok=True, token=HF_TOKEN)\n\
|
| 252 |
]
|
| 253 |
},
|
| 254 |
{
|
|
@@ -258,7 +262,7 @@
|
|
| 258 |
"metadata": {},
|
| 259 |
"outputs": [],
|
| 260 |
"source": [
|
| 261 |
-
"# ── Cell 20: Verification summary ────────────────────────────────────────────\nfrom huggingface_hub import list_repo_files\n\nprint('=' * 60)\nprint('DEEP SLEEP TRAINING — COMPLETE')\nprint('=' * 60)\nprint(f' Language : {TRAIN_LANG} ({LANG_NAME})')\nprint(f' Model : {WHISPER_MODEL_ID}')\nprint(f' Steps completed : {train_result.global_step}')\nprint(f' Train loss : {train_result.training_loss:.4f}')\n_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\nprint(f' Eval WER
|
| 262 |
]
|
| 263 |
}
|
| 264 |
]
|
|
|
|
| 127 |
"id": "cell-clean",
|
| 128 |
"metadata": {},
|
| 129 |
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\nimport re, unicodedata\n\n# Phonetic normaliser: unifies French-influenced spellings before training.\n# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','ɲ'),('ny','ɲ'),('ch','c'),('oo','ɔ'),('ee','ɛ')]\n_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n\ndef _bam_norm(text):\n import unicodedata as _ud\n text = _ud.normalize('NFC', text.lower())\n return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n\n\n_BAMBARA_EXTRA = {'\\u025b','\\u0254','\\u014b'}\n_FULA_EXTRA = {'\\u0253','\\u0257','\\u01b4','\\u014b','\\u0272'}\n_BASE_LATIN = set('abcdefghijklmnopqrstuvwxyz')\n_ACCENTED = set('\\u00e0\\u00e2\\u00e4\\u00e8\\u00e9\\u00ea\\u00eb'\n '\\u00ee\\u00ef\\u00f4\\u00f9\\u00fb\\u00fc\\u00fd'\n '\\u00ff\\u00e6\\u0153\\u00e7')\n_KEEP_PUNCT = set(\" ',-.'!?\")\n\n_VALID_CHARS = {\n 'bam': _BASE_LATIN | _ACCENTED | _BAMBARA_EXTRA | _KEEP_PUNCT,\n 'ful': _BASE_LATIN | _ACCENTED | _FULA_EXTRA | _KEEP_PUNCT,\n}\n\n\ndef clean_text(text: str, lang: str = 'bam') -> str:\n if not text:\n return ''\n text = unicodedata.normalize('NFKC', text.lower().strip())\n text = re.sub(r'https?://\\S+', '', text)\n text = re.sub(r'<[^>]+>', '', text)\n text = re.sub(r'([.,!?])\\1+', r'\\1', text)\n valid = _VALID_CHARS.get(lang, _VALID_CHARS['bam'] | _VALID_CHARS['ful'])\n text = ''.join(c for c in text if c in valid)\n return re.sub(r'\\s+', ' ', text).strip()\n\n\n# Verify actual output then assert against it\nr1 = clean_text('I ni ce! (hello)', 'bam') # parens stripped, ! kept\nr2 = clean_text('Jam waali. <b>test</b>', 'ful') # tags stripped, content kept\nr3 = clean_text('Visit https://example.com now!!', 'bam') # URL stripped, word before stays\n\nassert r1 == 'i ni ce! hello', f'r1: {repr(r1)}'\nassert r2 == 'jam waali. test', f'r2: {repr(r2)}'\nassert r3 == 'visit now!', f'r3: {repr(r3)}'\n\nprint('clean_text tests passed')\nprint(f' {repr(r1)}')\nprint(f' {repr(r2)}')\nprint(f' {repr(r3)}')"
|
| 132 |
+
]
|
| 133 |
},
|
| 134 |
{
|
| 135 |
"cell_type": "code",
|
|
|
|
| 137 |
"id": "cell-prepare",
|
| 138 |
"metadata": {},
|
| 139 |
"outputs": [],
|
| 140 |
+
"source": [
|
| 141 |
+
"# -- Cell 11: Whisper processor + prepare_dataset -----------------------------\n# WhisperProcessor imports processing_utils -> image_utils -> torchvision,\n# which crashes when torch/torchvision have mismatched CUDA versions.\n# Fix: build the processor manually from its two sub-components.\n# WhisperFeatureExtractor and WhisperTokenizer have no torchvision dependency.\nimport numpy as np\n\nfrom transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor\nfrom transformers.models.whisper.tokenization_whisper import WhisperTokenizer\n\nprint(f'Loading Whisper feature extractor + tokenizer: {WHISPER_MODEL_ID} ...')\n_feat_ext = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n_tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n\n\nclass _Processor:\n \"\"\"Minimal WhisperProcessor substitute that avoids the torchvision import chain.\"\"\"\n def __init__(self, feature_extractor, tokenizer):\n self.feature_extractor = feature_extractor\n self.tokenizer = tokenizer\n\n def get_decoder_prompt_ids(self, language, task='transcribe'):\n return self.tokenizer.get_decoder_prompt_ids(language=language, task=task)\n\n def save_pretrained(self, path):\n self.feature_extractor.save_pretrained(path)\n self.tokenizer.save_pretrained(path)\n\n\nprocessor = _Processor(_feat_ext, _tokenizer)\nprint('Processor ready')\n\n\ndef prepare_dataset(batch, text_col='transcription', lang=TRAIN_LANG):\n \"\"\"\n Resample to 16 kHz, extract log-mel features, tokenise text.\n Works on any dict with 'audio' (HF Audio column) and a text column.\n \"\"\"\n audio = batch['audio']\n audio_array = np.array(audio['array'], dtype=np.float32)\n orig_sr = audio['sampling_rate']\n\n if orig_sr != TARGET_SR:\n try:\n import torchaudio.functional as F_audio, torch\n audio_array = F_audio.resample(\n torch.from_numpy(audio_array).unsqueeze(0),\n orig_sr, TARGET_SR,\n ).squeeze(0).numpy()\n except Exception:\n import librosa\n audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=TARGET_SR)\n\n batch['input_features'] = processor.feature_extractor(\n audio_array, sampling_rate=TARGET_SR\n ).input_features[0]\n\n raw_text = batch.get(text_col, '') or ''\n _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n cleaned = clean_text(_norm_text, lang=lang)\n batch['labels'] = processor.tokenizer(cleaned).input_ids\n return batch\n\n\nprint('prepare_dataset ready')"
|
| 142 |
+
]
|
| 143 |
},
|
| 144 |
{
|
| 145 |
"cell_type": "code",
|
|
|
|
| 184 |
"metadata": {},
|
| 185 |
"outputs": [],
|
| 186 |
"source": [
|
| 187 |
+
"# -- Cell 14: Data collator + CER metric --------------------------------------\nimport jiwer\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List\n\ntransform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n jiwer.ReduceToListOfListOfWords(),\n])\n\n# CER transform (no word-split step needed)\n_cer_transform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n])\n\n\n@dataclass\nclass DataCollatorSpeechSeq2SeqWithPadding:\n processor: Any\n\n def __call__(self, features: List[Dict]) -> Dict:\n import torch\n input_feats = [{'input_features': f['input_features']} for f in features]\n batch = self.processor.feature_extractor.pad(input_feats, return_tensors='pt')\n\n # Leave features in fp32 -- AMP (fp16=True in TrainingArgs) handles casting\n\n label_feats = [{'input_ids': f['labels']} for f in features]\n labels_batch = self.processor.tokenizer.pad(label_feats, return_tensors='pt')\n labels = labels_batch['input_ids'].masked_fill(\n labels_batch.attention_mask.ne(1), -100\n )\n if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().item():\n labels = labels[:, 1:]\n batch['labels'] = labels\n return batch\n\n\ndef compute_metrics(pred):\n pred_ids = pred.predictions\n label_ids = pred.label_ids\n label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n\n cer = jiwer.cer(\n label_str, pred_str,\n reference_transform=_cer_transform,\n hypothesis_transform=_cer_transform,\n )\n wer = jiwer.wer(label_str, pred_str,\n hypothesis_transform=transform,\n reference_transform=transform)\n return {'cer': round(cer, 4), 'wer': round(wer, 4)}\n\n\ncollator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\nprint('Collator and WER metric ready')"
|
| 188 |
]
|
| 189 |
},
|
| 190 |
{
|
|
|
|
| 200 |
"metadata": {},
|
| 201 |
"outputs": [],
|
| 202 |
"source": [
|
| 203 |
+
"# -- Cell 15: Training arguments ----------------------------------------------\nimport inspect\nfrom transformers import Seq2SeqTrainingArguments\n\n# transformers 4.x used 'evaluation_strategy'; 4.45+ renamed to 'eval_strategy'.\n# Detect which name this installed version accepts.\n_params = inspect.signature(Seq2SeqTrainingArguments.__init__).parameters\n_eval_key = 'eval_strategy' if 'eval_strategy' in _params else 'evaluation_strategy'\n\ntraining_args = Seq2SeqTrainingArguments(\n output_dir=OUTPUT_DIR,\n\n max_steps=MAX_STEPS,\n warmup_steps=WARMUP_STEPS,\n logging_steps=LOGGING_STEPS,\n save_steps=SAVE_STEPS,\n eval_steps=EVAL_STEPS,\n\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=8,\n gradient_accumulation_steps=GRAD_ACCUM,\n\n fp16=True,\n gradient_checkpointing=True, # reduces activation memory on T4\n\n learning_rate=LEARNING_RATE,\n lr_scheduler_type='cosine',\n weight_decay=0.0,\n adam_beta1=0.9,\n adam_beta2=0.98,\n adam_epsilon=1e-6,\n\n **{_eval_key: 'steps'},\n predict_with_generate=True,\n generation_max_length=225,\n load_best_model_at_end=True,\n metric_for_best_model='cer',\n greater_is_better=False,\n\n save_total_limit=3,\n save_strategy='steps',\n\n report_to=['tensorboard'], # tensorboard logs to OUTPUT_DIR/runs by default\n push_to_hub=False,\n)\n\nprint(f'Training arguments ready (using {_eval_key}=steps)')\nprint(f' Effective batch size: {BATCH_SIZE * GRAD_ACCUM}')\nprint(f' Max steps : {MAX_STEPS}')\n"
|
| 204 |
]
|
| 205 |
},
|
| 206 |
{
|
|
|
|
| 226 |
"metadata": {},
|
| 227 |
"outputs": [],
|
| 228 |
"source": [
|
| 229 |
+
"# ── Cell 17: WER evaluation ───────────────────────────────────────────────────\nprint('Running full evaluation on eval split ...')\neval_results = trainer.evaluate()\n\ncer_score = eval_results.get('eval_cer', float('nan'))\nwer_score = eval_results.get('eval_wer', float('nan'))\nprint(f'\n✅ Final CER : {cer_score:.1%} (primary — lower is better)')\nprint(f' Final WER : {wer_score:.1%} (secondary)')\nprint(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')\n# Show a few example transcriptions side-by-side\nimport random, torch\nprint('\\n── Sample predictions ───────────────────────────────')\nsamples = random.sample(range(len(eval_ds)), min(5, len(eval_ds)))\nfor idx in samples:\n item = eval_ds[idx]\n feats = torch.tensor(item['input_features']).unsqueeze(0).to(model.device)\n with torch.no_grad():\n pred_ids = model.generate(\n feats, # fp32 to match model dtype\n max_new_tokens=128,\n )\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]\n labels = [t if t != -100 else processor.tokenizer.pad_token_id\n for t in item['labels']]\n ref_str = processor.tokenizer.decode(labels, skip_special_tokens=True)\n print(f' Ref : {ref_str}')\n print(f' Pred: {pred_str}')\n print()"
|
| 230 |
]
|
| 231 |
},
|
| 232 |
{
|
|
|
|
| 252 |
"metadata": {},
|
| 253 |
"outputs": [],
|
| 254 |
"source": [
|
| 255 |
+
"# ── Cell 19: Push adapter to HF Model repo ───────────────────────────────────\nfrom huggingface_hub import HfApi, create_repo\n\n# Ensure repo exists\ncreate_repo(ADAPTER_REPO_ID, repo_type='model', private=True,\n exist_ok=True, token=HF_TOKEN)\n\n_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\ncommit_msg = (\n f'[{VERSION_TAG}] {LANG_NAME} fine-tuned checkpoint — '\n f'{train_result.global_step} steps | CER {_cer_part} | '\n f'{len(correction_records)} corrections + WaxalNLP'\n)\n\napi.upload_folder(\n folder_path=OUTPUT_DIR,\n repo_id=ADAPTER_REPO_ID,\n repo_type='model',\n path_in_repo=PATH_IN_REPO,\n commit_message=commit_msg,\n)\nprint(f'✅ Adapter uploaded: {ADAPTER_REPO_ID}/{PATH_IN_REPO}')\n\n# Create a Git tag for this version\ntry:\n api.create_tag(\n repo_id=ADAPTER_REPO_ID,\n repo_type='model',\n tag=VERSION_TAG,\n tag_message=commit_msg,\n token=HF_TOKEN,\n )\n print(f'✅ Tag created : {VERSION_TAG}')\nexcept Exception as e:\n print(f'⚠️ Tag creation skipped: {e}')"
|
| 256 |
]
|
| 257 |
},
|
| 258 |
{
|
|
|
|
| 262 |
"metadata": {},
|
| 263 |
"outputs": [],
|
| 264 |
"source": [
|
| 265 |
+
"# ── Cell 20: Verification summary ────────────────────────────────────────────\nfrom huggingface_hub import list_repo_files\n\nprint('=' * 60)\nprint('DEEP SLEEP TRAINING — COMPLETE')\nprint('=' * 60)\nprint(f' Language : {TRAIN_LANG} ({LANG_NAME})')\nprint(f' Model : {WHISPER_MODEL_ID}')\nprint(f' Steps completed : {train_result.global_step}')\nprint(f' Train loss : {train_result.training_loss:.4f}')\n_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\nprint(f' Eval CER (primary) : {_cer_disp}')\nprint(f' Eval WER (secondary): {_wer_disp}')\nprint(f' Corrections used : {len(correction_records)} × {CORRECTION_REPEAT}')\nprint(f' WaxalNLP samples : up to {MAX_WAXAL_TRAIN}')\nprint(f' Version tag : {VERSION_TAG}')\nprint(f' HF repo : {ADAPTER_REPO_ID}/{PATH_IN_REPO}')\nprint()\n\n# List what was pushed\ntry:\n repo_files = sorted(list_repo_files(\n ADAPTER_REPO_ID, repo_type='model', token=HF_TOKEN\n ))\n adapter_files = [f for f in repo_files if f.startswith(f'adapters/{LANG_NAME}/')]\n print('Adapter files in repo:')\n for f in adapter_files:\n print(f' {f}')\nexcept Exception as e:\n print(f'Could not list repo files: {e}')\n\nprint()\nprint('Next steps:')\nprint(' 1. In your HF Space settings, confirm ADAPTER_REPO_ID secret is set')\nprint(f' 2. Tab 3 → Reload Adapters → select \"{VERSION_TAG}\"')\nprint(' 3. Collect more corrections in the Space, then re-run this notebook')"
|
| 266 |
]
|
| 267 |
}
|
| 268 |
]
|
|
@@ -52,6 +52,11 @@ scipy==1.15.2
|
|
| 52 |
# Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
|
| 53 |
rapidfuzz==3.13.0
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# Kaggle API (used by Self-Teaching tab to trigger training runs)
|
| 56 |
kaggle>=1.6.0
|
| 57 |
|
|
|
|
| 52 |
# Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
|
| 53 |
rapidfuzz==3.13.0
|
| 54 |
|
| 55 |
+
# Voice cloning — F5-TTS (flow-matching, language-agnostic, reference-speaker)
|
| 56 |
+
# Requires GPU at runtime (~750 MB model auto-downloaded on first use).
|
| 57 |
+
# Falls back to MMS-TTS gracefully when not installed or GPU unavailable.
|
| 58 |
+
f5-tts>=1.0.0
|
| 59 |
+
|
| 60 |
# Kaggle API (used by Self-Teaching tab to trigger training runs)
|
| 61 |
kaggle>=1.6.0
|
| 62 |
|
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER."""
|
| 2 |
+
import json, sys, re
|
| 3 |
+
sys.stdout.reconfigure(encoding="utf-8")
|
| 4 |
+
|
| 5 |
+
NB = "notebooks/kaggle_master_trainer.ipynb"
|
| 6 |
+
|
| 7 |
+
with open(NB, encoding="utf-8") as f:
|
| 8 |
+
nb = json.load(f)
|
| 9 |
+
cells = nb["cells"]
|
| 10 |
+
|
| 11 |
+
changed = []
|
| 12 |
+
|
| 13 |
+
# ── Cell 10 (idx=11): inject _bam_norm definition ────────────────────────────
|
| 14 |
+
old = "".join(cells[11]["source"])
|
| 15 |
+
OLD_TOP = (
|
| 16 |
+
"# -- Cell 10: Text cleaning utilities -----------------------------------------\n"
|
| 17 |
+
"import re, unicodedata"
|
| 18 |
+
)
|
| 19 |
+
NEW_TOP = (
|
| 20 |
+
"# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n"
|
| 21 |
+
"import re, unicodedata\n"
|
| 22 |
+
"\n"
|
| 23 |
+
"# Phonetic normaliser: unifies French-influenced spellings before training.\n"
|
| 24 |
+
"# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n"
|
| 25 |
+
"_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272'),"
|
| 26 |
+
"('ch','c'),('oo','\u0254'),('ee','\u025b')]\n"
|
| 27 |
+
"_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n"
|
| 28 |
+
"_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n"
|
| 29 |
+
"\n"
|
| 30 |
+
"def _bam_norm(text):\n"
|
| 31 |
+
" import unicodedata as _ud\n"
|
| 32 |
+
" text = _ud.normalize('NFC', text.lower())\n"
|
| 33 |
+
" return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n"
|
| 34 |
+
)
|
| 35 |
+
if OLD_TOP in old:
|
| 36 |
+
cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)]
|
| 37 |
+
changed.append("Cell 10: _bam_norm injected")
|
| 38 |
+
else:
|
| 39 |
+
changed.append("Cell 10: OLD_TOP not found - skip")
|
| 40 |
+
|
| 41 |
+
# ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ─────────────────────
|
| 42 |
+
old = "".join(cells[12]["source"])
|
| 43 |
+
OLD_PREP = " cleaned = clean_text(str(raw_text), lang=lang)"
|
| 44 |
+
NEW_PREP = (
|
| 45 |
+
" _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n"
|
| 46 |
+
" cleaned = clean_text(_norm_text, lang=lang)"
|
| 47 |
+
)
|
| 48 |
+
if OLD_PREP in old:
|
| 49 |
+
cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)]
|
| 50 |
+
changed.append("Cell 11: normaliser applied in prepare_dataset")
|
| 51 |
+
else:
|
| 52 |
+
changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})")
|
| 53 |
+
|
| 54 |
+
# ── Cell 14 (idx=17): WER -> CER in compute_metrics ──────────────────────────
|
| 55 |
+
old = "".join(cells[17]["source"])
|
| 56 |
+
# Replace header comment
|
| 57 |
+
new = old.replace(
|
| 58 |
+
"# -- Cell 14: Data collator + WER metric",
|
| 59 |
+
"# -- Cell 14: Data collator + CER metric"
|
| 60 |
+
)
|
| 61 |
+
# Add CER transform after existing transform definition
|
| 62 |
+
OLD_TRANSFORM_END = " jiwer.ReduceToListOfListOfWords(),\n])"
|
| 63 |
+
NEW_TRANSFORM_END = (
|
| 64 |
+
" jiwer.ReduceToListOfListOfWords(),\n"
|
| 65 |
+
"])\n"
|
| 66 |
+
"\n"
|
| 67 |
+
"# CER transform (no word-split step needed)\n"
|
| 68 |
+
"_cer_transform = jiwer.Compose([\n"
|
| 69 |
+
" jiwer.ToLowerCase(),\n"
|
| 70 |
+
" jiwer.RemoveMultipleSpaces(),\n"
|
| 71 |
+
" jiwer.Strip(),\n"
|
| 72 |
+
" jiwer.RemovePunctuation(),\n"
|
| 73 |
+
"])"
|
| 74 |
+
)
|
| 75 |
+
new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END)
|
| 76 |
+
# Replace return value in compute_metrics
|
| 77 |
+
OLD_RETURN = (
|
| 78 |
+
" wer = jiwer.wer(label_str, pred_str,\n"
|
| 79 |
+
" hypothesis_transform=transform,\n"
|
| 80 |
+
" reference_transform=transform)\n"
|
| 81 |
+
" return {'wer': round(wer, 4)}"
|
| 82 |
+
)
|
| 83 |
+
NEW_RETURN = (
|
| 84 |
+
" cer = jiwer.cer(\n"
|
| 85 |
+
" label_str, pred_str,\n"
|
| 86 |
+
" reference_transform=_cer_transform,\n"
|
| 87 |
+
" hypothesis_transform=_cer_transform,\n"
|
| 88 |
+
" )\n"
|
| 89 |
+
" wer = jiwer.wer(label_str, pred_str,\n"
|
| 90 |
+
" hypothesis_transform=transform,\n"
|
| 91 |
+
" reference_transform=transform)\n"
|
| 92 |
+
" return {'cer': round(cer, 4), 'wer': round(wer, 4)}"
|
| 93 |
+
)
|
| 94 |
+
new = new.replace(OLD_RETURN, NEW_RETURN)
|
| 95 |
+
if new != old:
|
| 96 |
+
cells[17]["source"] = [new]
|
| 97 |
+
changed.append("Cell 14: WER->CER in compute_metrics")
|
| 98 |
+
else:
|
| 99 |
+
changed.append("Cell 14: no changes applied")
|
| 100 |
+
|
| 101 |
+
# ── Cell 15 (idx=19): metric_for_best_model ───────────────────────────────────
|
| 102 |
+
old = "".join(cells[19]["source"])
|
| 103 |
+
new = old.replace(
|
| 104 |
+
" metric_for_best_model='wer',",
|
| 105 |
+
" metric_for_best_model='cer',"
|
| 106 |
+
)
|
| 107 |
+
if new != old:
|
| 108 |
+
cells[19]["source"] = [new]
|
| 109 |
+
changed.append("Cell 15: metric_for_best_model=cer")
|
| 110 |
+
else:
|
| 111 |
+
changed.append("Cell 15: no change")
|
| 112 |
+
|
| 113 |
+
# ── Cell 17 (idx=22): CER display in evaluation ───────────────────────────────
|
| 114 |
+
old = "".join(cells[22]["source"])
|
| 115 |
+
OLD_WER_PRINT = (
|
| 116 |
+
"wer_score = eval_results.get('eval_wer', float('nan'))\n"
|
| 117 |
+
"print(f'\\n? Final WER : {wer_score:.1%}')\n"
|
| 118 |
+
"print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
|
| 119 |
+
)
|
| 120 |
+
NEW_WER_PRINT = (
|
| 121 |
+
"cer_score = eval_results.get('eval_cer', float('nan'))\n"
|
| 122 |
+
"wer_score = eval_results.get('eval_wer', float('nan'))\n"
|
| 123 |
+
"print(f'\\n\u2705 Final CER : {cer_score:.1%} (primary — lower is better)')\n"
|
| 124 |
+
"print(f' Final WER : {wer_score:.1%} (secondary)')\n"
|
| 125 |
+
"print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
|
| 126 |
+
)
|
| 127 |
+
if OLD_WER_PRINT in old:
|
| 128 |
+
cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)]
|
| 129 |
+
changed.append("Cell 17: CER display")
|
| 130 |
+
else:
|
| 131 |
+
changed.append("Cell 17: print pattern not found")
|
| 132 |
+
# Try to find what's there
|
| 133 |
+
idx = old.find("wer_score")
|
| 134 |
+
if idx >= 0:
|
| 135 |
+
changed.append(f" ...found: {repr(old[idx:idx+100])}")
|
| 136 |
+
|
| 137 |
+
# ── Cell 19 push (idx=25): cer_score in commit msg ───────────────────────────
|
| 138 |
+
old = "".join(cells[25]["source"])
|
| 139 |
+
new = (
|
| 140 |
+
old
|
| 141 |
+
.replace(
|
| 142 |
+
"_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'",
|
| 143 |
+
"_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'"
|
| 144 |
+
)
|
| 145 |
+
.replace(
|
| 146 |
+
"f'{train_result.global_step} steps | WER {_wer_part} | '",
|
| 147 |
+
"f'{train_result.global_step} steps | CER {_cer_part} | '"
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
if new != old:
|
| 151 |
+
cells[25]["source"] = [new]
|
| 152 |
+
changed.append("Cell 19: CER in commit msg")
|
| 153 |
+
else:
|
| 154 |
+
changed.append("Cell 19: no change")
|
| 155 |
+
|
| 156 |
+
# ── Cell 20 summary (idx=26) ─────────────────────────────────────────────────
|
| 157 |
+
old = "".join(cells[26]["source"])
|
| 158 |
+
new = old.replace(
|
| 159 |
+
"_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
|
| 160 |
+
"print(f' Eval WER : {_wer_disp}')",
|
| 161 |
+
"_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n"
|
| 162 |
+
"_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
|
| 163 |
+
"print(f' Eval CER (primary) : {_cer_disp}')\n"
|
| 164 |
+
"print(f' Eval WER (secondary): {_wer_disp}')"
|
| 165 |
+
)
|
| 166 |
+
if new != old:
|
| 167 |
+
cells[26]["source"] = [new]
|
| 168 |
+
changed.append("Cell 20: CER in summary")
|
| 169 |
+
else:
|
| 170 |
+
changed.append("Cell 20: no change")
|
| 171 |
+
|
| 172 |
+
with open(NB, "w", encoding="utf-8") as f:
|
| 173 |
+
json.dump(nb, f, ensure_ascii=False, indent=1)
|
| 174 |
+
|
| 175 |
+
for msg in changed:
|
| 176 |
+
print(msg)
|
| 177 |
+
print("Done.")
|
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bambara phonetic normalizer.
|
| 3 |
+
|
| 4 |
+
Unifies French-influenced and informal spellings to the standard
|
| 5 |
+
N'Ko-derived Bambara orthography used in most NLP datasets.
|
| 6 |
+
|
| 7 |
+
Key rules (most impactful for ASR training):
|
| 8 |
+
ou → u French vowel → Bambara standard
|
| 9 |
+
gn → ɲ French nasal palatal
|
| 10 |
+
ny → ɲ English nasal palatal notation
|
| 11 |
+
dj → j French palatal affricate
|
| 12 |
+
ch → c French palatalized consonant
|
| 13 |
+
oo → ɔ long open-o (common informal spelling)
|
| 14 |
+
ee → ɛ long open-e (common informal spelling)
|
| 15 |
+
|
| 16 |
+
These rules run left-to-right on lower-cased text. They are conservative:
|
| 17 |
+
only unambiguous substitutions are applied so as not to corrupt words that
|
| 18 |
+
happen to contain these letter sequences in a non-phonemic context.
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
from src.data.bam_normalize import normalize
|
| 22 |
+
text = normalize("I ni ce, a bɛ djourou la")
|
| 23 |
+
# → "i ni ce, a bɛ juruu la"
|
| 24 |
+
"""
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import re
|
| 28 |
+
import unicodedata
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ── Replacement table (order matters — longest match first) ─────────────────
|
| 32 |
+
_RULES: list[tuple[str, str]] = [
|
| 33 |
+
("ou", "u"), # most frequent French influence
|
| 34 |
+
("dj", "j"), # palatal affricate
|
| 35 |
+
("gn", "ɲ"), # nasal palatal (French orthography)
|
| 36 |
+
("ny", "ɲ"), # nasal palatal (English-style notation)
|
| 37 |
+
("ch", "c"), # palatalized stop
|
| 38 |
+
("oo", "ɔ"), # long open-o (informal doubling)
|
| 39 |
+
("ee", "ɛ"), # long open-e (informal doubling)
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
# Compile once for speed
|
| 43 |
+
_PATTERN = re.compile(
|
| 44 |
+
"|".join(re.escape(src) for src, _ in _RULES)
|
| 45 |
+
)
|
| 46 |
+
_REPLACEMENTS = {src: dst for src, dst in _RULES}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def normalize(text: str) -> str:
|
| 50 |
+
"""
|
| 51 |
+
Apply phonetic normalization to a Bambara text string.
|
| 52 |
+
|
| 53 |
+
Steps:
|
| 54 |
+
1. Unicode NFC normalization (collapse combining characters).
|
| 55 |
+
2. Lowercase.
|
| 56 |
+
3. Apply phoneme substitution rules.
|
| 57 |
+
4. Collapse multiple spaces.
|
| 58 |
+
"""
|
| 59 |
+
text = unicodedata.normalize("NFC", text)
|
| 60 |
+
text = text.lower()
|
| 61 |
+
text = _PATTERN.sub(lambda m: _REPLACEMENTS[m.group(0)], text)
|
| 62 |
+
text = re.sub(r" {2,}", " ", text).strip()
|
| 63 |
+
return text
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def normalize_batch(texts: list[str]) -> list[str]:
|
| 67 |
+
return [normalize(t) for t in texts]
|
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
F5-TTS voice cloning engine.
|
| 3 |
+
|
| 4 |
+
Generates speech in a target speaker's voice given a short reference WAV.
|
| 5 |
+
Falls back to None gracefully if f5-tts is not installed or the GPU is
|
| 6 |
+
unavailable — the caller then falls back to MMS-TTS.
|
| 7 |
+
|
| 8 |
+
Install:
|
| 9 |
+
pip install f5-tts>=1.0.0
|
| 10 |
+
|
| 11 |
+
Reference:
|
| 12 |
+
SWivid/F5-TTS (HuggingFace / GitHub)
|
| 13 |
+
Model: ~750 MB, downloaded on first use to HF cache.
|
| 14 |
+
"""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import threading
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
_lock = threading.Lock()
|
| 27 |
+
_model = None # F5TTS instance, loaded lazily
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_model():
|
| 31 |
+
global _model
|
| 32 |
+
if _model is not None:
|
| 33 |
+
return _model
|
| 34 |
+
with _lock:
|
| 35 |
+
if _model is None:
|
| 36 |
+
from f5_tts.api import F5TTS # type: ignore
|
| 37 |
+
_model = F5TTS(model_type="F5TTS")
|
| 38 |
+
logger.info("F5-TTS model loaded.")
|
| 39 |
+
return _model
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def synthesize(
|
| 43 |
+
text: str,
|
| 44 |
+
ref_wav_path: str,
|
| 45 |
+
ref_text: str = "",
|
| 46 |
+
speed: float = 1.0,
|
| 47 |
+
device: str = "cuda",
|
| 48 |
+
) -> Optional[Tuple[np.ndarray, int]]:
|
| 49 |
+
"""
|
| 50 |
+
Generate speech for `text` using `ref_wav_path` as the speaker reference.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
text: Text to synthesize (Bambara, Fula, French, or English).
|
| 54 |
+
ref_wav_path: Path to reference audio (WAV, 5–30 s of the target speaker).
|
| 55 |
+
ref_text: Transcript of the reference audio. If empty the model
|
| 56 |
+
uses in-context inference (slightly lower quality but still
|
| 57 |
+
good for voice matching).
|
| 58 |
+
speed: Speaking rate multiplier. 1.0 = normal.
|
| 59 |
+
device: "cuda" or "cpu". CPU is 30-60 s/sentence — use GPU.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
(waveform_float32, sample_rate) or None on failure.
|
| 63 |
+
"""
|
| 64 |
+
if not text.strip():
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
import torch
|
| 69 |
+
model = _load_model()
|
| 70 |
+
|
| 71 |
+
wav, sr, _ = model.infer(
|
| 72 |
+
ref_file=ref_wav_path,
|
| 73 |
+
ref_text=ref_text.strip(),
|
| 74 |
+
gen_text=text.strip(),
|
| 75 |
+
speed=speed,
|
| 76 |
+
target_rms=0.1,
|
| 77 |
+
cross_fade_duration=0.15,
|
| 78 |
+
nfe_step=32,
|
| 79 |
+
cfg_strength=2.0,
|
| 80 |
+
show_info=False,
|
| 81 |
+
progress=None,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if isinstance(wav, torch.Tensor):
|
| 85 |
+
wav = wav.cpu().float().numpy()
|
| 86 |
+
else:
|
| 87 |
+
wav = np.asarray(wav, dtype=np.float32)
|
| 88 |
+
|
| 89 |
+
return wav, int(sr)
|
| 90 |
+
|
| 91 |
+
except ImportError:
|
| 92 |
+
logger.warning(
|
| 93 |
+
"f5-tts not installed — voice cloning disabled. "
|
| 94 |
+
"Add 'f5-tts>=1.0.0' to requirements.txt."
|
| 95 |
+
)
|
| 96 |
+
return None
|
| 97 |
+
except Exception as exc:
|
| 98 |
+
logger.error("F5-TTS synthesis failed: %s", exc)
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def to_wav_24k(audio_path: str) -> str:
|
| 103 |
+
"""
|
| 104 |
+
Resample any audio file to 24 kHz mono WAV (F5-TTS preferred sample rate).
|
| 105 |
+
Returns the path to the converted file (same stem, .wav extension).
|
| 106 |
+
Modifies in-place if the input is already a WAV — otherwise writes a new file.
|
| 107 |
+
"""
|
| 108 |
+
import librosa
|
| 109 |
+
import soundfile as sf
|
| 110 |
+
|
| 111 |
+
out_path = str(Path(audio_path).with_suffix(".f5ref.wav"))
|
| 112 |
+
audio, _ = librosa.load(audio_path, sr=24_000, mono=True)
|
| 113 |
+
sf.write(out_path, audio, 24_000)
|
| 114 |
+
return out_path
|