jefffffff9 Claude Sonnet 4.6 commited on
Commit
3657607
·
1 Parent(s): 082adaa

Add confidence loop, curiosity engine, and lightweight TTS

Browse files

Task 1 — TTS refactor (src/tts/waxal_tts.py):
Switch Bambara TTS from Qwen2-based MALIBA-AI to
ynnov/ekodi-bambara-tts-female (VitsModel + AutoTokenizer) — much
lighter on CPU Basic; no trust_remote_code needed. Fula is an
explicit generate_pular_tts() placeholder returning None until the
model is trained.

Task 2 — Active learning / confidence loop (src/engine/stt_processor.py):
transcribe_with_confidence() wraps Whisper generate() with
output_scores=True and computes avg_logprob via
compute_transition_scores(). If avg_logprob < -1.0, app_lab.py
replaces the transcript with CONFUSION_PROMPT so the LLM asks the
user in English to repeat and explain the word.

Task 3 — Proactive gaps (src/engine/curiosity.py):
CuriosityEngine.maybe_ask() fires every 5 interactions — sends the
last 10 vocabulary entries to Qwen and appends a 🌱 question in the
chat asking the user to teach a missing agricultural term.

Task 4 — Zero-cost persistence:
MemoryManager._push_to_hub() was already async (background thread +
HfApi.upload_file). No changes needed.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

app_lab.py CHANGED
@@ -46,13 +46,20 @@ LANGUAGE_NAMES = {
46
  }
47
 
48
  # ── Singletons ────────────────────────────────────────────────────────────────
49
- from src.memory.memory_manager import MemoryManager
50
- from src.llm.gemma_client import GemmaClient
51
- from src.tts.waxal_tts import WaxalTTSEngine
52
-
53
- _memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
54
- _gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
55
- _tts = WaxalTTSEngine()
 
 
 
 
 
 
 
56
 
57
  # Whisper — loaded lazily in background
58
  _whisper_model = None
@@ -103,38 +110,36 @@ def _whisper_status_label() -> str:
103
  return f"⚪ STT {s}"
104
 
105
 
106
- def _transcribe(audio_path: str, language_hint: str) -> str:
107
- """Run Whisper STT. Returns transcribed text."""
 
 
 
108
  if _whisper_model is None:
109
- return ""
110
- import torch, librosa
111
  audio_np, _ = librosa.load(audio_path, sr=16_000, mono=True)
112
 
 
 
 
 
 
 
 
 
 
 
 
113
  with _whisper_lock:
114
- inputs = _whisper_processor.feature_extractor(
115
- audio_np, sampling_rate=16_000, return_tensors="pt"
 
 
 
116
  )
117
- input_features = inputs.input_features
118
 
119
- # Whisper doesn't have Bambara / Fula tokens — let it auto-detect
120
- if language_hint in ("bam", "ful"):
121
- forced_ids = None
122
- else:
123
- try:
124
- forced_ids = _whisper_processor.get_decoder_prompt_ids(
125
- language=language_hint, task="transcribe"
126
- )
127
- except Exception:
128
- forced_ids = None
129
-
130
- with torch.no_grad():
131
- predicted_ids = _whisper_model.generate(
132
- input_features,
133
- forced_decoder_ids=forced_ids,
134
- max_new_tokens=256,
135
- )
136
-
137
- return _whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
138
 
139
 
140
  # ── Core pipeline ─────────────────────────────────────────────────────────────
@@ -175,6 +180,11 @@ def _run_llm_and_tts(
175
  history.append({"role": "user", "content": f"[{LANGUAGE_NAMES.get(lang_code, lang_code)}] {transcript}"})
176
  history.append({"role": "assistant", "content": response})
177
 
 
 
 
 
 
178
  tts_status = "" if audio_out else " (TTS not available for this language yet)"
179
  status_msg = {
180
  "teaching": f"✅ Word learned and saved!{tts_status}",
@@ -201,10 +211,18 @@ def process_audio(audio_path, language_label: str, history: list) -> tuple:
201
  if _whisper_model is None:
202
  return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
203
 
204
- transcript = _transcribe(audio_path, lang_code)
205
  if not transcript:
206
  return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
207
 
 
 
 
 
 
 
 
 
208
  return _run_llm_and_tts(transcript, lang_code, history, "voice")
209
  except Exception as exc:
210
  logger.exception("process_audio error")
 
46
  }
47
 
48
  # ── Singletons ────────────────────────────────────────────────────────────────
49
+ from src.memory.memory_manager import MemoryManager
50
+ from src.llm.gemma_client import GemmaClient
51
+ from src.tts.waxal_tts import WaxalTTSEngine
52
+ from src.engine.stt_processor import (
53
+ transcribe_with_confidence,
54
+ LOW_CONFIDENCE_THRESHOLD,
55
+ CONFUSION_PROMPT,
56
+ )
57
+ from src.engine.curiosity import CuriosityEngine
58
+
59
+ _memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
60
+ _gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
61
+ _tts = WaxalTTSEngine()
62
+ _curiosity = CuriosityEngine(interval=5)
63
 
64
  # Whisper — loaded lazily in background
65
  _whisper_model = None
 
110
  return f"⚪ STT {s}"
111
 
112
 
113
+ def _transcribe(audio_path: str, language_hint: str) -> tuple[str, float]:
114
+ """
115
+ Run Whisper STT with confidence scoring.
116
+ Returns (text, avg_logprob). avg_logprob < LOW_CONFIDENCE_THRESHOLD → confused.
117
+ """
118
  if _whisper_model is None:
119
+ return "", 0.0
120
+ import librosa
121
  audio_np, _ = librosa.load(audio_path, sr=16_000, mono=True)
122
 
123
+ # Whisper has no Bambara/Fula tokens — skip forced language for those
124
+ if language_hint in ("bam", "ful"):
125
+ forced_ids = None
126
+ else:
127
+ try:
128
+ forced_ids = _whisper_processor.get_decoder_prompt_ids(
129
+ language=language_hint, task="transcribe"
130
+ )
131
+ except Exception:
132
+ forced_ids = None
133
+
134
  with _whisper_lock:
135
+ text, avg_logprob = transcribe_with_confidence(
136
+ audio_np,
137
+ _whisper_model,
138
+ _whisper_processor,
139
+ forced_ids,
140
  )
 
141
 
142
+ return text, avg_logprob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  # ── Core pipeline ─────────────────────────────────────────────────────────────
 
180
  history.append({"role": "user", "content": f"[{LANGUAGE_NAMES.get(lang_code, lang_code)}] {transcript}"})
181
  history.append({"role": "assistant", "content": response})
182
 
183
+ # 5. Curiosity check — every 5 interactions, ask about a vocabulary gap
184
+ curiosity_q = _curiosity.maybe_ask(_memory, _gemma)
185
+ if curiosity_q:
186
+ history.append({"role": "assistant", "content": f"🌱 {curiosity_q}"})
187
+
188
  tts_status = "" if audio_out else " (TTS not available for this language yet)"
189
  status_msg = {
190
  "teaching": f"✅ Word learned and saved!{tts_status}",
 
211
  if _whisper_model is None:
212
  return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
213
 
214
+ transcript, avg_logprob = _transcribe(audio_path, lang_code)
215
  if not transcript:
216
  return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
217
 
218
+ # Low-confidence transcription → ask user to repeat and explain
219
+ if avg_logprob < LOW_CONFIDENCE_THRESHOLD:
220
+ logger.info(
221
+ "Low STT confidence (avg_logprob=%.3f) — switching to confusion prompt",
222
+ avg_logprob,
223
+ )
224
+ transcript = CONFUSION_PROMPT
225
+
226
  return _run_llm_and_tts(transcript, lang_code, history, "voice")
227
  except Exception as exc:
228
  logger.exception("process_audio error")
src/engine/curiosity.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CuriosityEngine — proactive vocabulary gap analysis.
3
+
4
+ Every N interactions (default: 5), sends the last 10 vocabulary entries to
5
+ the LLM and asks it to identify one related agricultural / everyday term that
6
+ is missing from the learner's vocabulary, then formulate a question asking the
7
+ user how to say that word in their language.
8
+
9
+ Usage in app_lab.py:
10
+ _curiosity = CuriosityEngine(interval=5)
11
+
12
+ # Inside _run_llm_and_tts, after the main LLM call:
13
+ question = _curiosity.maybe_ask(_memory, _gemma)
14
+ if question:
15
+ history.append({"role": "assistant", "content": f"🌱 {question}"})
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from typing import TYPE_CHECKING, Optional
21
+
22
+ if TYPE_CHECKING:
23
+ from src.memory.memory_manager import MemoryManager
24
+ from src.llm.gemma_client import GemmaClient
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ _CURIOSITY_SYSTEM = """\
29
+ You are a language-learning assistant that notices gaps in a West African vocabulary list.
30
+ Reply with a single valid JSON object and nothing else.\
31
+ """
32
+
33
+ _CURIOSITY_USER_TEMPLATE = """\
34
+ Here are the {n} most recent words I have learned so far:
35
+ {vocab_list}
36
+
37
+ Based on these words, what is ONE related agricultural or common everyday term \
38
+ I am likely missing? Formulate a short, warm question asking the user how to say \
39
+ that missing word in their language.
40
+
41
+ Reply only with this JSON:
42
+ {{
43
+ "word_suggestion": "<the English word you think is missing>",
44
+ "question": "<one friendly sentence asking the user>"
45
+ }}
46
+ """
47
+
48
+
49
+ class CuriosityEngine:
50
+ """Fires a vocabulary-gap prompt every `interval` user interactions."""
51
+
52
+ def __init__(self, interval: int = 5) -> None:
53
+ self._interval = interval
54
+ self._interaction = 0
55
+
56
+ def maybe_ask(
57
+ self,
58
+ memory: "MemoryManager",
59
+ gemma: "GemmaClient",
60
+ ) -> Optional[str]:
61
+ """
62
+ Increment the interaction counter. On every `interval`-th call, query
63
+ the LLM for a missing vocabulary term and return the question string.
64
+ Returns None on all other calls, or if vocabulary is too sparse, or if
65
+ the LLM call fails.
66
+ """
67
+ self._interaction += 1
68
+ if self._interaction % self._interval != 0:
69
+ return None
70
+
71
+ entries = memory.get_all()
72
+ if len(entries) < 3:
73
+ logger.debug("CuriosityEngine: vocabulary too sparse (%d entries)", len(entries))
74
+ return None
75
+
76
+ recent = entries[-10:]
77
+ lines = [
78
+ f" [{e.get('language','?')}] {e.get('word','')} = {e.get('translation','')}"
79
+ for e in recent
80
+ ]
81
+ prompt = _CURIOSITY_USER_TEMPLATE.format(
82
+ n=len(lines),
83
+ vocab_list="\n".join(lines),
84
+ )
85
+
86
+ try:
87
+ # Pass the curiosity prompt as user text; empty vocab context to avoid
88
+ # duplicating the word list inside the system prompt.
89
+ result = gemma.chat(prompt, vocabulary_context="(see above)")
90
+ question = result.get("question") or result.get("response")
91
+ if question:
92
+ word = result.get("word_suggestion", "")
93
+ logger.info(
94
+ "CuriosityEngine: suggesting '%s' — %s",
95
+ word,
96
+ question[:80],
97
+ )
98
+ return question.strip()
99
+ except Exception as exc:
100
+ logger.warning("CuriosityEngine: LLM call failed: %s", exc)
101
+
102
+ return None
src/engine/stt_processor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ STT confidence extractor.
3
+
4
+ Wraps Whisper's generate() with return_dict_in_generate=True to compute
5
+ avg_logprob — the mean log-probability over generated tokens. This mirrors
6
+ the avg_logprob field returned by the OpenAI Whisper API.
7
+
8
+ Threshold: avg_logprob < -1.0 signals a low-confidence transcription where
9
+ the model was essentially guessing. The caller should treat this as "confused"
10
+ and prompt the user to repeat and explain the word.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+
16
+ import numpy as np
17
+ import torch
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Anything below this is considered "confused" transcription
22
+ LOW_CONFIDENCE_THRESHOLD: float = -1.0
23
+
24
+ # Message substituted for the transcript when confidence is low
25
+ CONFUSION_PROMPT: str = (
26
+ "The user spoke, but I am confused. "
27
+ "Ask the user in English to repeat the local word and explain its meaning."
28
+ )
29
+
30
+
31
+ def transcribe_with_confidence(
32
+ audio_np: np.ndarray,
33
+ model,
34
+ processor,
35
+ forced_ids,
36
+ max_new_tokens: int = 256,
37
+ ) -> tuple[str, float]:
38
+ """
39
+ Run Whisper and return (text, avg_logprob).
40
+
41
+ avg_logprob is in (-inf, 0]. A value close to 0 means high confidence.
42
+ Returns avg_logprob = 0.0 if computation fails (treated as confident).
43
+
44
+ Args:
45
+ audio_np: float32 audio at 16 kHz.
46
+ model: WhisperForConditionalGeneration instance.
47
+ processor: WhisperProcessor instance.
48
+ forced_ids: Output of get_decoder_prompt_ids() or None.
49
+ max_new_tokens: Maximum tokens to generate.
50
+ """
51
+ inputs = processor.feature_extractor(
52
+ audio_np, sampling_rate=16_000, return_tensors="pt"
53
+ )
54
+ input_features = inputs.input_features
55
+
56
+ with torch.no_grad():
57
+ output = model.generate(
58
+ input_features,
59
+ forced_decoder_ids=forced_ids,
60
+ max_new_tokens=max_new_tokens,
61
+ return_dict_in_generate=True,
62
+ output_scores=True,
63
+ )
64
+
65
+ text = processor.batch_decode(output.sequences, skip_special_tokens=True)[0].strip()
66
+
67
+ # Compute avg log-prob via model.compute_transition_scores
68
+ avg_logprob = 0.0
69
+ try:
70
+ transition_scores = model.compute_transition_scores(
71
+ output.sequences,
72
+ output.scores,
73
+ normalize_logits=True,
74
+ )
75
+ # Shape: (batch, generated_len). Take batch[0], skip zero-padded positions.
76
+ scores = transition_scores[0]
77
+ valid = scores[scores != 0]
78
+ if valid.numel() > 0:
79
+ avg_logprob = valid.mean().item()
80
+ except Exception as exc:
81
+ logger.debug("avg_logprob computation failed: %s", exc)
82
+
83
+ logger.debug(
84
+ "STT confidence: avg_logprob=%.3f text=%r",
85
+ avg_logprob,
86
+ text[:60],
87
+ )
88
+ return text, avg_logprob
src/tts/waxal_tts.py CHANGED
@@ -1,23 +1,13 @@
1
  """
2
- WaxalTTSEngine — Phase 2 TTS for Sahel-Voice-Lab.
3
-
4
- Bambara : MALIBA-AI/bambara-tts loaded directly via transformers
5
- (avoids pip-installing maliba-ai at runtime, which fails because
6
- HF Spaces blocks outbound GitHub connections)
7
- Fula : ous-sow/fula-tts (VITS, trained via notebooks/train_fula_tts.ipynb)
8
- French/English : not yet integrated — returns None (text-only fallback)
9
-
10
- Architecture:
11
- MALIBA-AI uses a Qwen2-based architecture. We load it with
12
- AutoModelForCausalLM + AutoTokenizer, run greedy decoding, and extract
13
- the waveform from the model output — matching what BambaraTTSInference does
14
- internally without needing the package installed.
15
  """
16
  from __future__ import annotations
17
 
18
  import logging
19
  import os
20
- import tempfile
21
  import threading
22
  from typing import Optional
23
 
@@ -25,13 +15,23 @@ import numpy as np
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
- BAMBARA_TTS_REPO = "MALIBA-AI/bambara-tts"
29
- FULA_TTS_REPO = os.environ.get("FULA_TTS_REPO", "ous-sow/fula-tts")
30
  HF_TOKEN = os.environ.get("HF_TOKEN")
31
 
32
 
 
 
 
 
 
 
 
 
 
 
33
  class WaxalTTSEngine:
34
- """Unified TTS engine for Bambara and Fula."""
35
 
36
  def __init__(self) -> None:
37
  self._lock = threading.Lock()
@@ -40,18 +40,13 @@ class WaxalTTSEngine:
40
  self._bam_tokenizer = None
41
  self._bam_ready = False
42
  self._bam_error: Optional[str] = None
43
- # Fula
44
- self._ful_model = None
45
- self._ful_tokenizer = None
46
- self._ful_ready = False
47
- self._ful_error: Optional[str] = None
48
 
49
  # ── Public API ────────────────────────────────────────────────────────────
50
 
51
  def synthesize(self, text: str, lang: str) -> Optional[tuple[np.ndarray, int]]:
52
  """
53
  Returns (audio_float32, sample_rate) or None if TTS unavailable.
54
- Never raises — all errors are logged and None is returned.
55
  """
56
  text = text.strip()
57
  if not text:
@@ -60,7 +55,7 @@ class WaxalTTSEngine:
60
  if lang == "bam":
61
  return self._synthesize_bambara(text)
62
  elif lang == "ful":
63
- return self._synthesize_fula(text)
64
  else:
65
  return None
66
  except Exception as exc:
@@ -71,45 +66,26 @@ class WaxalTTSEngine:
71
  bam = "ready" if self._bam_ready else (
72
  f"error: {self._bam_error}" if self._bam_error else "loading…"
73
  )
74
- ful = "ready" if self._ful_ready else (
75
- f"error: {self._ful_error}" if self._ful_error else "not trained yet"
76
- )
77
- return {"bam": bam, "ful": ful}
78
 
79
  def preload(self) -> None:
80
- """Start background threads to load both models."""
81
  threading.Thread(target=self._load_bambara, daemon=True).start()
82
- threading.Thread(target=self._load_fula, daemon=True).start()
83
 
84
- # ── Bambara (MALIBA-AI/bambara-tts via AutoModel) ─────────────────────────
85
 
86
  def _load_bambara(self) -> None:
87
- """
88
- Load MALIBA-AI/bambara-tts directly from HF Hub using transformers.
89
- No pip install needed — just model weights downloaded to the HF cache.
90
- """
91
  try:
92
- from transformers import AutoTokenizer, AutoModelForCausalLM
93
- import torch
94
-
95
  logger.info("WaxalTTS: loading Bambara TTS from %s …", BAMBARA_TTS_REPO)
96
- tok = AutoTokenizer.from_pretrained(
97
- BAMBARA_TTS_REPO, token=HF_TOKEN, trust_remote_code=True
98
- )
99
- mdl = AutoModelForCausalLM.from_pretrained(
100
- BAMBARA_TTS_REPO,
101
- token=HF_TOKEN,
102
- trust_remote_code=True,
103
- torch_dtype=torch.float32,
104
- )
105
  mdl.eval()
106
-
107
  with self._lock:
108
  self._bam_tokenizer = tok
109
  self._bam_model = mdl
110
  self._bam_ready = True
111
  logger.info("WaxalTTS: Bambara TTS ready")
112
-
113
  except Exception as exc:
114
  self._bam_error = str(exc)
115
  logger.error("WaxalTTS: Bambara TTS load failed: %s", exc)
@@ -118,89 +94,21 @@ class WaxalTTSEngine:
118
  if not self._bam_ready:
119
  self._load_bambara()
120
  if not self._bam_ready:
121
- logger.warning("WaxalTTS: Bambara TTS not ready (%s)", self._bam_error)
122
- return None
123
-
124
- try:
125
- import torch, soundfile as sf
126
-
127
- with self._lock:
128
- inputs = self._bam_tokenizer(
129
- text, return_tensors="pt", add_special_tokens=True
130
- )
131
- with torch.no_grad():
132
- output = self._bam_model.generate(
133
- **inputs,
134
- max_new_tokens=1024,
135
- do_sample=False,
136
- )
137
-
138
- # MALIBA-AI model returns waveform tokens — decode to audio
139
- # The model's generate() returns a waveform directly when it has
140
- # an audio head; try standard attribute paths.
141
- audio = None
142
- sr = 16_000
143
-
144
- if hasattr(output, "waveform"):
145
- audio = output.waveform[0].cpu().float().numpy()
146
- elif hasattr(output, "audio"):
147
- audio = output.audio[0].cpu().float().numpy()
148
- else:
149
- # Fallback: treat output as token ids and use vocoder if present
150
- logger.warning(
151
- "WaxalTTS: Bambara model output type %s — expected waveform attribute",
152
- type(output)
153
- )
154
- return None
155
-
156
- if audio.ndim > 1:
157
- audio = audio.mean(axis=1)
158
- return audio.astype(np.float32), sr
159
-
160
- except Exception as exc:
161
- logger.error("WaxalTTS: Bambara synthesis failed: %s", exc)
162
- self._bam_error = str(exc)
163
- self._bam_ready = False
164
- return None
165
-
166
- # ── Fula (ous-sow/fula-tts, VITS) ────────────────────────────────────────
167
-
168
- def _load_fula(self) -> None:
169
- try:
170
- from transformers import VitsModel, VitsTokenizer
171
- logger.info("WaxalTTS: loading Fula TTS from %s …", FULA_TTS_REPO)
172
- tok = VitsTokenizer.from_pretrained(FULA_TTS_REPO, token=HF_TOKEN)
173
- mdl = VitsModel.from_pretrained(FULA_TTS_REPO, token=HF_TOKEN)
174
- mdl.eval()
175
- with self._lock:
176
- self._ful_tokenizer = tok
177
- self._ful_model = mdl
178
- self._ful_ready = True
179
- logger.info("WaxalTTS: Fula TTS ready")
180
- except Exception as exc:
181
- msg = str(exc)
182
- if any(k in msg.lower() for k in ("not found", "404", "repository", "does not exist")):
183
- self._ful_error = "not trained yet — run notebooks/train_fula_tts.ipynb on Kaggle"
184
- else:
185
- self._ful_error = msg
186
- logger.warning("WaxalTTS: Fula TTS unavailable: %s", self._ful_error)
187
-
188
- def _synthesize_fula(self, text: str) -> Optional[tuple[np.ndarray, int]]:
189
- if not self._ful_ready:
190
- self._load_fula()
191
- if not self._ful_ready:
192
  return None
193
  try:
194
  import torch
195
  with self._lock:
196
- inputs = self._ful_tokenizer(text, return_tensors="pt")
197
  with torch.no_grad():
198
- output = self._ful_model(**inputs)
199
  audio = output.waveform[0].cpu().numpy().astype(np.float32)
200
- sr = self._ful_model.config.sampling_rate
201
  return audio, sr
202
  except Exception as exc:
203
- logger.error("WaxalTTS: Fula synthesis failed: %s", exc)
 
 
204
  return None
205
 
206
  # ── Utility ───────────────────────────────────────────────────────────────
 
1
  """
2
+ WaxalTTSEngine — lightweight VITS-based TTS for Sahel-Voice-Lab.
3
+
4
+ Bambara : ynnov/ekodi-bambara-tts-female (VitsModel + AutoTokenizer)
5
+ Fula : placeholder returns None until ous-sow/fula-tts is trained
 
 
 
 
 
 
 
 
 
6
  """
7
  from __future__ import annotations
8
 
9
  import logging
10
  import os
 
11
  import threading
12
  from typing import Optional
13
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
+ BAMBARA_TTS_REPO = os.environ.get("BAMBARA_TTS_REPO", "ynnov/ekodi-bambara-tts-female")
19
+ FULA_TTS_REPO = os.environ.get("FULA_TTS_REPO", "ous-sow/fula-tts")
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
 
22
 
23
+ def generate_pular_tts(text: str) -> None:
24
+ """
25
+ Placeholder for Fula (Pulaar) TTS.
26
+ Returns None until ous-sow/fula-tts is trained and pushed to the Hub.
27
+ Run notebooks/train_fula_tts.ipynb on Kaggle T4 to produce the model.
28
+ """
29
+ logger.info("generate_pular_tts: model not yet trained — returning None")
30
+ return None
31
+
32
+
33
  class WaxalTTSEngine:
34
+ """Unified TTS engine: Bambara (VITS) + Fula (placeholder)."""
35
 
36
  def __init__(self) -> None:
37
  self._lock = threading.Lock()
 
40
  self._bam_tokenizer = None
41
  self._bam_ready = False
42
  self._bam_error: Optional[str] = None
 
 
 
 
 
43
 
44
  # ── Public API ────────────────────────────────────────────────────────────
45
 
46
  def synthesize(self, text: str, lang: str) -> Optional[tuple[np.ndarray, int]]:
47
  """
48
  Returns (audio_float32, sample_rate) or None if TTS unavailable.
49
+ Never raises — all errors are logged.
50
  """
51
  text = text.strip()
52
  if not text:
 
55
  if lang == "bam":
56
  return self._synthesize_bambara(text)
57
  elif lang == "ful":
58
+ return generate_pular_tts(text)
59
  else:
60
  return None
61
  except Exception as exc:
 
66
  bam = "ready" if self._bam_ready else (
67
  f"error: {self._bam_error}" if self._bam_error else "loading…"
68
  )
69
+ return {"bam": bam, "ful": "not trained yet"}
 
 
 
70
 
71
  def preload(self) -> None:
72
+ """Start background thread to load the Bambara model."""
73
  threading.Thread(target=self._load_bambara, daemon=True).start()
 
74
 
75
+ # ── Bambara (ynnov/ekodi-bambara-tts-female, VITS) ───────────────────────
76
 
77
  def _load_bambara(self) -> None:
 
 
 
 
78
  try:
79
+ from transformers import VitsModel, AutoTokenizer
 
 
80
  logger.info("WaxalTTS: loading Bambara TTS from %s …", BAMBARA_TTS_REPO)
81
+ tok = AutoTokenizer.from_pretrained(BAMBARA_TTS_REPO, token=HF_TOKEN)
82
+ mdl = VitsModel.from_pretrained(BAMBARA_TTS_REPO, token=HF_TOKEN)
 
 
 
 
 
 
 
83
  mdl.eval()
 
84
  with self._lock:
85
  self._bam_tokenizer = tok
86
  self._bam_model = mdl
87
  self._bam_ready = True
88
  logger.info("WaxalTTS: Bambara TTS ready")
 
89
  except Exception as exc:
90
  self._bam_error = str(exc)
91
  logger.error("WaxalTTS: Bambara TTS load failed: %s", exc)
 
94
  if not self._bam_ready:
95
  self._load_bambara()
96
  if not self._bam_ready:
97
+ logger.warning("WaxalTTS: Bambara TTS not ready %s", self._bam_error)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return None
99
  try:
100
  import torch
101
  with self._lock:
102
+ inputs = self._bam_tokenizer(text, return_tensors="pt")
103
  with torch.no_grad():
104
+ output = self._bam_model(**inputs)
105
  audio = output.waveform[0].cpu().numpy().astype(np.float32)
106
+ sr = self._bam_model.config.sampling_rate
107
  return audio, sr
108
  except Exception as exc:
109
+ logger.error("WaxalTTS: Bambara synthesis failed: %s", exc)
110
+ self._bam_error = str(exc)
111
+ self._bam_ready = False
112
  return None
113
 
114
  # ── Utility ───────────────────────────────────────────────────────────────