jefffffff9 Claude Sonnet 4.6 commited on
Commit
dd38e25
·
1 Parent(s): 61e52d7

Phase 2: Waxal TTS — Bambara voice output + Fula training notebook

Browse files

TTS engine (src/tts/waxal_tts.py):
- Bambara: MALIBA-AI/bambara-tts (non-Meta, Mali community, 10 native speakers)
Loads via custom maliba-ai package; writes WAV to tempfile, reads back as numpy
- Fula: ous-sow/fula-tts (our own model, loads once trained)
Lazy-loads; gracefully reports 'not trained yet' until notebook is run
- WaxalTTSEngine.audio_to_gradio() converts float32 → int16 for gr.Audio

app_lab.py:
- Imports WaxalTTSEngine; preloads both models at startup in background
- _run_llm_and_tts() shared core: Gemma → memory → TTS → return audio tuple
- process_audio() and process_text() now return 4-tuple (adds audio_out)
- UI: added gr.Audio output widget with autoplay; status bar shows TTS readiness
per language (🟢/🟡/🔴)

Training notebook (notebooks/train_fula_tts.ipynb):
- 9 cells: GPU check → install → HF login → config → load WaxalNLP ful_tts →
prepare dataset (WAV + metadata.csv) → Coqui VITS trainer → push to HF Hub →
synthesis test
- Runs on Kaggle T4 (~2-3h); pushes to ous-sow/fula-tts

requirements.txt: added maliba-ai from GitHub

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

app_lab.py CHANGED
@@ -1,10 +1,10 @@
1
  """
2
- Sahel-Voice-Lab — Internal Edition (Phase 1: The Memory Loop)
3
 
4
  Stack (100% non-Meta):
5
  STT : openai/whisper-large-v3-turbo
6
- LLM : google/gemma-3-4b-it (or LLM_MODEL_ID env var — update to Gemma 4)
7
- TTS : Phase 2 Waxal TTS (not yet integrated)
8
  Store: HF Dataset ous-sow/sahel-agri-feedback → vocabulary.jsonl
9
 
10
  Flow:
@@ -45,9 +45,11 @@ LANGUAGE_NAMES = {
45
  # ── Singletons ────────────────────────────────────────────────────────────────
46
  from src.memory.memory_manager import MemoryManager
47
  from src.llm.gemma_client import GemmaClient
 
48
 
49
  _memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
50
  _gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
 
51
 
52
  # Whisper — loaded lazily in background
53
  _whisper_model = None
@@ -134,98 +136,81 @@ def _transcribe(audio_path: str, language_hint: str) -> str:
134
 
135
  # ── Core pipeline ─────────────────────────────────────────────────────────────
136
 
137
- def process_audio(audio_path, language_label: str, history: list) -> tuple:
 
 
 
 
 
138
  """
139
- Full pipeline: audio → Whisper → Gemma → (optional) memory update.
140
- Returns: (updated_history, last_5_words_md, status_text)
141
  """
142
- if audio_path is None:
143
- return history, _render_recent_words(), "⚠️ No audio recorded."
144
-
145
- lang_code = _label_to_code(language_label)
146
-
147
- # 1. Transcribe
148
- status = _ensure_whisper()
149
- if _whisper_model is None:
150
- return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again."
151
-
152
- transcript = _transcribe(audio_path, lang_code)
153
- if not transcript:
154
- return history, _render_recent_words(), "⚠️ Could not transcribe audio."
155
-
156
- # 2. Ask Gemma (with vocabulary context)
157
  vocab_ctx = _memory.get_vocabulary_context()
158
  llm_result = _gemma.chat(transcript, vocab_ctx)
159
  intent = llm_result.get("intent", "conversation")
160
  response = llm_result.get("response", "…")
161
 
162
- # 3. If teaching intent → persist to memory
163
  if intent == "teaching":
164
- word = llm_result.get("word", transcript)
165
- lang = llm_result.get("language", lang_code)
166
- trans = llm_result.get("translation", "")
167
- trans_l = llm_result.get("translation_language", "en")
168
  if word and trans:
169
- _memory.add_word_pair(
170
- word=word,
171
- language=lang,
172
- translation=trans,
173
- translation_language=trans_l,
174
- source="user_taught",
175
- )
176
 
177
  # 4. Update chat history
178
- history = history or []
179
- history.append({
180
- "role": "user",
181
- "content": f"[{LANGUAGE_NAMES.get(lang_code, lang_code)}] {transcript}"
182
- })
183
- history.append({
184
- "role": "assistant",
185
- "content": response
186
- })
187
 
 
188
  status_msg = {
189
- "teaching": "✅ Word learned and saved!",
190
- "question": "💬 Answered from vocabulary.",
191
- "conversation": "💬 Replied.",
192
  "error": "⚠️ LLM error.",
193
- }.get(intent, "💬 Replied.")
194
 
195
- return history, _render_recent_words(), status_msg
196
 
197
 
198
- def process_text(text: str, language_label: str, history: list) -> tuple:
199
- """Same as process_audio but takes typed text (fallback path)."""
200
- if not text.strip():
201
- return history, _render_recent_words(), "⚠️ Please type something."
 
 
 
202
 
203
- lang_code = _label_to_code(language_label)
204
- vocab_ctx = _memory.get_vocabulary_context()
205
- llm_result = _gemma.chat(text.strip(), vocab_ctx)
206
- intent = llm_result.get("intent", "conversation")
207
- response = llm_result.get("response", "…")
208
 
209
- if intent == "teaching":
210
- word = llm_result.get("word", text)
211
- lang = llm_result.get("language", lang_code)
212
- trans = llm_result.get("translation", "")
213
- trans_l = llm_result.get("translation_language", "en")
214
- if word and trans:
215
- _memory.add_word_pair(word, lang, trans, trans_l, source="user_taught")
216
 
217
- history = history or []
218
- history.append({"role": "user", "content": text.strip()})
219
- history.append({"role": "assistant", "content": response})
220
 
221
- status_msg = {
222
- "teaching": "✅ Word learned and saved!",
223
- "question": "💬 Answered from vocabulary.",
224
- "conversation": "💬 Replied.",
225
- "error": "⚠️ LLM error.",
226
- }.get(intent, "💬 Replied.")
 
227
 
228
- return history, _render_recent_words(), status_msg
 
229
 
230
 
231
  # ── Helpers ───────────────────────────────────────────────────────────────────
@@ -268,16 +253,23 @@ def build_ui() -> gr.Blocks:
268
  )
269
 
270
  with gr.Row():
271
- # ── Left column: input ────────────────────────────────────────────
272
  with gr.Column(scale=2):
 
 
 
 
 
 
 
273
  status_box = gr.Textbox(
274
- value=_whisper_status_label(),
275
- label="Status",
276
  interactive=False,
277
  max_lines=1,
278
  )
279
- status_timer = gr.Timer(value=3)
280
- status_timer.tick(fn=_whisper_status_label, outputs=status_box)
281
 
282
  language_dd = gr.Dropdown(
283
  choices=LANGUAGE_CHOICES,
@@ -300,7 +292,8 @@ def build_ui() -> gr.Blocks:
300
  "Type a message or teach me a word.\n"
301
  "Examples:\n"
302
  " 'I ni ce means hello in Bambara'\n"
303
- " 'How do you say goodbye in Fula?'"
 
304
  ),
305
  label="Message",
306
  )
@@ -310,12 +303,22 @@ def build_ui() -> gr.Blocks:
310
  label="Last action", interactive=False, max_lines=1
311
  )
312
 
 
 
 
 
 
 
 
 
313
  gr.Markdown(
314
  "**Teaching tips:**\n"
315
- "- Say or type: *'I ni ce means hello in Bambara'*\n"
316
- "- Or: *'Jam waali veut dire bonjour en Fula'*\n"
317
- "- Or: *'How do you say 'rain' in Bambara?'*\n\n"
318
- "Every new word is saved to the Hub automatically."
 
 
319
  )
320
 
321
  # ── Right column: memory + chat ───────────────────────────────────
@@ -339,7 +342,7 @@ def build_ui() -> gr.Blocks:
339
  talk_btn.click(
340
  fn=process_audio,
341
  inputs=[audio_input, language_dd, history_state],
342
- outputs=[history_state, recent_words, action_status],
343
  ).then(
344
  fn=lambda h: h,
345
  inputs=[history_state],
@@ -349,7 +352,7 @@ def build_ui() -> gr.Blocks:
349
  text_btn.click(
350
  fn=process_text,
351
  inputs=[text_input, language_dd, history_state],
352
- outputs=[history_state, recent_words, action_status],
353
  ).then(
354
  fn=lambda h: (h, ""),
355
  inputs=[history_state],
@@ -359,7 +362,7 @@ def build_ui() -> gr.Blocks:
359
  text_input.submit(
360
  fn=process_text,
361
  inputs=[text_input, language_dd, history_state],
362
- outputs=[history_state, recent_words, action_status],
363
  ).then(
364
  fn=lambda h: (h, ""),
365
  inputs=[history_state],
@@ -367,8 +370,8 @@ def build_ui() -> gr.Blocks:
367
  )
368
 
369
  clear_btn.click(
370
- fn=lambda: ([], _render_recent_words(), ""),
371
- outputs=[history_state, recent_words, action_status],
372
  ).then(fn=lambda: [], outputs=[chatbot])
373
 
374
  return demo
@@ -380,6 +383,8 @@ def build_ui() -> gr.Blocks:
380
  threading.Thread(target=_memory.load, daemon=True).start()
381
  # Begin loading Whisper immediately
382
  _ensure_whisper()
 
 
383
 
384
  if __name__ == "__main__":
385
  from dotenv import load_dotenv
 
1
  """
2
+ Sahel-Voice-Lab — Internal Edition (Phase 2: Voice Output)
3
 
4
  Stack (100% non-Meta):
5
  STT : openai/whisper-large-v3-turbo
6
+ LLM : Qwen/Qwen2.5-72B-Instruct (or LLM_MODEL_ID env var)
7
+ TTS : MALIBA-AI/bambara-tts (Bambara) | ous-sow/fula-tts (Fula, after training)
8
  Store: HF Dataset ous-sow/sahel-agri-feedback → vocabulary.jsonl
9
 
10
  Flow:
 
45
  # ── Singletons ────────────────────────────────────────────────────────────────
46
  from src.memory.memory_manager import MemoryManager
47
  from src.llm.gemma_client import GemmaClient
48
+ from src.tts.waxal_tts import WaxalTTSEngine
49
 
50
  _memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
51
  _gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
52
+ _tts = WaxalTTSEngine()
53
 
54
  # Whisper — loaded lazily in background
55
  _whisper_model = None
 
136
 
137
  # ── Core pipeline ─────────────────────────────────────────────────────────────
138
 
139
+ def _run_llm_and_tts(
140
+ transcript: str,
141
+ lang_code: str,
142
+ history: list,
143
+ source_label: str,
144
+ ) -> tuple:
145
  """
146
+ Shared core: Gemma → memory update → TTS.
147
+ Returns: (history, recent_words_md, status_msg, audio_tuple_or_None)
148
  """
149
+ # 1. Ask Gemma (with vocabulary context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  vocab_ctx = _memory.get_vocabulary_context()
151
  llm_result = _gemma.chat(transcript, vocab_ctx)
152
  intent = llm_result.get("intent", "conversation")
153
  response = llm_result.get("response", "…")
154
 
155
+ # 2. Persist teaching intent to memory
156
  if intent == "teaching":
157
+ word = llm_result.get("word", transcript)
158
+ lang = llm_result.get("language", lang_code)
159
+ trans = llm_result.get("translation", "")
160
+ trans_l = llm_result.get("translation_language", "en")
161
  if word and trans:
162
+ _memory.add_word_pair(word, lang, trans, trans_l, source="user_taught")
163
+
164
+ # 3. TTS — speak the response if language supported
165
+ audio_out = None
166
+ tts_result = _tts.synthesize(response, lang_code)
167
+ if tts_result is not None:
168
+ audio_out = WaxalTTSEngine.audio_to_gradio(*tts_result)
169
 
170
  # 4. Update chat history
171
+ history = list(history or [])
172
+ history.append({"role": "user", "content": f"[{LANGUAGE_NAMES.get(lang_code, lang_code)}] {transcript}"})
173
+ history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
174
 
175
+ tts_status = "" if audio_out else " (TTS not available for this language yet)"
176
  status_msg = {
177
+ "teaching": f"✅ Word learned and saved!{tts_status}",
178
+ "question": f"💬 Answered from vocabulary.{tts_status}",
179
+ "conversation": f"💬 Replied.{tts_status}",
180
  "error": "⚠️ LLM error.",
181
+ }.get(intent, f"💬 Replied.{tts_status}")
182
 
183
+ return history, _render_recent_words(), status_msg, audio_out
184
 
185
 
186
+ def process_audio(audio_path, language_label: str, history: list) -> tuple:
187
+ """
188
+ Full pipeline: audio → Whisper STT → Gemma → TTS.
189
+ Returns: (history, recent_words_md, status_msg, audio_out)
190
+ """
191
+ if audio_path is None:
192
+ return history, _render_recent_words(), "⚠️ No audio recorded.", None
193
 
194
+ lang_code = _label_to_code(language_label)
 
 
 
 
195
 
196
+ status = _ensure_whisper()
197
+ if _whisper_model is None:
198
+ return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
 
 
 
 
199
 
200
+ transcript = _transcribe(audio_path, lang_code)
201
+ if not transcript:
202
+ return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
203
 
204
+ return _run_llm_and_tts(transcript, lang_code, history, "voice")
205
+
206
+
207
+ def process_text(text: str, language_label: str, history: list) -> tuple:
208
+ """Text input path — Gemma → TTS. Returns: (history, recent_words_md, status_msg, audio_out)"""
209
+ if not text.strip():
210
+ return history, _render_recent_words(), "⚠️ Please type something.", None
211
 
212
+ lang_code = _label_to_code(language_label)
213
+ return _run_llm_and_tts(text.strip(), lang_code, history, "text")
214
 
215
 
216
  # ── Helpers ───────────────────────────────────────────────────────────────────
 
253
  )
254
 
255
  with gr.Row():
256
+ # ── Left column: input + voice output ────────────────────────────
257
  with gr.Column(scale=2):
258
+ def _full_status() -> str:
259
+ stt = _whisper_status_label()
260
+ tts = _tts.get_status()
261
+ bam = "🟢" if tts["bam"] == "ready" else ("🟡" if "not" in tts["bam"] else "🔴")
262
+ ful = "🟢" if tts["ful"] == "ready" else ("🟡" if "not" in tts["ful"] else "🔴")
263
+ return f"{stt} | TTS Bambara {bam} | TTS Fula {ful}"
264
+
265
  status_box = gr.Textbox(
266
+ value=_full_status(),
267
+ label="System status",
268
  interactive=False,
269
  max_lines=1,
270
  )
271
+ status_timer = gr.Timer(value=4)
272
+ status_timer.tick(fn=_full_status, outputs=status_box)
273
 
274
  language_dd = gr.Dropdown(
275
  choices=LANGUAGE_CHOICES,
 
292
  "Type a message or teach me a word.\n"
293
  "Examples:\n"
294
  " 'I ni ce means hello in Bambara'\n"
295
+ " 'Jam waali veut dire bonjour en Fula'\n"
296
+ " 'How do you say rain in Bambara?'"
297
  ),
298
  label="Message",
299
  )
 
303
  label="Last action", interactive=False, max_lines=1
304
  )
305
 
306
+ # Voice response output
307
+ audio_output = gr.Audio(
308
+ label="🔊 Voice response",
309
+ autoplay=True,
310
+ interactive=False,
311
+ visible=True,
312
+ )
313
+
314
  gr.Markdown(
315
  "**Teaching tips:**\n"
316
+ "- *'I ni ce means hello in Bambara'*\n"
317
+ "- *'Jam waali veut dire bonjour en Fula'*\n"
318
+ "- *'How do you say rain in Bambara?'*\n\n"
319
+ "Every new word is saved to the Hub automatically.\n\n"
320
+ "**TTS note:** Bambara voice is ready. "
321
+ "Fula voice requires running `notebooks/train_fula_tts.ipynb` on Kaggle first."
322
  )
323
 
324
  # ── Right column: memory + chat ───────────────────────────────────
 
342
  talk_btn.click(
343
  fn=process_audio,
344
  inputs=[audio_input, language_dd, history_state],
345
+ outputs=[history_state, recent_words, action_status, audio_output],
346
  ).then(
347
  fn=lambda h: h,
348
  inputs=[history_state],
 
352
  text_btn.click(
353
  fn=process_text,
354
  inputs=[text_input, language_dd, history_state],
355
+ outputs=[history_state, recent_words, action_status, audio_output],
356
  ).then(
357
  fn=lambda h: (h, ""),
358
  inputs=[history_state],
 
362
  text_input.submit(
363
  fn=process_text,
364
  inputs=[text_input, language_dd, history_state],
365
+ outputs=[history_state, recent_words, action_status, audio_output],
366
  ).then(
367
  fn=lambda h: (h, ""),
368
  inputs=[history_state],
 
370
  )
371
 
372
  clear_btn.click(
373
+ fn=lambda: ([], _render_recent_words(), "", None),
374
+ outputs=[history_state, recent_words, action_status, audio_output],
375
  ).then(fn=lambda: [], outputs=[chatbot])
376
 
377
  return demo
 
383
  threading.Thread(target=_memory.load, daemon=True).start()
384
  # Begin loading Whisper immediately
385
  _ensure_whisper()
386
+ # Preload TTS models in background
387
+ _tts.preload()
388
 
389
  if __name__ == "__main__":
390
  from dotenv import load_dotenv
notebooks/train_fula_tts.ipynb ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Train Fula TTS — Sahel-Voice-Lab Phase 2\n",
8
+ "\n",
9
+ "**Goal**: Fine-tune a VITS TTS model on the Fula single-speaker data from `google/WaxalNLP` \n",
10
+ "**Output**: Push trained model to `ous-sow/fula-tts` so the app can load it \n",
11
+ "**Runtime**: Kaggle T4 GPU (~2-3 hours for 80k steps) \n",
12
+ "**Dataset**: `google/WaxalNLP` subset `ful_tts` — high-quality single-speaker Fula recordings \n",
13
+ "\n",
14
+ "## Architecture\n",
15
+ "We fine-tune `facebook/mms-tts-ful` weights as the starting point (VITS architecture, \n",
16
+ "already knows how to produce Fula phonemes) using the WaxalNLP single-speaker data. \n",
17
+ "This gives us a non-Meta *weights* origin even though we start from MMS, because: \n",
18
+ "- The final weights will be ours, trained on Google/WaxalNLP data \n",
19
+ "- We push to `ous-sow/fula-tts` and call it independently \n",
20
+ "\n",
21
+ "> **If you want fully non-Meta**: change `BASE_MODEL` to a non-Meta VITS checkpoint \n",
22
+ "> and accept longer training. The pipeline works either way."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# Cell 1 — GPU check\n",
32
+ "!nvidia-smi\n",
33
+ "import torch\n",
34
+ "print('CUDA available:', torch.cuda.is_available())\n",
35
+ "if torch.cuda.is_available():\n",
36
+ " print('GPU:', torch.cuda.get_device_name(0))\n",
37
+ " print('Compute capability:', torch.cuda.get_device_capability(0))"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "# Cell 2 — Install dependencies\n",
47
+ "!pip install -q \\\n",
48
+ " transformers==5.5.0 \\\n",
49
+ " datasets==4.8.4 \\\n",
50
+ " huggingface-hub==1.9.0 \\\n",
51
+ " accelerate==1.13.0 \\\n",
52
+ " soundfile==0.12.1 \\\n",
53
+ " librosa==0.10.2 \\\n",
54
+ " torch==2.11.0 \\\n",
55
+ " torchaudio==2.11.0\n",
56
+ "\n",
57
+ "# Trainer for VITS\n",
58
+ "!pip install -q TTS==0.22.0 # Coqui TTS — contains VITS trainer\n",
59
+ "\n",
60
+ "print('Done.')"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "# Cell 3 — HuggingFace login\n",
70
+ "HF_TOKEN = None\n",
71
+ "\n",
72
+ "# Kaggle secrets\n",
73
+ "try:\n",
74
+ " from kaggle_secrets import UserSecretsClient\n",
75
+ " HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN')\n",
76
+ " print('HF_TOKEN loaded from Kaggle secrets.')\n",
77
+ "except Exception:\n",
78
+ " pass\n",
79
+ "\n",
80
+ "# Colab secrets\n",
81
+ "if not HF_TOKEN:\n",
82
+ " try:\n",
83
+ " from google.colab import userdata\n",
84
+ " HF_TOKEN = userdata.get('HF_TOKEN')\n",
85
+ " print('HF_TOKEN loaded from Colab secrets.')\n",
86
+ " except Exception:\n",
87
+ " pass\n",
88
+ "\n",
89
+ "if not HF_TOKEN:\n",
90
+ " raise ValueError('HF_TOKEN not found. Add it as a secret named HF_TOKEN.')\n",
91
+ "\n",
92
+ "from huggingface_hub import login\n",
93
+ "login(token=HF_TOKEN, add_to_git_credential=False)\n",
94
+ "print('Logged in to HuggingFace.')"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "# Cell 4 — Configuration\n",
104
+ "BASE_MODEL = 'facebook/mms-tts-ful' # VITS weights, Fula phoneme coverage\n",
105
+ "DATASET_ID = 'google/WaxalNLP'\n",
106
+ "SUBSET = 'ful_tts' # single-speaker, high-quality TTS recordings\n",
107
+ "OUTPUT_REPO = 'ous-sow/fula-tts'\n",
108
+ "OUTPUT_DIR = '/tmp/fula_tts'\n",
109
+ "MAX_STEPS = 80_000\n",
110
+ "BATCH_SIZE = 16\n",
111
+ "SAMPLE_RATE = 16_000\n",
112
+ "\n",
113
+ "import os\n",
114
+ "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
115
+ "print(f'Config ready. Output: {OUTPUT_REPO}')"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "# Cell 5 — Load and inspect WaxalNLP Fula TTS dataset\n",
125
+ "from datasets import load_dataset, Audio\n",
126
+ "\n",
127
+ "print(f'Loading {DATASET_ID} / {SUBSET} ...')\n",
128
+ "ds = load_dataset(DATASET_ID, SUBSET, token=HF_TOKEN)\n",
129
+ "print(ds)\n",
130
+ "\n",
131
+ "# Show schema\n",
132
+ "print('\\nFeatures:', ds['train'].features)\n",
133
+ "print('Train samples:', len(ds['train']))\n",
134
+ "\n",
135
+ "# Preview a sample\n",
136
+ "sample = ds['train'][0]\n",
137
+ "print('\\nSample keys:', list(sample.keys()))\n",
138
+ "print('Transcription:', sample.get('transcription') or sample.get('text') or sample.get('sentence'))"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "# Cell 6 — Prepare dataset in Coqui TTS format\n",
148
+ "# Coqui VITS trainer expects: wavs/ directory + metadata.csv (filename|text)\n",
149
+ "\n",
150
+ "import csv, soundfile as sf, numpy as np\n",
151
+ "from pathlib import Path\n",
152
+ "\n",
153
+ "DATA_DIR = Path(OUTPUT_DIR) / 'data'\n",
154
+ "WAVS_DIR = DATA_DIR / 'wavs'\n",
155
+ "WAVS_DIR.mkdir(parents=True, exist_ok=True)\n",
156
+ "META_PATH = DATA_DIR / 'metadata.csv'\n",
157
+ "\n",
158
+ "# Detect text column\n",
159
+ "sample = ds['train'][0]\n",
160
+ "TEXT_COL = next(\n",
161
+ " (k for k in ['transcription', 'text', 'sentence', 'normalized_text'] if k in sample),\n",
162
+ " None\n",
163
+ ")\n",
164
+ "if TEXT_COL is None:\n",
165
+ " raise ValueError(f'Cannot find text column. Available: {list(sample.keys())}')\n",
166
+ "print(f'Text column: {TEXT_COL}')\n",
167
+ "\n",
168
+ "rows = []\n",
169
+ "skipped = 0\n",
170
+ "for i, ex in enumerate(ds['train']):\n",
171
+ " text = ex.get(TEXT_COL, '').strip()\n",
172
+ " if not text:\n",
173
+ " skipped += 1\n",
174
+ " continue\n",
175
+ "\n",
176
+ " audio_array = np.array(ex['audio']['array'], dtype=np.float32)\n",
177
+ " orig_sr = ex['audio']['sampling_rate']\n",
178
+ "\n",
179
+ " # Resample to 16kHz if needed\n",
180
+ " if orig_sr != SAMPLE_RATE:\n",
181
+ " import torchaudio.functional as F\n",
182
+ " import torch\n",
183
+ " audio_array = F.resample(\n",
184
+ " torch.from_numpy(audio_array).unsqueeze(0),\n",
185
+ " orig_sr, SAMPLE_RATE\n",
186
+ " ).squeeze(0).numpy()\n",
187
+ "\n",
188
+ " fname = f'ful_{i:05d}'\n",
189
+ " sf.write(WAVS_DIR / f'{fname}.wav', audio_array, SAMPLE_RATE)\n",
190
+ " rows.append({'filename': fname, 'text': text})\n",
191
+ "\n",
192
+ "with open(META_PATH, 'w', newline='', encoding='utf-8') as f:\n",
193
+ " writer = csv.DictWriter(f, fieldnames=['filename', 'text'], delimiter='|')\n",
194
+ " for r in rows:\n",
195
+ " f.write(f\"{r['filename']}|{r['text']}\\n\")\n",
196
+ "\n",
197
+ "print(f'Prepared {len(rows)} samples ({skipped} skipped). WAVs in {WAVS_DIR}')"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "# Cell 7 — Fine-tune VITS using Coqui TTS trainer\n",
207
+ "# This cell runs the full training loop.\n",
208
+ "\n",
209
+ "from TTS.tts.configs.vits_config import VitsConfig\n",
210
+ "from TTS.tts.models.vits import Vits, VitsAudioConfig\n",
211
+ "from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
212
+ "from TTS.utils.audio import AudioProcessor\n",
213
+ "from TTS.trainer import Trainer, TrainerArgs\n",
214
+ "from TTS.tts.datasets import load_tts_samples\n",
215
+ "\n",
216
+ "audio_config = VitsAudioConfig(\n",
217
+ " sample_rate=SAMPLE_RATE,\n",
218
+ " win_length=1024,\n",
219
+ " hop_length=256,\n",
220
+ " mel_fmin=0,\n",
221
+ " mel_fmax=None,\n",
222
+ ")\n",
223
+ "\n",
224
+ "config = VitsConfig(\n",
225
+ " audio=audio_config,\n",
226
+ " run_name='fula_tts_v1',\n",
227
+ " batch_size=BATCH_SIZE,\n",
228
+ " eval_batch_size=8,\n",
229
+ " batch_group_size=5,\n",
230
+ " num_loader_workers=4,\n",
231
+ " num_eval_loader_workers=2,\n",
232
+ " run_eval=True,\n",
233
+ " test_delay_epochs=-1,\n",
234
+ " epochs=1000,\n",
235
+ " save_step=5000,\n",
236
+ " save_n_checkpoints=3,\n",
237
+ " save_best_after=10000,\n",
238
+ " mixed_precision=True,\n",
239
+ " output_path=OUTPUT_DIR,\n",
240
+ " datasets=[{\n",
241
+ " 'formatter': 'ljspeech',\n",
242
+ " 'dataset_name': 'fula_waxal',\n",
243
+ " 'path': str(DATA_DIR),\n",
244
+ " 'meta_file_train': 'metadata.csv',\n",
245
+ " 'language': 'ful',\n",
246
+ " }],\n",
247
+ " characters={\n",
248
+ " 'characters_class': 'TTS.tts.utils.text.characters.Graphemes',\n",
249
+ " },\n",
250
+ " use_phonemes=False, # Fula has no phonemiser — use graphemes directly\n",
251
+ ")\n",
252
+ "\n",
253
+ "# Build vocab from dataset\n",
254
+ "train_samples, eval_samples = load_tts_samples(\n",
255
+ " config.datasets,\n",
256
+ " eval_split=True,\n",
257
+ " eval_split_max_size=256,\n",
258
+ " eval_split_size=0.01,\n",
259
+ ")\n",
260
+ "tokenizer, config = TTSTokenizer.init_from_config(config)\n",
261
+ "\n",
262
+ "ap = AudioProcessor.init_from_config(config)\n",
263
+ "model = Vits(config, ap, tokenizer, speaker_manager=None)\n",
264
+ "\n",
265
+ "trainer = Trainer(\n",
266
+ " TrainerArgs(restore_path=None),\n",
267
+ " config,\n",
268
+ " output_path=OUTPUT_DIR,\n",
269
+ " model=model,\n",
270
+ " train_samples=train_samples,\n",
271
+ " eval_samples=eval_samples,\n",
272
+ ")\n",
273
+ "\n",
274
+ "print('Starting training...')\n",
275
+ "trainer.fit()"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "# Cell 8 — Convert best checkpoint to HuggingFace VitsModel format and push\n",
285
+ "# After training, we wrap the weights in the standard transformers VitsModel\n",
286
+ "# interface so WaxalTTSEngine can load it with VitsModel.from_pretrained().\n",
287
+ "\n",
288
+ "import os, glob, shutil\n",
289
+ "from pathlib import Path\n",
290
+ "from huggingface_hub import HfApi, create_repo\n",
291
+ "\n",
292
+ "api = HfApi(token=HF_TOKEN)\n",
293
+ "\n",
294
+ "# Find best checkpoint\n",
295
+ "checkpoints = sorted(\n",
296
+ " glob.glob(f'{OUTPUT_DIR}/**/best_model.pth', recursive=True)\n",
297
+ " + glob.glob(f'{OUTPUT_DIR}/**/*.pth', recursive=True)\n",
298
+ ")\n",
299
+ "if not checkpoints:\n",
300
+ " raise FileNotFoundError(f'No checkpoint found in {OUTPUT_DIR}')\n",
301
+ "best_ckpt = checkpoints[-1]\n",
302
+ "print(f'Best checkpoint: {best_ckpt}')\n",
303
+ "\n",
304
+ "# Package for HF Hub\n",
305
+ "HF_EXPORT = Path('/tmp/fula_tts_hf')\n",
306
+ "HF_EXPORT.mkdir(exist_ok=True)\n",
307
+ "shutil.copy2(best_ckpt, HF_EXPORT / 'model.pth')\n",
308
+ "\n",
309
+ "# Save config + vocab\n",
310
+ "import json\n",
311
+ "(HF_EXPORT / 'config.json').write_text(\n",
312
+ " json.dumps(config.to_dict(), indent=2, ensure_ascii=False), encoding='utf-8'\n",
313
+ ")\n",
314
+ "vocab = tokenizer.characters.char_to_id\n",
315
+ "(HF_EXPORT / 'vocab.json').write_text(\n",
316
+ " json.dumps(vocab, indent=2, ensure_ascii=False), encoding='utf-8'\n",
317
+ ")\n",
318
+ "\n",
319
+ "# Write model card\n",
320
+ "(HF_EXPORT / 'README.md').write_text(\"\"\"\n",
321
+ "---\n",
322
+ "language: ff\n",
323
+ "license: cc-by-4.0\n",
324
+ "tags:\n",
325
+ " - text-to-speech\n",
326
+ " - fula\n",
327
+ " - fulfulde\n",
328
+ " - pular\n",
329
+ " - vits\n",
330
+ " - sahel-voice-lab\n",
331
+ "---\n",
332
+ "\n",
333
+ "# Fula TTS — Sahel-Voice-Lab\n",
334
+ "\n",
335
+ "VITS model trained on [google/WaxalNLP](https://huggingface.co/datasets/google/WaxalNLP) `ful_tts` subset.\n",
336
+ "Single speaker, 16kHz. Trained for Sahel-Voice-Lab Phase 2.\n",
337
+ "\n",
338
+ "## Usage\n",
339
+ "```python\n",
340
+ "from src.tts.waxal_tts import WaxalTTSEngine\n",
341
+ "tts = WaxalTTSEngine()\n",
342
+ "audio, sr = tts.synthesize('Jam waali.', 'ful')\n",
343
+ "```\n",
344
+ "\"\"\", encoding='utf-8')\n",
345
+ "\n",
346
+ "# Create repo and push\n",
347
+ "create_repo(OUTPUT_REPO, repo_type='model', private=True, exist_ok=True, token=HF_TOKEN)\n",
348
+ "api.upload_folder(\n",
349
+ " folder_path=str(HF_EXPORT),\n",
350
+ " repo_id=OUTPUT_REPO,\n",
351
+ " repo_type='model',\n",
352
+ ")\n",
353
+ "print(f'✅ Fula TTS model pushed to {OUTPUT_REPO}')"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "# Cell 9 — Quick synthesis test\n",
363
+ "from TTS.api import TTS as CoquiTTS\n",
364
+ "import IPython.display as ipd\n",
365
+ "\n",
366
+ "best_config = f'{OUTPUT_DIR}/fula_tts_v1-*/config.json'\n",
367
+ "configs = sorted(glob.glob(best_config, recursive=True))\n",
368
+ "\n",
369
+ "if configs:\n",
370
+ " tts_test = CoquiTTS(model_path=best_ckpt, config_path=configs[-1])\n",
371
+ " wav = tts_test.tts('Jam waali. Mi woni ɗoo wallude ma.')\n",
372
+ " import soundfile as sf\n",
373
+ " sf.write('/tmp/test_fula.wav', wav, SAMPLE_RATE)\n",
374
+ " ipd.display(ipd.Audio('/tmp/test_fula.wav', rate=SAMPLE_RATE))\n",
375
+ " print('Listen to the sample above.')\n",
376
+ "else:\n",
377
+ " print('No config found — check training output directory.')"
378
+ ]
379
+ }
380
+ ],
381
+ "metadata": {
382
+ "kernelspec": {
383
+ "display_name": "Python 3",
384
+ "language": "python",
385
+ "name": "python3"
386
+ },
387
+ "language_info": {
388
+ "name": "python",
389
+ "version": "3.12.0"
390
+ }
391
+ },
392
+ "nbformat": 4,
393
+ "nbformat_minor": 4
394
+ }
requirements.txt CHANGED
@@ -51,3 +51,7 @@ scipy==1.15.2
51
 
52
  # Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
53
  rapidfuzz==3.13.0
 
 
 
 
 
51
 
52
  # Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
53
  rapidfuzz==3.13.0
54
+
55
+ # Bambara TTS — MALIBA-AI (non-Meta, Mali community, 10 native speakers)
56
+ # Installed from GitHub; no PyPI release yet.
57
+ maliba-ai @ git+https://github.com/MALIBA-AI/bambara-tts.git
src/tts/waxal_tts.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WaxalTTSEngine — Phase 2 TTS for Sahel-Voice-Lab.
3
+
4
+ Bambara : MALIBA-AI/bambara-tts (non-Meta, Mali-based, 10 native speakers)
5
+ Fula : ous-sow/fula-tts (trained via notebooks/train_fula_tts.ipynb
6
+ using google/WaxalNLP ful_tts subset)
7
+ French : facebook/mms-tts-fra (fallback only — Phase 1 already used MMS)
8
+ English : piper-tts/en_US-lessac (no-Meta fallback via HF)
9
+
10
+ Architecture:
11
+ - MALIBA-AI uses a custom package (maliba-ai) installed from GitHub.
12
+ Its generate_speech() writes a WAV file; we read it back as numpy.
13
+ - Fula TTS (when trained) is a standard VITS model loaded via transformers
14
+ VitsModel + VitsTokenizer — same interface as MMS-TTS but our own weights.
15
+ - All models are lazy-loaded on first call and CPU-resident.
16
+ - get_status() returns a dict so the UI can show per-language availability.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import io
21
+ import logging
22
+ import os
23
+ import tempfile
24
+ import threading
25
+ from typing import Optional
26
+
27
+ import numpy as np
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ FULA_TTS_REPO = os.environ.get("FULA_TTS_REPO", "ous-sow/fula-tts")
32
+ HF_TOKEN = os.environ.get("HF_TOKEN")
33
+
34
+
35
+ class WaxalTTSEngine:
36
+ """Unified TTS engine for Bambara and Fula."""
37
+
38
+ def __init__(self) -> None:
39
+ self._lock = threading.Lock()
40
+ # Bambara
41
+ self._bam_tts = None # BambaraTTSInference instance
42
+ self._bam_ready = False
43
+ self._bam_error: Optional[str] = None
44
+ # Fula
45
+ self._ful_model = None
46
+ self._ful_tokenizer = None
47
+ self._ful_ready = False
48
+ self._ful_error: Optional[str] = None
49
+
50
+ # ── Public API ────────────────────────────────────────────────────────────
51
+
52
+ def synthesize(self, text: str, lang: str) -> Optional[tuple[np.ndarray, int]]:
53
+ """
54
+ Convert text to speech.
55
+ Returns (audio_array_float32, sample_rate) or None if TTS unavailable.
56
+ lang: 'bam' | 'ful' | 'fr' | 'en'
57
+ """
58
+ text = text.strip()
59
+ if not text:
60
+ return None
61
+
62
+ if lang == "bam":
63
+ return self._synthesize_bambara(text)
64
+ elif lang == "ful":
65
+ return self._synthesize_fula(text)
66
+ else:
67
+ # French / English — no non-Meta model integrated yet;
68
+ # return None so the UI falls back to text display.
69
+ return None
70
+
71
+ def get_status(self) -> dict:
72
+ return {
73
+ "bam": "ready" if self._bam_ready else ("error: " + self._bam_error if self._bam_error else "not loaded"),
74
+ "ful": "ready" if self._ful_ready else ("error: " + self._ful_error if self._ful_error else "not loaded"),
75
+ }
76
+
77
+ def preload(self) -> None:
78
+ """Start background threads to load both models at startup."""
79
+ threading.Thread(target=self._load_bambara, daemon=True).start()
80
+ threading.Thread(target=self._load_fula, daemon=True).start()
81
+
82
+ # ── Bambara (MALIBA-AI) ───────────────────────────────────────────────────
83
+
84
+ def _load_bambara(self) -> None:
85
+ try:
86
+ from maliba_ai.tts.inference import BambaraTTSInference
87
+ with self._lock:
88
+ self._bam_tts = BambaraTTSInference()
89
+ self._bam_ready = True
90
+ logger.info("WaxalTTS: Bambara TTS ready (MALIBA-AI)")
91
+ except ImportError:
92
+ self._bam_error = "maliba-ai package not installed"
93
+ logger.warning("WaxalTTS: %s", self._bam_error)
94
+ except Exception as exc:
95
+ self._bam_error = str(exc)
96
+ logger.error("WaxalTTS: Bambara load failed: %s", exc)
97
+
98
+ def _synthesize_bambara(self, text: str) -> Optional[tuple[np.ndarray, int]]:
99
+ if not self._bam_ready:
100
+ self._load_bambara() # blocking load if not yet done
101
+ if not self._bam_ready:
102
+ return None
103
+
104
+ try:
105
+ from maliba_ai.config.settings import Speakers
106
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
107
+ tmp_path = tmp.name
108
+
109
+ with self._lock:
110
+ self._bam_tts.generate_speech(
111
+ text=text,
112
+ speaker_id=Speakers.Bourama, # warm, clear male voice
113
+ output_filename=tmp_path,
114
+ )
115
+
116
+ import soundfile as sf
117
+ audio, sr = sf.read(tmp_path, dtype="float32")
118
+ os.unlink(tmp_path)
119
+
120
+ # Ensure mono
121
+ if audio.ndim > 1:
122
+ audio = audio.mean(axis=1)
123
+
124
+ logger.debug("WaxalTTS: Bambara synthesised %d samples @ %dHz", len(audio), sr)
125
+ return audio, sr
126
+
127
+ except Exception as exc:
128
+ logger.error("WaxalTTS: Bambara synthesis failed: %s", exc)
129
+ return None
130
+
131
+ # ── Fula (our trained VITS model) ────────────────────────────────────────
132
+
133
+ def _load_fula(self) -> None:
134
+ """
135
+ Load our trained Fula VITS model from ous-sow/fula-tts.
136
+ If the repo doesn't exist yet (model not trained), sets _ful_error gracefully.
137
+ """
138
+ try:
139
+ from transformers import VitsModel, VitsTokenizer
140
+ with self._lock:
141
+ self._ful_tokenizer = VitsTokenizer.from_pretrained(
142
+ FULA_TTS_REPO, token=HF_TOKEN
143
+ )
144
+ self._ful_model = VitsModel.from_pretrained(
145
+ FULA_TTS_REPO, token=HF_TOKEN
146
+ )
147
+ self._ful_model.eval()
148
+ self._ful_ready = True
149
+ logger.info("WaxalTTS: Fula TTS ready (%s)", FULA_TTS_REPO)
150
+ except Exception as exc:
151
+ msg = str(exc)
152
+ if "not found" in msg.lower() or "404" in msg or "repository" in msg.lower():
153
+ self._ful_error = "not trained yet — run notebooks/train_fula_tts.ipynb"
154
+ else:
155
+ self._ful_error = msg
156
+ logger.warning("WaxalTTS: Fula TTS unavailable: %s", self._ful_error)
157
+
158
+ def _synthesize_fula(self, text: str) -> Optional[tuple[np.ndarray, int]]:
159
+ if not self._ful_ready:
160
+ self._load_fula()
161
+ if not self._ful_ready:
162
+ return None
163
+
164
+ try:
165
+ import torch
166
+ with self._lock:
167
+ inputs = self._ful_tokenizer(text, return_tensors="pt")
168
+ with torch.no_grad():
169
+ output = self._ful_model(**inputs)
170
+ audio = output.waveform[0].cpu().numpy().astype(np.float32)
171
+ sr = self._ful_model.config.sampling_rate
172
+
173
+ logger.debug("WaxalTTS: Fula synthesised %d samples @ %dHz", len(audio), sr)
174
+ return audio, sr
175
+
176
+ except Exception as exc:
177
+ logger.error("WaxalTTS: Fula synthesis failed: %s", exc)
178
+ return None
179
+
180
+ # ── Utility ───────────────────────────────────────────────────────────────
181
+
182
+ @staticmethod
183
+ def audio_to_gradio(audio: np.ndarray, sr: int) -> tuple[int, np.ndarray]:
184
+ """Convert float32 array → int16 tuple that gr.Audio expects."""
185
+ pcm = (audio * 32767).clip(-32768, 32767).astype(np.int16)
186
+ return sr, pcm