jefffffff9 Claude Sonnet 4.6 commited on
Commit
8952fff
·
1 Parent(s): cd017e2

Phase 3: Voice-to-Voice S2S pipeline — F5-TTS, LLM brain, CER metric

Browse files

app.py:
- Add LLM_MODEL_ID env var (default: Qwen/Qwen2.5-72B-Instruct)
- Import GemmaClient + bam_normalize at module level
- Add _voice_ref_path / _voice_ref_text / _llm_client state
- Add set_voice_reference(): converts MP3->24kHz WAV, auto-transcribes
- Add _convo_pipeline(): ASR -> bam_normalize -> LLM (phonetic Bambara
system prompt) -> F5-TTS (voice ref) with MMS-TTS fallback
- handle_ask() accepts convo_mode=bool, routes to _convo_pipeline or
the original sensor pipeline accordingly
- Tab 1 UI: Conversation Mode toggle, Voice Reference upload accordion,
stop_recording auto-submit for true back-to-back loop

src/tts/f5_tts.py (new):
- Lazy-loaded F5TTS wrapper; synthesize() with ref_wav + ref_text
- to_wav_24k() resampler (F5-TTS needs 24 kHz input)
- Graceful fallback (returns None) when f5-tts not installed

src/data/bam_normalize.py (new):
- _bam_norm(): ou->u, dj->j, gn->ny, ch->c, oo->o-open, ee->e-open
- Used at inference (app.py) and training (notebook Cell 11)

requirements.txt: add f5-tts>=1.0.0

Notebook:
- Cell 10: _bam_norm() defined inline (no external dependency)
- Cell 11: apply _bam_norm before tokenisation in prepare_dataset
- Cell 14: compute_metrics returns {cer, wer}; CER is primary metric
- Cell 15: metric_for_best_model='cer'; best checkpoint = lowest CER
- Cell 17: show CER (primary) + WER (secondary) in evaluation output
- Cell 19: CER in Hub commit message
- Cell 20: CER in verification summary

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

app.py CHANGED
@@ -7,6 +7,7 @@ Environment variables (set in Space Settings → Secrets):
7
  FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private)
8
  ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private)
9
  WHISPER_MODEL_ID — default: openai/whisper-small
 
10
  KAGGLE_USERNAME — Kaggle username (for auto-trigger training)
11
  KAGGLE_KEY — Kaggle API key (for auto-trigger training)
12
  KAGGLE_KERNEL_SLUG — default: ous-sow/sahel-voice-master-trainer
@@ -36,6 +37,7 @@ ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapte
36
  # whisper-small: ~10s on cpu-basic, good multilingual quality.
37
  # Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
38
  WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
 
39
  KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "")
40
  KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "")
41
  KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer")
@@ -66,11 +68,18 @@ _fine_tuned_models = {} # lang_code -> WhisperForConditionalGeneration (ful
66
  _model_lock = threading.Lock()
67
  _model_status = "not loaded"
68
 
 
 
 
 
 
69
  from src.tts.mms_tts import MMSTTSEngine
70
  from src.iot.intent_parser import IntentParser
71
  from src.iot.sensor_bridge import SensorBridge
72
  from src.iot.voice_responder import VoiceResponder
73
  from src.conversation.phrase_matcher import PhraseMatcher
 
 
74
 
75
  _tts = MMSTTSEngine()
76
  _intent_parser = IntentParser()
@@ -237,6 +246,174 @@ def _run_pipeline(audio_path: str, language_code: str):
237
  return transcript, english_translation, response_text, (sample_rate, wav_np)
238
 
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  # ── HF Hub feedback persistence ───────────────────────────────────────────────
241
 
242
  def _save_feedback_to_hub(
@@ -937,7 +1114,7 @@ def _harvest_hf_dataset(lang_label: str, max_samples: int = 500) -> str:
937
 
938
  # ── Main ask handler ──────────────────────────────────────────────────────────
939
 
940
- def handle_ask(audio_path, language_label):
941
  if audio_path is None:
942
  return "⚠️ No audio — press Record or upload a file.", "", "", None
943
 
@@ -948,8 +1125,11 @@ def handle_ask(audio_path, language_label):
948
  return f"⏳ Model loading ({status}). Wait a moment and try again.", "", "", None
949
 
950
  try:
951
- transcript, english_translation, response_text, audio_out = _run_pipeline(audio_path, language_code)
952
- return transcript, english_translation, response_text, audio_out
 
 
 
953
  except Exception as e:
954
  return f"❌ {e}", "", "", None
955
 
@@ -977,6 +1157,39 @@ def build_ui() -> gr.Blocks:
977
 
978
  # ── Tab 1: Voice Assistant ────────────────────────────────────────
979
  with gr.TabItem("🎙️ Voice Assistant", id="tab_voice"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  with gr.Row():
981
  with gr.Column(scale=1):
982
  language_dd = gr.Dropdown(
@@ -999,15 +1212,15 @@ def build_ui() -> gr.Blocks:
999
  interactive=False,
1000
  )
1001
  translation_box = gr.Textbox(
1002
- label="English translation",
1003
  lines=2,
1004
  placeholder="English meaning will appear here…",
1005
  interactive=False,
1006
  )
1007
  response_box = gr.Textbox(
1008
- label="Response in your language",
1009
  lines=2,
1010
- placeholder="Agricultural advice will appear here…",
1011
  interactive=False,
1012
  )
1013
  audio_output = gr.Audio(
@@ -1021,10 +1234,20 @@ def build_ui() -> gr.Blocks:
1021
  size="sm",
1022
  )
1023
 
 
 
 
 
1024
  ask_btn.click(
1025
  fn=handle_ask,
1026
- inputs=[audio_input, language_dd],
1027
- outputs=[transcript_box, translation_box, response_box, audio_output],
 
 
 
 
 
 
1028
  )
1029
 
1030
  # ── Tab 2: Feedback & Correction ─────────────────────────────────
 
7
  FEEDBACK_REPO_ID — e.g. ous-sow/sahel-agri-feedback (dataset, private)
8
  ADAPTER_REPO_ID — e.g. ous-sow/sahel-agri-adapters (model, private)
9
  WHISPER_MODEL_ID — default: openai/whisper-small
10
+ LLM_MODEL_ID — default: Qwen/Qwen2.5-72B-Instruct
11
  KAGGLE_USERNAME — Kaggle username (for auto-trigger training)
12
  KAGGLE_KEY — Kaggle API key (for auto-trigger training)
13
  KAGGLE_KERNEL_SLUG — default: ous-sow/sahel-voice-master-trainer
 
37
  # whisper-small: ~10s on cpu-basic, good multilingual quality.
38
  # Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
39
  WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
40
+ LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "Qwen/Qwen2.5-72B-Instruct")
41
  KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "")
42
  KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "")
43
  KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer")
 
68
  _model_lock = threading.Lock()
69
  _model_status = "not loaded"
70
 
71
+ # ── Conversation-mode state ───────────────────────────────────────────────────
72
+ _voice_ref_path: str | None = None # path to 24 kHz WAV converted from user MP3
73
+ _voice_ref_text: str = "" # auto-transcribed text of reference audio
74
+ _llm_client = None # GemmaClient, lazy init
75
+
76
  from src.tts.mms_tts import MMSTTSEngine
77
  from src.iot.intent_parser import IntentParser
78
  from src.iot.sensor_bridge import SensorBridge
79
  from src.iot.voice_responder import VoiceResponder
80
  from src.conversation.phrase_matcher import PhraseMatcher
81
+ from src.llm.gemma_client import GemmaClient
82
+ from src.data.bam_normalize import normalize as bam_normalize
83
 
84
  _tts = MMSTTSEngine()
85
  _intent_parser = IntentParser()
 
246
  return transcript, english_translation, response_text, (sample_rate, wav_np)
247
 
248
 
249
+ # ── Conversation-mode helpers ─────────────────────────────────────────────────
250
+
251
+ # Bambara conversation system prompt — instructs LLM to respond phonetically
252
+ _BAM_CONVO_SYSTEM = """\
253
+ You are a friendly Bambara voice assistant. Rules you must follow:
254
+ 1. Always reply in Bambara, matching the user's informal spoken style.
255
+ 2. Use phonetic spelling: write 'u' instead of 'ou', 'j' instead of 'dj', \
256
+ 'c' instead of 'ch' — spell words as they sound when spoken aloud.
257
+ 3. Keep responses short: 1–3 sentences max. This is a voice conversation.
258
+ 4. Never add translations or explanations unless explicitly asked.
259
+ 5. If the user speaks French or English, switch to that language naturally."""
260
+
261
+
262
+ def _get_llm() -> GemmaClient:
263
+ global _llm_client
264
+ if _llm_client is None:
265
+ _llm_client = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
266
+ return _llm_client
267
+
268
+
269
+ def set_voice_reference(audio_file) -> str:
270
+ """
271
+ Store an uploaded audio file as the TTS voice reference.
272
+ Converts to 24 kHz WAV (F5-TTS requirement) and auto-transcribes.
273
+ Returns a status string for the UI.
274
+ """
275
+ global _voice_ref_path, _voice_ref_text
276
+
277
+ if audio_file is None:
278
+ _voice_ref_path = None
279
+ _voice_ref_text = ""
280
+ return "🗑️ Voice reference cleared — using default MMS-TTS voice."
281
+
282
+ try:
283
+ from src.tts.f5_tts import to_wav_24k
284
+ wav_path = to_wav_24k(audio_file)
285
+ _voice_ref_path = wav_path
286
+
287
+ # Auto-transcribe using already-loaded Whisper if available
288
+ if _whisper_model is not None and _whisper_processor is not None:
289
+ import torch, librosa
290
+ audio_np, _ = librosa.load(wav_path, sr=16000, mono=True)
291
+ with _model_lock:
292
+ inputs = _whisper_processor.feature_extractor(
293
+ audio_np, sampling_rate=16000, return_tensors="pt"
294
+ )
295
+ with torch.no_grad():
296
+ ids = _whisper_model.generate(
297
+ inputs.input_features,
298
+ max_new_tokens=128,
299
+ )
300
+ _voice_ref_text = _whisper_processor.batch_decode(
301
+ ids, skip_special_tokens=True
302
+ )[0].strip()
303
+ return (
304
+ f"✅ Voice reference set!\n"
305
+ f"File : {Path(audio_file).name}\n"
306
+ f"Transcript : {_voice_ref_text[:80] or '(empty — F5-TTS will use in-context inference)'}"
307
+ )
308
+ else:
309
+ _voice_ref_text = ""
310
+ return (
311
+ f"✅ Voice reference set (model not loaded yet — transcript pending).\n"
312
+ f"File: {Path(audio_file).name}"
313
+ )
314
+ except Exception as exc:
315
+ return f"❌ Could not process reference audio: {exc}"
316
+
317
+
318
+ @_gpu
319
+ def _convo_pipeline(audio_path: str, language_code: str):
320
+ """
321
+ Full S2S conversation pipeline:
322
+ 1. ASR — fine-tuned Whisper → transcript
323
+ 2. Norm — bam_normalize() on Bambara input
324
+ 3. Brain — LLM (Qwen) with Bambara phonetic system prompt → response text
325
+ 4. Mouth — F5-TTS with voice reference (or MMS-TTS fallback) → audio
326
+
327
+ Returns same 4-tuple as _run_pipeline.
328
+ """
329
+ import torch
330
+
331
+ device = "cuda" if torch.cuda.is_available() else "cpu"
332
+
333
+ if _whisper_model is None:
334
+ return "⏳ Model still loading…", "", "", None
335
+
336
+ import librosa
337
+ audio_np, _ = librosa.load(audio_path, sr=16000, mono=True)
338
+
339
+ active_model = _fine_tuned_models.get(language_code, _whisper_model)
340
+ active_model.to(device)
341
+
342
+ with _model_lock:
343
+ inputs = _whisper_processor.feature_extractor(
344
+ audio_np, sampling_rate=16000, return_tensors="pt"
345
+ )
346
+ input_features = inputs.input_features.to(device)
347
+
348
+ forced_ids = None
349
+ if language_code not in ("bam", "ful"):
350
+ forced_ids = _whisper_processor.get_decoder_prompt_ids(
351
+ language=language_code, task="transcribe"
352
+ )
353
+
354
+ with torch.no_grad():
355
+ predicted_ids = active_model.generate(
356
+ input_features,
357
+ forced_decoder_ids=forced_ids if forced_ids else None,
358
+ max_new_tokens=256,
359
+ )
360
+
361
+ transcript = _whisper_processor.batch_decode(
362
+ predicted_ids, skip_special_tokens=True
363
+ )[0].strip()
364
+
365
+ active_model.to("cpu")
366
+ if device == "cuda":
367
+ torch.cuda.empty_cache()
368
+
369
+ # Phonetic normalisation for Bambara (unifies ou→u etc.)
370
+ normalised = bam_normalize(transcript) if language_code == "bam" else transcript
371
+
372
+ # ── LLM brain ─────────────────────────────────────────────────────────────
373
+ try:
374
+ from huggingface_hub import InferenceClient
375
+ client = InferenceClient(token=HF_TOKEN)
376
+ completion = client.chat_completion(
377
+ model=LLM_MODEL_ID,
378
+ messages=[
379
+ {"role": "system", "content": _BAM_CONVO_SYSTEM},
380
+ {"role": "user", "content": normalised},
381
+ ],
382
+ max_tokens=256,
383
+ temperature=0.6,
384
+ )
385
+ response_text = completion.choices[0].message.content.strip()
386
+ except Exception as llm_err:
387
+ response_text = normalised # echo transcript if LLM fails
388
+ import logging
389
+ logging.getLogger(__name__).warning("LLM failed: %s", llm_err)
390
+
391
+ # ── TTS mouth — F5-TTS preferred, MMS-TTS fallback ────────────────────────
392
+ audio_out = None
393
+ if _voice_ref_path and Path(_voice_ref_path).exists():
394
+ try:
395
+ from src.tts.f5_tts import synthesize as f5_synthesize
396
+ result = f5_synthesize(
397
+ response_text,
398
+ ref_wav_path=_voice_ref_path,
399
+ ref_text=_voice_ref_text,
400
+ device=device,
401
+ )
402
+ if result is not None:
403
+ wav_np, sr = result
404
+ audio_out = (sr, wav_np)
405
+ except Exception as tts_err:
406
+ import logging
407
+ logging.getLogger(__name__).warning("F5-TTS failed, falling back: %s", tts_err)
408
+
409
+ if audio_out is None:
410
+ # MMS-TTS fallback
411
+ wav_np, sr = _tts.synthesize(response_text, language_code, device=device)
412
+ audio_out = (sr, wav_np)
413
+
414
+ return transcript, "", response_text, audio_out
415
+
416
+
417
  # ── HF Hub feedback persistence ───────────────────────────────────────────────
418
 
419
  def _save_feedback_to_hub(
 
1114
 
1115
  # ── Main ask handler ──────────────────────────────────────────────────────────
1116
 
1117
+ def handle_ask(audio_path, language_label, convo_mode: bool = False):
1118
  if audio_path is None:
1119
  return "⚠️ No audio — press Record or upload a file.", "", "", None
1120
 
 
1125
  return f"⏳ Model loading ({status}). Wait a moment and try again.", "", "", None
1126
 
1127
  try:
1128
+ if convo_mode:
1129
+ transcript, eng, response_text, audio_out = _convo_pipeline(audio_path, language_code)
1130
+ else:
1131
+ transcript, eng, response_text, audio_out = _run_pipeline(audio_path, language_code)
1132
+ return transcript, eng, response_text, audio_out
1133
  except Exception as e:
1134
  return f"❌ {e}", "", "", None
1135
 
 
1157
 
1158
  # ── Tab 1: Voice Assistant ────────────────────────────────────────
1159
  with gr.TabItem("🎙️ Voice Assistant", id="tab_voice"):
1160
+
1161
+ # ── Conversation Mode controls (top bar) ─────────────────────
1162
+ with gr.Row():
1163
+ convo_mode_toggle = gr.Checkbox(
1164
+ value=False,
1165
+ label="🔄 Conversation Mode — AI responds with LLM + cloned voice",
1166
+ info="When ON: mic auto-submits on stop; AI replies via LLM + F5-TTS (requires voice reference below).",
1167
+ )
1168
+
1169
+ with gr.Accordion("🎤 Voice Reference — upload an MP3/WAV of the target speaker", open=False):
1170
+ gr.Markdown(
1171
+ "Upload **5–30 seconds** of clear speech in the target voice. "
1172
+ "The AI will speak all its responses using this voice. "
1173
+ "Requires `f5-tts` and a GPU — falls back to MMS-TTS otherwise."
1174
+ )
1175
+ with gr.Row():
1176
+ voice_ref_input = gr.Audio(
1177
+ sources=["upload"],
1178
+ type="filepath",
1179
+ label="Reference audio (MP3 or WAV)",
1180
+ )
1181
+ voice_ref_status = gr.Textbox(
1182
+ label="Status", interactive=False, lines=3
1183
+ )
1184
+ voice_ref_btn = gr.Button("💾 Set as Voice Reference", variant="secondary")
1185
+ voice_ref_btn.click(
1186
+ fn=set_voice_reference,
1187
+ inputs=[voice_ref_input],
1188
+ outputs=[voice_ref_status],
1189
+ )
1190
+
1191
+ gr.Markdown("---")
1192
+
1193
  with gr.Row():
1194
  with gr.Column(scale=1):
1195
  language_dd = gr.Dropdown(
 
1212
  interactive=False,
1213
  )
1214
  translation_box = gr.Textbox(
1215
+ label="English translation (hidden in Conversation Mode)",
1216
  lines=2,
1217
  placeholder="English meaning will appear here…",
1218
  interactive=False,
1219
  )
1220
  response_box = gr.Textbox(
1221
+ label="AI response",
1222
  lines=2,
1223
+ placeholder="Response will appear here…",
1224
  interactive=False,
1225
  )
1226
  audio_output = gr.Audio(
 
1234
  size="sm",
1235
  )
1236
 
1237
+ _ask_inputs = [audio_input, language_dd, convo_mode_toggle]
1238
+ _ask_outputs = [transcript_box, translation_box, response_box, audio_output]
1239
+
1240
+ # Manual button click
1241
  ask_btn.click(
1242
  fn=handle_ask,
1243
+ inputs=_ask_inputs,
1244
+ outputs=_ask_outputs,
1245
+ )
1246
+ # Auto-submit when mic recording stops (Conversation Mode only)
1247
+ audio_input.stop_recording(
1248
+ fn=lambda ap, ll, cm: handle_ask(ap, ll, cm) if cm else (None, None, None, None),
1249
+ inputs=_ask_inputs,
1250
+ outputs=_ask_outputs,
1251
  )
1252
 
1253
  # ── Tab 2: Feedback & Correction ─────────────────────────────────
notebooks/kaggle_master_trainer.ipynb CHANGED
@@ -127,7 +127,9 @@
127
  "id": "cell-clean",
128
  "metadata": {},
129
  "outputs": [],
130
- "source": "# -- Cell 10: Text cleaning utilities -----------------------------------------\nimport re, unicodedata\n\n_BAMBARA_EXTRA = {'\\u025b','\\u0254','\\u014b'}\n_FULA_EXTRA = {'\\u0253','\\u0257','\\u01b4','\\u014b','\\u0272'}\n_BASE_LATIN = set('abcdefghijklmnopqrstuvwxyz')\n_ACCENTED = set('\\u00e0\\u00e2\\u00e4\\u00e8\\u00e9\\u00ea\\u00eb'\n '\\u00ee\\u00ef\\u00f4\\u00f9\\u00fb\\u00fc\\u00fd'\n '\\u00ff\\u00e6\\u0153\\u00e7')\n_KEEP_PUNCT = set(\" ',-.'!?\")\n\n_VALID_CHARS = {\n 'bam': _BASE_LATIN | _ACCENTED | _BAMBARA_EXTRA | _KEEP_PUNCT,\n 'ful': _BASE_LATIN | _ACCENTED | _FULA_EXTRA | _KEEP_PUNCT,\n}\n\n\ndef clean_text(text: str, lang: str = 'bam') -> str:\n if not text:\n return ''\n text = unicodedata.normalize('NFKC', text.lower().strip())\n text = re.sub(r'https?://\\S+', '', text)\n text = re.sub(r'<[^>]+>', '', text)\n text = re.sub(r'([.,!?])\\1+', r'\\1', text)\n valid = _VALID_CHARS.get(lang, _VALID_CHARS['bam'] | _VALID_CHARS['ful'])\n text = ''.join(c for c in text if c in valid)\n return re.sub(r'\\s+', ' ', text).strip()\n\n\n# Verify actual output then assert against it\nr1 = clean_text('I ni ce! (hello)', 'bam') # parens stripped, ! kept\nr2 = clean_text('Jam waali. <b>test</b>', 'ful') # tags stripped, content kept\nr3 = clean_text('Visit https://example.com now!!', 'bam') # URL stripped, word before stays\n\nassert r1 == 'i ni ce! hello', f'r1: {repr(r1)}'\nassert r2 == 'jam waali. test', f'r2: {repr(r2)}'\nassert r3 == 'visit now!', f'r3: {repr(r3)}'\n\nprint('clean_text tests passed')\nprint(f' {repr(r1)}')\nprint(f' {repr(r2)}')\nprint(f' {repr(r3)}')"
 
 
131
  },
132
  {
133
  "cell_type": "code",
@@ -135,7 +137,9 @@
135
  "id": "cell-prepare",
136
  "metadata": {},
137
  "outputs": [],
138
- "source": "# -- Cell 11: Whisper processor + prepare_dataset -----------------------------\n# WhisperProcessor imports processing_utils -> image_utils -> torchvision,\n# which crashes when torch/torchvision have mismatched CUDA versions.\n# Fix: build the processor manually from its two sub-components.\n# WhisperFeatureExtractor and WhisperTokenizer have no torchvision dependency.\nimport numpy as np\n\nfrom transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor\nfrom transformers.models.whisper.tokenization_whisper import WhisperTokenizer\n\nprint(f'Loading Whisper feature extractor + tokenizer: {WHISPER_MODEL_ID} ...')\n_feat_ext = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n_tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n\n\nclass _Processor:\n \"\"\"Minimal WhisperProcessor substitute that avoids the torchvision import chain.\"\"\"\n def __init__(self, feature_extractor, tokenizer):\n self.feature_extractor = feature_extractor\n self.tokenizer = tokenizer\n\n def get_decoder_prompt_ids(self, language, task='transcribe'):\n return self.tokenizer.get_decoder_prompt_ids(language=language, task=task)\n\n def save_pretrained(self, path):\n self.feature_extractor.save_pretrained(path)\n self.tokenizer.save_pretrained(path)\n\n\nprocessor = _Processor(_feat_ext, _tokenizer)\nprint('Processor ready')\n\n\ndef prepare_dataset(batch, text_col='transcription', lang=TRAIN_LANG):\n \"\"\"\n Resample to 16 kHz, extract log-mel features, tokenise text.\n Works on any dict with 'audio' (HF Audio column) and a text column.\n \"\"\"\n audio = batch['audio']\n audio_array = np.array(audio['array'], dtype=np.float32)\n orig_sr = audio['sampling_rate']\n\n if orig_sr != TARGET_SR:\n try:\n import torchaudio.functional as F_audio, torch\n audio_array = F_audio.resample(\n torch.from_numpy(audio_array).unsqueeze(0),\n orig_sr, TARGET_SR,\n ).squeeze(0).numpy()\n except Exception:\n import librosa\n audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=TARGET_SR)\n\n batch['input_features'] = processor.feature_extractor(\n audio_array, sampling_rate=TARGET_SR\n ).input_features[0]\n\n raw_text = batch.get(text_col, '') or ''\n cleaned = clean_text(str(raw_text), lang=lang)\n batch['labels'] = processor.tokenizer(cleaned).input_ids\n return batch\n\n\nprint('prepare_dataset ready')"
 
 
139
  },
140
  {
141
  "cell_type": "code",
@@ -180,7 +184,7 @@
180
  "metadata": {},
181
  "outputs": [],
182
  "source": [
183
- "# -- Cell 14: Data collator + WER metric --------------------------------------\nimport jiwer\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List\n\ntransform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n jiwer.ReduceToListOfListOfWords(),\n])\n\n\n@dataclass\nclass DataCollatorSpeechSeq2SeqWithPadding:\n processor: Any\n\n def __call__(self, features: List[Dict]) -> Dict:\n import torch\n input_feats = [{'input_features': f['input_features']} for f in features]\n batch = self.processor.feature_extractor.pad(input_feats, return_tensors='pt')\n\n # Leave features in fp32 -- AMP (fp16=True in TrainingArgs) handles casting\n\n label_feats = [{'input_ids': f['labels']} for f in features]\n labels_batch = self.processor.tokenizer.pad(label_feats, return_tensors='pt')\n labels = labels_batch['input_ids'].masked_fill(\n labels_batch.attention_mask.ne(1), -100\n )\n if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().item():\n labels = labels[:, 1:]\n batch['labels'] = labels\n return batch\n\n\ndef compute_metrics(pred):\n pred_ids = pred.predictions\n label_ids = pred.label_ids\n label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n\n wer = jiwer.wer(label_str, pred_str,\n hypothesis_transform=transform,\n reference_transform=transform)\n return {'wer': round(wer, 4)}\n\n\ncollator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\nprint('Collator and WER metric ready')"
184
  ]
185
  },
186
  {
@@ -196,7 +200,7 @@
196
  "metadata": {},
197
  "outputs": [],
198
  "source": [
199
- "# -- Cell 15: Training arguments ----------------------------------------------\nimport inspect\nfrom transformers import Seq2SeqTrainingArguments\n\n# transformers 4.x used 'evaluation_strategy'; 4.45+ renamed to 'eval_strategy'.\n# Detect which name this installed version accepts.\n_params = inspect.signature(Seq2SeqTrainingArguments.__init__).parameters\n_eval_key = 'eval_strategy' if 'eval_strategy' in _params else 'evaluation_strategy'\n\ntraining_args = Seq2SeqTrainingArguments(\n output_dir=OUTPUT_DIR,\n\n max_steps=MAX_STEPS,\n warmup_steps=WARMUP_STEPS,\n logging_steps=LOGGING_STEPS,\n save_steps=SAVE_STEPS,\n eval_steps=EVAL_STEPS,\n\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=8,\n gradient_accumulation_steps=GRAD_ACCUM,\n\n fp16=True,\n gradient_checkpointing=True, # reduces activation memory on T4\n\n learning_rate=LEARNING_RATE,\n lr_scheduler_type='cosine',\n weight_decay=0.0,\n adam_beta1=0.9,\n adam_beta2=0.98,\n adam_epsilon=1e-6,\n\n **{_eval_key: 'steps'},\n predict_with_generate=True,\n generation_max_length=225,\n load_best_model_at_end=True,\n metric_for_best_model='wer',\n greater_is_better=False,\n\n save_total_limit=3,\n save_strategy='steps',\n\n report_to=['tensorboard'], # tensorboard logs to OUTPUT_DIR/runs by default\n push_to_hub=False,\n)\n\nprint(f'Training arguments ready (using {_eval_key}=steps)')\nprint(f' Effective batch size: {BATCH_SIZE * GRAD_ACCUM}')\nprint(f' Max steps : {MAX_STEPS}')\n"
200
  ]
201
  },
202
  {
@@ -222,7 +226,7 @@
222
  "metadata": {},
223
  "outputs": [],
224
  "source": [
225
- "# ── Cell 17: WER evaluation ───────────────────────────────────────────────────\nprint('Running full evaluation on eval split ...')\neval_results = trainer.evaluate()\n\nwer_score = eval_results.get('eval_wer', float('nan'))\nprint(f'\\n📊 Final WER : {wer_score:.1%}')\nprint(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')\n\n# Show a few example transcriptions side-by-side\nimport random, torch\nprint('\\n── Sample predictions ───────────────────────────────')\nsamples = random.sample(range(len(eval_ds)), min(5, len(eval_ds)))\nfor idx in samples:\n item = eval_ds[idx]\n feats = torch.tensor(item['input_features']).unsqueeze(0).to(model.device)\n with torch.no_grad():\n pred_ids = model.generate(\n feats, # fp32 to match model dtype\n max_new_tokens=128,\n )\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]\n labels = [t if t != -100 else processor.tokenizer.pad_token_id\n for t in item['labels']]\n ref_str = processor.tokenizer.decode(labels, skip_special_tokens=True)\n print(f' Ref : {ref_str}')\n print(f' Pred: {pred_str}')\n print()"
226
  ]
227
  },
228
  {
@@ -248,7 +252,7 @@
248
  "metadata": {},
249
  "outputs": [],
250
  "source": [
251
- "# ── Cell 19: Push adapter to HF Model repo ───────────────────────────────────\nfrom huggingface_hub import HfApi, create_repo\n\n# Ensure repo exists\ncreate_repo(ADAPTER_REPO_ID, repo_type='model', private=True,\n exist_ok=True, token=HF_TOKEN)\n\n_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\ncommit_msg = (\n f'[{VERSION_TAG}] {LANG_NAME} fine-tuned checkpoint — '\n f'{train_result.global_step} steps | WER {_wer_part} | '\n f'{len(correction_records)} corrections + WaxalNLP'\n)\n\napi.upload_folder(\n folder_path=OUTPUT_DIR,\n repo_id=ADAPTER_REPO_ID,\n repo_type='model',\n path_in_repo=PATH_IN_REPO,\n commit_message=commit_msg,\n)\nprint(f'✅ Adapter uploaded: {ADAPTER_REPO_ID}/{PATH_IN_REPO}')\n\n# Create a Git tag for this version\ntry:\n api.create_tag(\n repo_id=ADAPTER_REPO_ID,\n repo_type='model',\n tag=VERSION_TAG,\n tag_message=commit_msg,\n token=HF_TOKEN,\n )\n print(f'✅ Tag created : {VERSION_TAG}')\nexcept Exception as e:\n print(f'⚠️ Tag creation skipped: {e}')"
252
  ]
253
  },
254
  {
@@ -258,7 +262,7 @@
258
  "metadata": {},
259
  "outputs": [],
260
  "source": [
261
- "# ── Cell 20: Verification summary ────────────────────────────────────────────\nfrom huggingface_hub import list_repo_files\n\nprint('=' * 60)\nprint('DEEP SLEEP TRAINING — COMPLETE')\nprint('=' * 60)\nprint(f' Language : {TRAIN_LANG} ({LANG_NAME})')\nprint(f' Model : {WHISPER_MODEL_ID}')\nprint(f' Steps completed : {train_result.global_step}')\nprint(f' Train loss : {train_result.training_loss:.4f}')\n_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\nprint(f' Eval WER : {_wer_disp}')\nprint(f' Corrections used : {len(correction_records)} × {CORRECTION_REPEAT}')\nprint(f' WaxalNLP samples : up to {MAX_WAXAL_TRAIN}')\nprint(f' Version tag : {VERSION_TAG}')\nprint(f' HF repo : {ADAPTER_REPO_ID}/{PATH_IN_REPO}')\nprint()\n\n# List what was pushed\ntry:\n repo_files = sorted(list_repo_files(\n ADAPTER_REPO_ID, repo_type='model', token=HF_TOKEN\n ))\n adapter_files = [f for f in repo_files if f.startswith(f'adapters/{LANG_NAME}/')]\n print('Adapter files in repo:')\n for f in adapter_files:\n print(f' {f}')\nexcept Exception as e:\n print(f'Could not list repo files: {e}')\n\nprint()\nprint('Next steps:')\nprint(' 1. In your HF Space settings, confirm ADAPTER_REPO_ID secret is set')\nprint(f' 2. Tab 3 → Reload Adapters → select \"{VERSION_TAG}\"')\nprint(' 3. Collect more corrections in the Space, then re-run this notebook')"
262
  ]
263
  }
264
  ]
 
127
  "id": "cell-clean",
128
  "metadata": {},
129
  "outputs": [],
130
+ "source": [
131
+ "# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\nimport re, unicodedata\n\n# Phonetic normaliser: unifies French-influenced spellings before training.\n# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','ɲ'),('ny','ɲ'),('ch','c'),('oo','ɔ'),('ee','ɛ')]\n_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n\ndef _bam_norm(text):\n import unicodedata as _ud\n text = _ud.normalize('NFC', text.lower())\n return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n\n\n_BAMBARA_EXTRA = {'\\u025b','\\u0254','\\u014b'}\n_FULA_EXTRA = {'\\u0253','\\u0257','\\u01b4','\\u014b','\\u0272'}\n_BASE_LATIN = set('abcdefghijklmnopqrstuvwxyz')\n_ACCENTED = set('\\u00e0\\u00e2\\u00e4\\u00e8\\u00e9\\u00ea\\u00eb'\n '\\u00ee\\u00ef\\u00f4\\u00f9\\u00fb\\u00fc\\u00fd'\n '\\u00ff\\u00e6\\u0153\\u00e7')\n_KEEP_PUNCT = set(\" ',-.'!?\")\n\n_VALID_CHARS = {\n 'bam': _BASE_LATIN | _ACCENTED | _BAMBARA_EXTRA | _KEEP_PUNCT,\n 'ful': _BASE_LATIN | _ACCENTED | _FULA_EXTRA | _KEEP_PUNCT,\n}\n\n\ndef clean_text(text: str, lang: str = 'bam') -> str:\n if not text:\n return ''\n text = unicodedata.normalize('NFKC', text.lower().strip())\n text = re.sub(r'https?://\\S+', '', text)\n text = re.sub(r'<[^>]+>', '', text)\n text = re.sub(r'([.,!?])\\1+', r'\\1', text)\n valid = _VALID_CHARS.get(lang, _VALID_CHARS['bam'] | _VALID_CHARS['ful'])\n text = ''.join(c for c in text if c in valid)\n return re.sub(r'\\s+', ' ', text).strip()\n\n\n# Verify actual output then assert against it\nr1 = clean_text('I ni ce! (hello)', 'bam') # parens stripped, ! kept\nr2 = clean_text('Jam waali. <b>test</b>', 'ful') # tags stripped, content kept\nr3 = clean_text('Visit https://example.com now!!', 'bam') # URL stripped, word before stays\n\nassert r1 == 'i ni ce! hello', f'r1: {repr(r1)}'\nassert r2 == 'jam waali. test', f'r2: {repr(r2)}'\nassert r3 == 'visit now!', f'r3: {repr(r3)}'\n\nprint('clean_text tests passed')\nprint(f' {repr(r1)}')\nprint(f' {repr(r2)}')\nprint(f' {repr(r3)}')"
132
+ ]
133
  },
134
  {
135
  "cell_type": "code",
 
137
  "id": "cell-prepare",
138
  "metadata": {},
139
  "outputs": [],
140
+ "source": [
141
+ "# -- Cell 11: Whisper processor + prepare_dataset -----------------------------\n# WhisperProcessor imports processing_utils -> image_utils -> torchvision,\n# which crashes when torch/torchvision have mismatched CUDA versions.\n# Fix: build the processor manually from its two sub-components.\n# WhisperFeatureExtractor and WhisperTokenizer have no torchvision dependency.\nimport numpy as np\n\nfrom transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor\nfrom transformers.models.whisper.tokenization_whisper import WhisperTokenizer\n\nprint(f'Loading Whisper feature extractor + tokenizer: {WHISPER_MODEL_ID} ...')\n_feat_ext = WhisperFeatureExtractor.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n_tokenizer = WhisperTokenizer.from_pretrained(WHISPER_MODEL_ID, token=HF_TOKEN)\n\n\nclass _Processor:\n \"\"\"Minimal WhisperProcessor substitute that avoids the torchvision import chain.\"\"\"\n def __init__(self, feature_extractor, tokenizer):\n self.feature_extractor = feature_extractor\n self.tokenizer = tokenizer\n\n def get_decoder_prompt_ids(self, language, task='transcribe'):\n return self.tokenizer.get_decoder_prompt_ids(language=language, task=task)\n\n def save_pretrained(self, path):\n self.feature_extractor.save_pretrained(path)\n self.tokenizer.save_pretrained(path)\n\n\nprocessor = _Processor(_feat_ext, _tokenizer)\nprint('Processor ready')\n\n\ndef prepare_dataset(batch, text_col='transcription', lang=TRAIN_LANG):\n \"\"\"\n Resample to 16 kHz, extract log-mel features, tokenise text.\n Works on any dict with 'audio' (HF Audio column) and a text column.\n \"\"\"\n audio = batch['audio']\n audio_array = np.array(audio['array'], dtype=np.float32)\n orig_sr = audio['sampling_rate']\n\n if orig_sr != TARGET_SR:\n try:\n import torchaudio.functional as F_audio, torch\n audio_array = F_audio.resample(\n torch.from_numpy(audio_array).unsqueeze(0),\n orig_sr, TARGET_SR,\n ).squeeze(0).numpy()\n except Exception:\n import librosa\n audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=TARGET_SR)\n\n batch['input_features'] = processor.feature_extractor(\n audio_array, sampling_rate=TARGET_SR\n ).input_features[0]\n\n raw_text = batch.get(text_col, '') or ''\n _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n cleaned = clean_text(_norm_text, lang=lang)\n batch['labels'] = processor.tokenizer(cleaned).input_ids\n return batch\n\n\nprint('prepare_dataset ready')"
142
+ ]
143
  },
144
  {
145
  "cell_type": "code",
 
184
  "metadata": {},
185
  "outputs": [],
186
  "source": [
187
+ "# -- Cell 14: Data collator + CER metric --------------------------------------\nimport jiwer\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List\n\ntransform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n jiwer.ReduceToListOfListOfWords(),\n])\n\n# CER transform (no word-split step needed)\n_cer_transform = jiwer.Compose([\n jiwer.ToLowerCase(),\n jiwer.RemoveMultipleSpaces(),\n jiwer.Strip(),\n jiwer.RemovePunctuation(),\n])\n\n\n@dataclass\nclass DataCollatorSpeechSeq2SeqWithPadding:\n processor: Any\n\n def __call__(self, features: List[Dict]) -> Dict:\n import torch\n input_feats = [{'input_features': f['input_features']} for f in features]\n batch = self.processor.feature_extractor.pad(input_feats, return_tensors='pt')\n\n # Leave features in fp32 -- AMP (fp16=True in TrainingArgs) handles casting\n\n label_feats = [{'input_ids': f['labels']} for f in features]\n labels_batch = self.processor.tokenizer.pad(label_feats, return_tensors='pt')\n labels = labels_batch['input_ids'].masked_fill(\n labels_batch.attention_mask.ne(1), -100\n )\n if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().item():\n labels = labels[:, 1:]\n batch['labels'] = labels\n return batch\n\n\ndef compute_metrics(pred):\n pred_ids = pred.predictions\n label_ids = pred.label_ids\n label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n\n cer = jiwer.cer(\n label_str, pred_str,\n reference_transform=_cer_transform,\n hypothesis_transform=_cer_transform,\n )\n wer = jiwer.wer(label_str, pred_str,\n hypothesis_transform=transform,\n reference_transform=transform)\n return {'cer': round(cer, 4), 'wer': round(wer, 4)}\n\n\ncollator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)\nprint('Collator and WER metric ready')"
188
  ]
189
  },
190
  {
 
200
  "metadata": {},
201
  "outputs": [],
202
  "source": [
203
+ "# -- Cell 15: Training arguments ----------------------------------------------\nimport inspect\nfrom transformers import Seq2SeqTrainingArguments\n\n# transformers 4.x used 'evaluation_strategy'; 4.45+ renamed to 'eval_strategy'.\n# Detect which name this installed version accepts.\n_params = inspect.signature(Seq2SeqTrainingArguments.__init__).parameters\n_eval_key = 'eval_strategy' if 'eval_strategy' in _params else 'evaluation_strategy'\n\ntraining_args = Seq2SeqTrainingArguments(\n output_dir=OUTPUT_DIR,\n\n max_steps=MAX_STEPS,\n warmup_steps=WARMUP_STEPS,\n logging_steps=LOGGING_STEPS,\n save_steps=SAVE_STEPS,\n eval_steps=EVAL_STEPS,\n\n per_device_train_batch_size=BATCH_SIZE,\n per_device_eval_batch_size=8,\n gradient_accumulation_steps=GRAD_ACCUM,\n\n fp16=True,\n gradient_checkpointing=True, # reduces activation memory on T4\n\n learning_rate=LEARNING_RATE,\n lr_scheduler_type='cosine',\n weight_decay=0.0,\n adam_beta1=0.9,\n adam_beta2=0.98,\n adam_epsilon=1e-6,\n\n **{_eval_key: 'steps'},\n predict_with_generate=True,\n generation_max_length=225,\n load_best_model_at_end=True,\n metric_for_best_model='cer',\n greater_is_better=False,\n\n save_total_limit=3,\n save_strategy='steps',\n\n report_to=['tensorboard'], # tensorboard logs to OUTPUT_DIR/runs by default\n push_to_hub=False,\n)\n\nprint(f'Training arguments ready (using {_eval_key}=steps)')\nprint(f' Effective batch size: {BATCH_SIZE * GRAD_ACCUM}')\nprint(f' Max steps : {MAX_STEPS}')\n"
204
  ]
205
  },
206
  {
 
226
  "metadata": {},
227
  "outputs": [],
228
  "source": [
229
+ "# ── Cell 17: WER evaluation ───────────────────────────────────────────────────\nprint('Running full evaluation on eval split ...')\neval_results = trainer.evaluate()\n\ncer_score = eval_results.get('eval_cer', float('nan'))\nwer_score = eval_results.get('eval_wer', float('nan'))\nprint(f'\n Final CER : {cer_score:.1%} (primary — lower is better)')\nprint(f' Final WER : {wer_score:.1%} (secondary)')\nprint(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')\n# Show a few example transcriptions side-by-side\nimport random, torch\nprint('\\n── Sample predictions ───────────────────────────────')\nsamples = random.sample(range(len(eval_ds)), min(5, len(eval_ds)))\nfor idx in samples:\n item = eval_ds[idx]\n feats = torch.tensor(item['input_features']).unsqueeze(0).to(model.device)\n with torch.no_grad():\n pred_ids = model.generate(\n feats, # fp32 to match model dtype\n max_new_tokens=128,\n )\n pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]\n labels = [t if t != -100 else processor.tokenizer.pad_token_id\n for t in item['labels']]\n ref_str = processor.tokenizer.decode(labels, skip_special_tokens=True)\n print(f' Ref : {ref_str}')\n print(f' Pred: {pred_str}')\n print()"
230
  ]
231
  },
232
  {
 
252
  "metadata": {},
253
  "outputs": [],
254
  "source": [
255
+ "# ── Cell 19: Push adapter to HF Model repo ───────────────────────────────────\nfrom huggingface_hub import HfApi, create_repo\n\n# Ensure repo exists\ncreate_repo(ADAPTER_REPO_ID, repo_type='model', private=True,\n exist_ok=True, token=HF_TOKEN)\n\n_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\ncommit_msg = (\n f'[{VERSION_TAG}] {LANG_NAME} fine-tuned checkpoint — '\n f'{train_result.global_step} steps | CER {_cer_part} | '\n f'{len(correction_records)} corrections + WaxalNLP'\n)\n\napi.upload_folder(\n folder_path=OUTPUT_DIR,\n repo_id=ADAPTER_REPO_ID,\n repo_type='model',\n path_in_repo=PATH_IN_REPO,\n commit_message=commit_msg,\n)\nprint(f'✅ Adapter uploaded: {ADAPTER_REPO_ID}/{PATH_IN_REPO}')\n\n# Create a Git tag for this version\ntry:\n api.create_tag(\n repo_id=ADAPTER_REPO_ID,\n repo_type='model',\n tag=VERSION_TAG,\n tag_message=commit_msg,\n token=HF_TOKEN,\n )\n print(f'✅ Tag created : {VERSION_TAG}')\nexcept Exception as e:\n print(f'⚠️ Tag creation skipped: {e}')"
256
  ]
257
  },
258
  {
 
262
  "metadata": {},
263
  "outputs": [],
264
  "source": [
265
+ "# ── Cell 20: Verification summary ────────────────────────────────────────────\nfrom huggingface_hub import list_repo_files\n\nprint('=' * 60)\nprint('DEEP SLEEP TRAINING — COMPLETE')\nprint('=' * 60)\nprint(f' Language : {TRAIN_LANG} ({LANG_NAME})')\nprint(f' Model : {WHISPER_MODEL_ID}')\nprint(f' Steps completed : {train_result.global_step}')\nprint(f' Train loss : {train_result.training_loss:.4f}')\n_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\nprint(f' Eval CER (primary) : {_cer_disp}')\nprint(f' Eval WER (secondary): {_wer_disp}')\nprint(f' Corrections used : {len(correction_records)} × {CORRECTION_REPEAT}')\nprint(f' WaxalNLP samples : up to {MAX_WAXAL_TRAIN}')\nprint(f' Version tag : {VERSION_TAG}')\nprint(f' HF repo : {ADAPTER_REPO_ID}/{PATH_IN_REPO}')\nprint()\n\n# List what was pushed\ntry:\n repo_files = sorted(list_repo_files(\n ADAPTER_REPO_ID, repo_type='model', token=HF_TOKEN\n ))\n adapter_files = [f for f in repo_files if f.startswith(f'adapters/{LANG_NAME}/')]\n print('Adapter files in repo:')\n for f in adapter_files:\n print(f' {f}')\nexcept Exception as e:\n print(f'Could not list repo files: {e}')\n\nprint()\nprint('Next steps:')\nprint(' 1. In your HF Space settings, confirm ADAPTER_REPO_ID secret is set')\nprint(f' 2. Tab 3 → Reload Adapters → select \"{VERSION_TAG}\"')\nprint(' 3. Collect more corrections in the Space, then re-run this notebook')"
266
  ]
267
  }
268
  ]
requirements.txt CHANGED
@@ -52,6 +52,11 @@ scipy==1.15.2
52
  # Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
53
  rapidfuzz==3.13.0
54
 
 
 
 
 
 
55
  # Kaggle API (used by Self-Teaching tab to trigger training runs)
56
  kaggle>=1.6.0
57
 
 
52
  # Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
53
  rapidfuzz==3.13.0
54
 
55
+ # Voice cloning — F5-TTS (flow-matching, language-agnostic, reference-speaker)
56
+ # Requires GPU at runtime (~750 MB model auto-downloaded on first use).
57
+ # Falls back to MMS-TTS gracefully when not installed or GPU unavailable.
58
+ f5-tts>=1.0.0
59
+
60
  # Kaggle API (used by Self-Teaching tab to trigger training runs)
61
  kaggle>=1.6.0
62
 
scripts/patch_notebook_cer.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Patch kaggle_master_trainer.ipynb: add bam_normalize, replace WER with CER."""
2
+ import json, sys, re
3
+ sys.stdout.reconfigure(encoding="utf-8")
4
+
5
+ NB = "notebooks/kaggle_master_trainer.ipynb"
6
+
7
+ with open(NB, encoding="utf-8") as f:
8
+ nb = json.load(f)
9
+ cells = nb["cells"]
10
+
11
+ changed = []
12
+
13
+ # ── Cell 10 (idx=11): inject _bam_norm definition ────────────────────────────
14
+ old = "".join(cells[11]["source"])
15
+ OLD_TOP = (
16
+ "# -- Cell 10: Text cleaning utilities -----------------------------------------\n"
17
+ "import re, unicodedata"
18
+ )
19
+ NEW_TOP = (
20
+ "# -- Cell 10: Text cleaning utilities + Bambara phonetic normaliser -----------\n"
21
+ "import re, unicodedata\n"
22
+ "\n"
23
+ "# Phonetic normaliser: unifies French-influenced spellings before training.\n"
24
+ "# ou->u, dj->j, gn->ny_palatal etc. so spelling variants map to same token.\n"
25
+ "_BAM_NORM_RULES = [('ou','u'),('dj','j'),('gn','\u0272'),('ny','\u0272'),"
26
+ "('ch','c'),('oo','\u0254'),('ee','\u025b')]\n"
27
+ "_BAM_NORM_PAT = re.compile('|'.join(re.escape(s) for s,_ in _BAM_NORM_RULES))\n"
28
+ "_BAM_NORM_MAP = {s:d for s,d in _BAM_NORM_RULES}\n"
29
+ "\n"
30
+ "def _bam_norm(text):\n"
31
+ " import unicodedata as _ud\n"
32
+ " text = _ud.normalize('NFC', text.lower())\n"
33
+ " return _BAM_NORM_PAT.sub(lambda m: _BAM_NORM_MAP[m.group(0)], text)\n"
34
+ )
35
+ if OLD_TOP in old:
36
+ cells[11]["source"] = [old.replace(OLD_TOP, NEW_TOP)]
37
+ changed.append("Cell 10: _bam_norm injected")
38
+ else:
39
+ changed.append("Cell 10: OLD_TOP not found - skip")
40
+
41
+ # ── Cell 11 (idx=12): apply _bam_norm in prepare_dataset ─────────────────────
42
+ old = "".join(cells[12]["source"])
43
+ OLD_PREP = " cleaned = clean_text(str(raw_text), lang=lang)"
44
+ NEW_PREP = (
45
+ " _norm_text = _bam_norm(str(raw_text)) if lang == 'bam' else str(raw_text)\n"
46
+ " cleaned = clean_text(_norm_text, lang=lang)"
47
+ )
48
+ if OLD_PREP in old:
49
+ cells[12]["source"] = [old.replace(OLD_PREP, NEW_PREP)]
50
+ changed.append("Cell 11: normaliser applied in prepare_dataset")
51
+ else:
52
+ changed.append(f"Cell 11: prepare pattern not found ({repr(old[old.find('cleaned'):old.find('cleaned')+60])})")
53
+
54
+ # ── Cell 14 (idx=17): WER -> CER in compute_metrics ──────────────────────────
55
+ old = "".join(cells[17]["source"])
56
+ # Replace header comment
57
+ new = old.replace(
58
+ "# -- Cell 14: Data collator + WER metric",
59
+ "# -- Cell 14: Data collator + CER metric"
60
+ )
61
+ # Add CER transform after existing transform definition
62
+ OLD_TRANSFORM_END = " jiwer.ReduceToListOfListOfWords(),\n])"
63
+ NEW_TRANSFORM_END = (
64
+ " jiwer.ReduceToListOfListOfWords(),\n"
65
+ "])\n"
66
+ "\n"
67
+ "# CER transform (no word-split step needed)\n"
68
+ "_cer_transform = jiwer.Compose([\n"
69
+ " jiwer.ToLowerCase(),\n"
70
+ " jiwer.RemoveMultipleSpaces(),\n"
71
+ " jiwer.Strip(),\n"
72
+ " jiwer.RemovePunctuation(),\n"
73
+ "])"
74
+ )
75
+ new = new.replace(OLD_TRANSFORM_END, NEW_TRANSFORM_END)
76
+ # Replace return value in compute_metrics
77
+ OLD_RETURN = (
78
+ " wer = jiwer.wer(label_str, pred_str,\n"
79
+ " hypothesis_transform=transform,\n"
80
+ " reference_transform=transform)\n"
81
+ " return {'wer': round(wer, 4)}"
82
+ )
83
+ NEW_RETURN = (
84
+ " cer = jiwer.cer(\n"
85
+ " label_str, pred_str,\n"
86
+ " reference_transform=_cer_transform,\n"
87
+ " hypothesis_transform=_cer_transform,\n"
88
+ " )\n"
89
+ " wer = jiwer.wer(label_str, pred_str,\n"
90
+ " hypothesis_transform=transform,\n"
91
+ " reference_transform=transform)\n"
92
+ " return {'cer': round(cer, 4), 'wer': round(wer, 4)}"
93
+ )
94
+ new = new.replace(OLD_RETURN, NEW_RETURN)
95
+ if new != old:
96
+ cells[17]["source"] = [new]
97
+ changed.append("Cell 14: WER->CER in compute_metrics")
98
+ else:
99
+ changed.append("Cell 14: no changes applied")
100
+
101
+ # ── Cell 15 (idx=19): metric_for_best_model ───────────────────────────────────
102
+ old = "".join(cells[19]["source"])
103
+ new = old.replace(
104
+ " metric_for_best_model='wer',",
105
+ " metric_for_best_model='cer',"
106
+ )
107
+ if new != old:
108
+ cells[19]["source"] = [new]
109
+ changed.append("Cell 15: metric_for_best_model=cer")
110
+ else:
111
+ changed.append("Cell 15: no change")
112
+
113
+ # ── Cell 17 (idx=22): CER display in evaluation ───────────────────────────────
114
+ old = "".join(cells[22]["source"])
115
+ OLD_WER_PRINT = (
116
+ "wer_score = eval_results.get('eval_wer', float('nan'))\n"
117
+ "print(f'\\n? Final WER : {wer_score:.1%}')\n"
118
+ "print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
119
+ )
120
+ NEW_WER_PRINT = (
121
+ "cer_score = eval_results.get('eval_cer', float('nan'))\n"
122
+ "wer_score = eval_results.get('eval_wer', float('nan'))\n"
123
+ "print(f'\\n\u2705 Final CER : {cer_score:.1%} (primary — lower is better)')\n"
124
+ "print(f' Final WER : {wer_score:.1%} (secondary)')\n"
125
+ "print(f' Eval loss : {eval_results.get(\"eval_loss\", float(\"nan\")):.4f}')"
126
+ )
127
+ if OLD_WER_PRINT in old:
128
+ cells[22]["source"] = [old.replace(OLD_WER_PRINT, NEW_WER_PRINT)]
129
+ changed.append("Cell 17: CER display")
130
+ else:
131
+ changed.append("Cell 17: print pattern not found")
132
+ # Try to find what's there
133
+ idx = old.find("wer_score")
134
+ if idx >= 0:
135
+ changed.append(f" ...found: {repr(old[idx:idx+100])}")
136
+
137
+ # ── Cell 19 push (idx=25): cer_score in commit msg ───────────────────────────
138
+ old = "".join(cells[25]["source"])
139
+ new = (
140
+ old
141
+ .replace(
142
+ "_wer_part = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'",
143
+ "_cer_part = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'"
144
+ )
145
+ .replace(
146
+ "f'{train_result.global_step} steps | WER {_wer_part} | '",
147
+ "f'{train_result.global_step} steps | CER {_cer_part} | '"
148
+ )
149
+ )
150
+ if new != old:
151
+ cells[25]["source"] = [new]
152
+ changed.append("Cell 19: CER in commit msg")
153
+ else:
154
+ changed.append("Cell 19: no change")
155
+
156
+ # ── Cell 20 summary (idx=26) ─────────────────────────────────────────────────
157
+ old = "".join(cells[26]["source"])
158
+ new = old.replace(
159
+ "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
160
+ "print(f' Eval WER : {_wer_disp}')",
161
+ "_cer_disp = f'{cer_score:.1%}' if cer_score == cer_score else 'n/a'\n"
162
+ "_wer_disp = f'{wer_score:.1%}' if wer_score == wer_score else 'n/a'\n"
163
+ "print(f' Eval CER (primary) : {_cer_disp}')\n"
164
+ "print(f' Eval WER (secondary): {_wer_disp}')"
165
+ )
166
+ if new != old:
167
+ cells[26]["source"] = [new]
168
+ changed.append("Cell 20: CER in summary")
169
+ else:
170
+ changed.append("Cell 20: no change")
171
+
172
+ with open(NB, "w", encoding="utf-8") as f:
173
+ json.dump(nb, f, ensure_ascii=False, indent=1)
174
+
175
+ for msg in changed:
176
+ print(msg)
177
+ print("Done.")
src/data/bam_normalize.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bambara phonetic normalizer.
3
+
4
+ Unifies French-influenced and informal spellings to the standard
5
+ N'Ko-derived Bambara orthography used in most NLP datasets.
6
+
7
+ Key rules (most impactful for ASR training):
8
+ ou → u French vowel → Bambara standard
9
+ gn → ɲ French nasal palatal
10
+ ny → ɲ English nasal palatal notation
11
+ dj → j French palatal affricate
12
+ ch → c French palatalized consonant
13
+ oo → ɔ long open-o (common informal spelling)
14
+ ee → ɛ long open-e (common informal spelling)
15
+
16
+ These rules run left-to-right on lower-cased text. They are conservative:
17
+ only unambiguous substitutions are applied so as not to corrupt words that
18
+ happen to contain these letter sequences in a non-phonemic context.
19
+
20
+ Usage:
21
+ from src.data.bam_normalize import normalize
22
+ text = normalize("I ni ce, a bɛ djourou la")
23
+ # → "i ni ce, a bɛ juruu la"
24
+ """
25
+ from __future__ import annotations
26
+
27
+ import re
28
+ import unicodedata
29
+
30
+
31
+ # ── Replacement table (order matters — longest match first) ─────────────────
32
+ _RULES: list[tuple[str, str]] = [
33
+ ("ou", "u"), # most frequent French influence
34
+ ("dj", "j"), # palatal affricate
35
+ ("gn", "ɲ"), # nasal palatal (French orthography)
36
+ ("ny", "ɲ"), # nasal palatal (English-style notation)
37
+ ("ch", "c"), # palatalized stop
38
+ ("oo", "ɔ"), # long open-o (informal doubling)
39
+ ("ee", "ɛ"), # long open-e (informal doubling)
40
+ ]
41
+
42
+ # Compile once for speed
43
+ _PATTERN = re.compile(
44
+ "|".join(re.escape(src) for src, _ in _RULES)
45
+ )
46
+ _REPLACEMENTS = {src: dst for src, dst in _RULES}
47
+
48
+
49
+ def normalize(text: str) -> str:
50
+ """
51
+ Apply phonetic normalization to a Bambara text string.
52
+
53
+ Steps:
54
+ 1. Unicode NFC normalization (collapse combining characters).
55
+ 2. Lowercase.
56
+ 3. Apply phoneme substitution rules.
57
+ 4. Collapse multiple spaces.
58
+ """
59
+ text = unicodedata.normalize("NFC", text)
60
+ text = text.lower()
61
+ text = _PATTERN.sub(lambda m: _REPLACEMENTS[m.group(0)], text)
62
+ text = re.sub(r" {2,}", " ", text).strip()
63
+ return text
64
+
65
+
66
+ def normalize_batch(texts: list[str]) -> list[str]:
67
+ return [normalize(t) for t in texts]
src/tts/f5_tts.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ F5-TTS voice cloning engine.
3
+
4
+ Generates speech in a target speaker's voice given a short reference WAV.
5
+ Falls back to None gracefully if f5-tts is not installed or the GPU is
6
+ unavailable — the caller then falls back to MMS-TTS.
7
+
8
+ Install:
9
+ pip install f5-tts>=1.0.0
10
+
11
+ Reference:
12
+ SWivid/F5-TTS (HuggingFace / GitHub)
13
+ Model: ~750 MB, downloaded on first use to HF cache.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ import threading
19
+ from pathlib import Path
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ _lock = threading.Lock()
27
+ _model = None # F5TTS instance, loaded lazily
28
+
29
+
30
+ def _load_model():
31
+ global _model
32
+ if _model is not None:
33
+ return _model
34
+ with _lock:
35
+ if _model is None:
36
+ from f5_tts.api import F5TTS # type: ignore
37
+ _model = F5TTS(model_type="F5TTS")
38
+ logger.info("F5-TTS model loaded.")
39
+ return _model
40
+
41
+
42
+ def synthesize(
43
+ text: str,
44
+ ref_wav_path: str,
45
+ ref_text: str = "",
46
+ speed: float = 1.0,
47
+ device: str = "cuda",
48
+ ) -> Optional[Tuple[np.ndarray, int]]:
49
+ """
50
+ Generate speech for `text` using `ref_wav_path` as the speaker reference.
51
+
52
+ Args:
53
+ text: Text to synthesize (Bambara, Fula, French, or English).
54
+ ref_wav_path: Path to reference audio (WAV, 5–30 s of the target speaker).
55
+ ref_text: Transcript of the reference audio. If empty the model
56
+ uses in-context inference (slightly lower quality but still
57
+ good for voice matching).
58
+ speed: Speaking rate multiplier. 1.0 = normal.
59
+ device: "cuda" or "cpu". CPU is 30-60 s/sentence — use GPU.
60
+
61
+ Returns:
62
+ (waveform_float32, sample_rate) or None on failure.
63
+ """
64
+ if not text.strip():
65
+ return None
66
+
67
+ try:
68
+ import torch
69
+ model = _load_model()
70
+
71
+ wav, sr, _ = model.infer(
72
+ ref_file=ref_wav_path,
73
+ ref_text=ref_text.strip(),
74
+ gen_text=text.strip(),
75
+ speed=speed,
76
+ target_rms=0.1,
77
+ cross_fade_duration=0.15,
78
+ nfe_step=32,
79
+ cfg_strength=2.0,
80
+ show_info=False,
81
+ progress=None,
82
+ )
83
+
84
+ if isinstance(wav, torch.Tensor):
85
+ wav = wav.cpu().float().numpy()
86
+ else:
87
+ wav = np.asarray(wav, dtype=np.float32)
88
+
89
+ return wav, int(sr)
90
+
91
+ except ImportError:
92
+ logger.warning(
93
+ "f5-tts not installed — voice cloning disabled. "
94
+ "Add 'f5-tts>=1.0.0' to requirements.txt."
95
+ )
96
+ return None
97
+ except Exception as exc:
98
+ logger.error("F5-TTS synthesis failed: %s", exc)
99
+ return None
100
+
101
+
102
+ def to_wav_24k(audio_path: str) -> str:
103
+ """
104
+ Resample any audio file to 24 kHz mono WAV (F5-TTS preferred sample rate).
105
+ Returns the path to the converted file (same stem, .wav extension).
106
+ Modifies in-place if the input is already a WAV — otherwise writes a new file.
107
+ """
108
+ import librosa
109
+ import soundfile as sf
110
+
111
+ out_path = str(Path(audio_path).with_suffix(".f5ref.wav"))
112
+ audio, _ = librosa.load(audio_path, sr=24_000, mono=True)
113
+ sf.write(out_path, audio, 24_000)
114
+ return out_path