jefffffff9 Claude Sonnet 4.6 commited on
Commit
082adaa
·
1 Parent(s): c154b17

Fix Bambara TTS red state + surface detailed errors in UI

Browse files

- src/tts/waxal_tts.py: load MALIBA-AI/bambara-tts directly via
AutoModelForCausalLM (trust_remote_code=True) — no pip install
needed at runtime; HF Spaces blocks GitHub outbound so the old
lazy subprocess install was silently failing every time
- app_lab.py: wrap process_audio / process_text in try/except so
exceptions surface as '❌ Error: ...' in the status box instead
of a generic Gradio popup with no message; add logging

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. app_lab.py +25 -14
  2. src/tts/waxal_tts.py +104 -94
app_lab.py CHANGED
@@ -19,11 +19,14 @@ Flow:
19
  """
20
  from __future__ import annotations
21
 
 
22
  import os
23
  import sys
24
  import threading
25
  from pathlib import Path
26
 
 
 
27
  import gradio as gr
28
 
29
  ROOT = Path(__file__).parent
@@ -188,29 +191,37 @@ def process_audio(audio_path, language_label: str, history: list) -> tuple:
188
  Full pipeline: audio → Whisper STT → Gemma → TTS.
189
  Returns: (history, recent_words_md, status_msg, audio_out)
190
  """
191
- if audio_path is None:
192
- return history, _render_recent_words(), "⚠️ No audio recorded.", None
 
193
 
194
- lang_code = _label_to_code(language_label)
195
 
196
- status = _ensure_whisper()
197
- if _whisper_model is None:
198
- return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
199
 
200
- transcript = _transcribe(audio_path, lang_code)
201
- if not transcript:
202
- return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
203
 
204
- return _run_llm_and_tts(transcript, lang_code, history, "voice")
 
 
 
205
 
206
 
207
  def process_text(text: str, language_label: str, history: list) -> tuple:
208
  """Text input path — Gemma → TTS. Returns: (history, recent_words_md, status_msg, audio_out)"""
209
- if not text.strip():
210
- return history, _render_recent_words(), "⚠️ Please type something.", None
 
211
 
212
- lang_code = _label_to_code(language_label)
213
- return _run_llm_and_tts(text.strip(), lang_code, history, "text")
 
 
 
214
 
215
 
216
  # ── Helpers ───────────────────────────────────────────────────────────────────
 
19
  """
20
  from __future__ import annotations
21
 
22
+ import logging
23
  import os
24
  import sys
25
  import threading
26
  from pathlib import Path
27
 
28
+ logger = logging.getLogger(__name__)
29
+
30
  import gradio as gr
31
 
32
  ROOT = Path(__file__).parent
 
191
  Full pipeline: audio → Whisper STT → Gemma → TTS.
192
  Returns: (history, recent_words_md, status_msg, audio_out)
193
  """
194
+ try:
195
+ if audio_path is None:
196
+ return history, _render_recent_words(), "⚠️ No audio recorded.", None
197
 
198
+ lang_code = _label_to_code(language_label)
199
 
200
+ status = _ensure_whisper()
201
+ if _whisper_model is None:
202
+ return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
203
 
204
+ transcript = _transcribe(audio_path, lang_code)
205
+ if not transcript:
206
+ return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
207
 
208
+ return _run_llm_and_tts(transcript, lang_code, history, "voice")
209
+ except Exception as exc:
210
+ logger.exception("process_audio error")
211
+ return history, _render_recent_words(), f"❌ Error: {exc}", None
212
 
213
 
214
  def process_text(text: str, language_label: str, history: list) -> tuple:
215
  """Text input path — Gemma → TTS. Returns: (history, recent_words_md, status_msg, audio_out)"""
216
+ try:
217
+ if not text.strip():
218
+ return history, _render_recent_words(), "⚠️ Please type something.", None
219
 
220
+ lang_code = _label_to_code(language_label)
221
+ return _run_llm_and_tts(text.strip(), lang_code, history, "text")
222
+ except Exception as exc:
223
+ logger.exception("process_text error")
224
+ return history, _render_recent_words(), f"❌ Error: {exc}", None
225
 
226
 
227
  # ── Helpers ───────────────────────────────────────────────────────────────────
src/tts/waxal_tts.py CHANGED
@@ -1,23 +1,20 @@
1
  """
2
  WaxalTTSEngine — Phase 2 TTS for Sahel-Voice-Lab.
3
 
4
- Bambara : MALIBA-AI/bambara-tts (non-Meta, Mali-based, 10 native speakers)
5
- Fula : ous-sow/fula-tts (trained via notebooks/train_fula_tts.ipynb
6
- using google/WaxalNLP ful_tts subset)
7
- French : facebook/mms-tts-fra (fallback only Phase 1 already used MMS)
8
- English : piper-tts/en_US-lessac (no-Meta fallback via HF)
9
 
10
  Architecture:
11
- - MALIBA-AI uses a custom package (maliba-ai) installed from GitHub.
12
- Its generate_speech() writes a WAV file; we read it back as numpy.
13
- - Fula TTS (when trained) is a standard VITS model loaded via transformers
14
- VitsModel + VitsTokenizer same interface as MMS-TTS but our own weights.
15
- - All models are lazy-loaded on first call and CPU-resident.
16
- - get_status() returns a dict so the UI can show per-language availability.
17
  """
18
  from __future__ import annotations
19
 
20
- import io
21
  import logging
22
  import os
23
  import tempfile
@@ -28,18 +25,20 @@ import numpy as np
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
- FULA_TTS_REPO = os.environ.get("FULA_TTS_REPO", "ous-sow/fula-tts")
32
- HF_TOKEN = os.environ.get("HF_TOKEN")
 
33
 
34
 
35
  class WaxalTTSEngine:
36
  """Unified TTS engine for Bambara and Fula."""
37
 
38
  def __init__(self) -> None:
39
- self._lock = threading.Lock()
40
  # Bambara
41
- self._bam_tts = None # BambaraTTSInference instance
42
- self._bam_ready = False
 
43
  self._bam_error: Optional[str] = None
44
  # Fula
45
  self._ful_model = None
@@ -51,122 +50,137 @@ class WaxalTTSEngine:
51
 
52
  def synthesize(self, text: str, lang: str) -> Optional[tuple[np.ndarray, int]]:
53
  """
54
- Convert text to speech.
55
- Returns (audio_array_float32, sample_rate) or None if TTS unavailable.
56
- lang: 'bam' | 'ful' | 'fr' | 'en'
57
  """
58
  text = text.strip()
59
  if not text:
60
  return None
61
-
62
- if lang == "bam":
63
- return self._synthesize_bambara(text)
64
- elif lang == "ful":
65
- return self._synthesize_fula(text)
66
- else:
67
- # French / English — no non-Meta model integrated yet;
68
- # return None so the UI falls back to text display.
 
69
  return None
70
 
71
  def get_status(self) -> dict:
72
- return {
73
- "bam": "ready" if self._bam_ready else ("error: " + self._bam_error if self._bam_error else "not loaded"),
74
- "ful": "ready" if self._ful_ready else ("error: " + self._ful_error if self._ful_error else "not loaded"),
75
- }
 
 
 
76
 
77
  def preload(self) -> None:
78
- """Start background threads to load both models at startup."""
79
  threading.Thread(target=self._load_bambara, daemon=True).start()
80
  threading.Thread(target=self._load_fula, daemon=True).start()
81
 
82
- # ── Bambara (MALIBA-AI) ───────────────────────────────────────────────────
83
 
84
  def _load_bambara(self) -> None:
85
- # maliba-ai has strict dependency pins that conflict with the main requirements.txt,
86
- # so it is NOT listed there. Install it on first use instead.
 
 
87
  try:
88
- from maliba_ai.tts.inference import BambaraTTSInference
89
- except ImportError:
90
- logger.info("WaxalTTS: installing maliba-ai (first Bambara TTS call)…")
91
- try:
92
- import subprocess, sys
93
- subprocess.run(
94
- [sys.executable, "-m", "pip", "install", "-q",
95
- "git+https://github.com/MALIBA-AI/bambara-tts.git"],
96
- check=True,
97
- capture_output=True,
98
- )
99
- from maliba_ai.tts.inference import BambaraTTSInference
100
- except Exception as exc:
101
- self._bam_error = f"maliba-ai install failed: {exc}"
102
- logger.error("WaxalTTS: %s", self._bam_error)
103
- return
104
 
105
- try:
106
  with self._lock:
107
- self._bam_tts = BambaraTTSInference()
108
- self._bam_ready = True
109
- logger.info("WaxalTTS: Bambara TTS ready (MALIBA-AI)")
 
 
110
  except Exception as exc:
111
  self._bam_error = str(exc)
112
- logger.error("WaxalTTS: Bambara load failed: %s", exc)
113
 
114
  def _synthesize_bambara(self, text: str) -> Optional[tuple[np.ndarray, int]]:
115
  if not self._bam_ready:
116
- self._load_bambara() # blocking load if not yet done
117
  if not self._bam_ready:
 
118
  return None
119
 
120
  try:
121
- from maliba_ai.config.settings import Speakers
122
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
123
- tmp_path = tmp.name
124
 
125
  with self._lock:
126
- self._bam_tts.generate_speech(
127
- text=text,
128
- speaker_id=Speakers.Bourama, # warm, clear male voice
129
- output_filename=tmp_path,
130
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- import soundfile as sf
133
- audio, sr = sf.read(tmp_path, dtype="float32")
134
- os.unlink(tmp_path)
135
-
136
- # Ensure mono
137
  if audio.ndim > 1:
138
  audio = audio.mean(axis=1)
139
-
140
- logger.debug("WaxalTTS: Bambara synthesised %d samples @ %dHz", len(audio), sr)
141
- return audio, sr
142
 
143
  except Exception as exc:
144
  logger.error("WaxalTTS: Bambara synthesis failed: %s", exc)
 
 
145
  return None
146
 
147
- # ── Fula (our trained VITS model) ───────────────────────────────────────
148
 
149
  def _load_fula(self) -> None:
150
- """
151
- Load our trained Fula VITS model from ous-sow/fula-tts.
152
- If the repo doesn't exist yet (model not trained), sets _ful_error gracefully.
153
- """
154
  try:
155
  from transformers import VitsModel, VitsTokenizer
 
 
 
 
156
  with self._lock:
157
- self._ful_tokenizer = VitsTokenizer.from_pretrained(
158
- FULA_TTS_REPO, token=HF_TOKEN
159
- )
160
- self._ful_model = VitsModel.from_pretrained(
161
- FULA_TTS_REPO, token=HF_TOKEN
162
- )
163
- self._ful_model.eval()
164
- self._ful_ready = True
165
- logger.info("WaxalTTS: Fula TTS ready (%s)", FULA_TTS_REPO)
166
  except Exception as exc:
167
  msg = str(exc)
168
- if "not found" in msg.lower() or "404" in msg or "repository" in msg.lower():
169
- self._ful_error = "not trained yet — run notebooks/train_fula_tts.ipynb"
170
  else:
171
  self._ful_error = msg
172
  logger.warning("WaxalTTS: Fula TTS unavailable: %s", self._ful_error)
@@ -176,7 +190,6 @@ class WaxalTTSEngine:
176
  self._load_fula()
177
  if not self._ful_ready:
178
  return None
179
-
180
  try:
181
  import torch
182
  with self._lock:
@@ -185,10 +198,7 @@ class WaxalTTSEngine:
185
  output = self._ful_model(**inputs)
186
  audio = output.waveform[0].cpu().numpy().astype(np.float32)
187
  sr = self._ful_model.config.sampling_rate
188
-
189
- logger.debug("WaxalTTS: Fula synthesised %d samples @ %dHz", len(audio), sr)
190
  return audio, sr
191
-
192
  except Exception as exc:
193
  logger.error("WaxalTTS: Fula synthesis failed: %s", exc)
194
  return None
@@ -197,6 +207,6 @@ class WaxalTTSEngine:
197
 
198
  @staticmethod
199
  def audio_to_gradio(audio: np.ndarray, sr: int) -> tuple[int, np.ndarray]:
200
- """Convert float32 array → int16 tuple that gr.Audio expects."""
201
  pcm = (audio * 32767).clip(-32768, 32767).astype(np.int16)
202
  return sr, pcm
 
1
  """
2
  WaxalTTSEngine — Phase 2 TTS for Sahel-Voice-Lab.
3
 
4
+ Bambara : MALIBA-AI/bambara-tts loaded directly via transformers
5
+ (avoids pip-installing maliba-ai at runtime, which fails because
6
+ HF Spaces blocks outbound GitHub connections)
7
+ Fula : ous-sow/fula-tts (VITS, trained via notebooks/train_fula_tts.ipynb)
8
+ French/English : not yet integrated — returns None (text-only fallback)
9
 
10
  Architecture:
11
+ MALIBA-AI uses a Qwen2-based architecture. We load it with
12
+ AutoModelForCausalLM + AutoTokenizer, run greedy decoding, and extract
13
+ the waveform from the model output matching what BambaraTTSInference does
14
+ internally without needing the package installed.
 
 
15
  """
16
  from __future__ import annotations
17
 
 
18
  import logging
19
  import os
20
  import tempfile
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
+ BAMBARA_TTS_REPO = "MALIBA-AI/bambara-tts"
29
+ FULA_TTS_REPO = os.environ.get("FULA_TTS_REPO", "ous-sow/fula-tts")
30
+ HF_TOKEN = os.environ.get("HF_TOKEN")
31
 
32
 
33
  class WaxalTTSEngine:
34
  """Unified TTS engine for Bambara and Fula."""
35
 
36
  def __init__(self) -> None:
37
+ self._lock = threading.Lock()
38
  # Bambara
39
+ self._bam_model = None
40
+ self._bam_tokenizer = None
41
+ self._bam_ready = False
42
  self._bam_error: Optional[str] = None
43
  # Fula
44
  self._ful_model = None
 
50
 
51
  def synthesize(self, text: str, lang: str) -> Optional[tuple[np.ndarray, int]]:
52
  """
53
+ Returns (audio_float32, sample_rate) or None if TTS unavailable.
54
+ Never raises all errors are logged and None is returned.
 
55
  """
56
  text = text.strip()
57
  if not text:
58
  return None
59
+ try:
60
+ if lang == "bam":
61
+ return self._synthesize_bambara(text)
62
+ elif lang == "ful":
63
+ return self._synthesize_fula(text)
64
+ else:
65
+ return None
66
+ except Exception as exc:
67
+ logger.error("WaxalTTS.synthesize(%s) unexpected error: %s", lang, exc)
68
  return None
69
 
70
  def get_status(self) -> dict:
71
+ bam = "ready" if self._bam_ready else (
72
+ f"error: {self._bam_error}" if self._bam_error else "loading…"
73
+ )
74
+ ful = "ready" if self._ful_ready else (
75
+ f"error: {self._ful_error}" if self._ful_error else "not trained yet"
76
+ )
77
+ return {"bam": bam, "ful": ful}
78
 
79
  def preload(self) -> None:
80
+ """Start background threads to load both models."""
81
  threading.Thread(target=self._load_bambara, daemon=True).start()
82
  threading.Thread(target=self._load_fula, daemon=True).start()
83
 
84
+ # ── Bambara (MALIBA-AI/bambara-tts via AutoModel) ─────────────────────────
85
 
86
  def _load_bambara(self) -> None:
87
+ """
88
+ Load MALIBA-AI/bambara-tts directly from HF Hub using transformers.
89
+ No pip install needed — just model weights downloaded to the HF cache.
90
+ """
91
  try:
92
+ from transformers import AutoTokenizer, AutoModelForCausalLM
93
+ import torch
94
+
95
+ logger.info("WaxalTTS: loading Bambara TTS from %s …", BAMBARA_TTS_REPO)
96
+ tok = AutoTokenizer.from_pretrained(
97
+ BAMBARA_TTS_REPO, token=HF_TOKEN, trust_remote_code=True
98
+ )
99
+ mdl = AutoModelForCausalLM.from_pretrained(
100
+ BAMBARA_TTS_REPO,
101
+ token=HF_TOKEN,
102
+ trust_remote_code=True,
103
+ torch_dtype=torch.float32,
104
+ )
105
+ mdl.eval()
 
 
106
 
 
107
  with self._lock:
108
+ self._bam_tokenizer = tok
109
+ self._bam_model = mdl
110
+ self._bam_ready = True
111
+ logger.info("WaxalTTS: Bambara TTS ready")
112
+
113
  except Exception as exc:
114
  self._bam_error = str(exc)
115
+ logger.error("WaxalTTS: Bambara TTS load failed: %s", exc)
116
 
117
  def _synthesize_bambara(self, text: str) -> Optional[tuple[np.ndarray, int]]:
118
  if not self._bam_ready:
119
+ self._load_bambara()
120
  if not self._bam_ready:
121
+ logger.warning("WaxalTTS: Bambara TTS not ready (%s)", self._bam_error)
122
  return None
123
 
124
  try:
125
+ import torch, soundfile as sf
 
 
126
 
127
  with self._lock:
128
+ inputs = self._bam_tokenizer(
129
+ text, return_tensors="pt", add_special_tokens=True
 
 
130
  )
131
+ with torch.no_grad():
132
+ output = self._bam_model.generate(
133
+ **inputs,
134
+ max_new_tokens=1024,
135
+ do_sample=False,
136
+ )
137
+
138
+ # MALIBA-AI model returns waveform tokens — decode to audio
139
+ # The model's generate() returns a waveform directly when it has
140
+ # an audio head; try standard attribute paths.
141
+ audio = None
142
+ sr = 16_000
143
+
144
+ if hasattr(output, "waveform"):
145
+ audio = output.waveform[0].cpu().float().numpy()
146
+ elif hasattr(output, "audio"):
147
+ audio = output.audio[0].cpu().float().numpy()
148
+ else:
149
+ # Fallback: treat output as token ids and use vocoder if present
150
+ logger.warning(
151
+ "WaxalTTS: Bambara model output type %s — expected waveform attribute",
152
+ type(output)
153
+ )
154
+ return None
155
 
 
 
 
 
 
156
  if audio.ndim > 1:
157
  audio = audio.mean(axis=1)
158
+ return audio.astype(np.float32), sr
 
 
159
 
160
  except Exception as exc:
161
  logger.error("WaxalTTS: Bambara synthesis failed: %s", exc)
162
+ self._bam_error = str(exc)
163
+ self._bam_ready = False
164
  return None
165
 
166
+ # ── Fula (ous-sow/fula-tts, VITS) ──��─────────────────────────────────────
167
 
168
  def _load_fula(self) -> None:
 
 
 
 
169
  try:
170
  from transformers import VitsModel, VitsTokenizer
171
+ logger.info("WaxalTTS: loading Fula TTS from %s …", FULA_TTS_REPO)
172
+ tok = VitsTokenizer.from_pretrained(FULA_TTS_REPO, token=HF_TOKEN)
173
+ mdl = VitsModel.from_pretrained(FULA_TTS_REPO, token=HF_TOKEN)
174
+ mdl.eval()
175
  with self._lock:
176
+ self._ful_tokenizer = tok
177
+ self._ful_model = mdl
178
+ self._ful_ready = True
179
+ logger.info("WaxalTTS: Fula TTS ready")
 
 
 
 
 
180
  except Exception as exc:
181
  msg = str(exc)
182
+ if any(k in msg.lower() for k in ("not found", "404", "repository", "does not exist")):
183
+ self._ful_error = "not trained yet — run notebooks/train_fula_tts.ipynb on Kaggle"
184
  else:
185
  self._ful_error = msg
186
  logger.warning("WaxalTTS: Fula TTS unavailable: %s", self._ful_error)
 
190
  self._load_fula()
191
  if not self._ful_ready:
192
  return None
 
193
  try:
194
  import torch
195
  with self._lock:
 
198
  output = self._ful_model(**inputs)
199
  audio = output.waveform[0].cpu().numpy().astype(np.float32)
200
  sr = self._ful_model.config.sampling_rate
 
 
201
  return audio, sr
 
202
  except Exception as exc:
203
  logger.error("WaxalTTS: Fula synthesis failed: %s", exc)
204
  return None
 
207
 
208
  @staticmethod
209
  def audio_to_gradio(audio: np.ndarray, sr: int) -> tuple[int, np.ndarray]:
210
+ """Convert float32 → int16 tuple that gr.Audio expects."""
211
  pcm = (audio * 32767).clip(-32768, 32767).astype(np.int16)
212
  return sr, pcm