jefffffff9 Claude Sonnet 4.6 commited on
Commit
49910a9
·
1 Parent(s): 3657607

Add multi-user speaker profiles, collective voice, and mode toggle

Browse files

Task 1 — SpeakerProfileManager (src/voice/speaker_profiles.py):
SpeechBrain ECAPA-TDNN identifies speakers by cosine similarity of
192-d embeddings (threshold 0.75). Each user gets user_N_sb.npy
(running-average SpeechBrain embedding) and user_N_ov.npy
(running-average OpenVoice V2 tone-color SE). New utterances update
the running average so profiles sharpen over time.

Task 2 — get_collective_embedding() (speaker_profiles.py):
Loads all user_N_ov.npy files and returns their mean vector — the
"Collective Voice" that blends all known speakers into one SE.

Task 3 — Voice Mode toggle (app_lab.py + src/tts/voice_cloner.py):
Added Individual / Collective radio to the UI. In Individual mode
the TTS output is cloned into the last detected speaker's voice via
OpenVoice V2 ToneColorConverter. In Collective mode the mean SE is
used. Text-input path falls back to the collective SE. Both modes
degrade gracefully to base VITS voice if OpenVoice is not ready.

requirements.txt: add speechbrain>=0.5.15 and openvoice (git install
at Docker build time — HF Spaces blocks GitHub only at runtime).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

app_lab.py CHANGED
@@ -49,6 +49,8 @@ LANGUAGE_NAMES = {
49
  from src.memory.memory_manager import MemoryManager
50
  from src.llm.gemma_client import GemmaClient
51
  from src.tts.waxal_tts import WaxalTTSEngine
 
 
52
  from src.engine.stt_processor import (
53
  transcribe_with_confidence,
54
  LOW_CONFIDENCE_THRESHOLD,
@@ -56,10 +58,12 @@ from src.engine.stt_processor import (
56
  )
57
  from src.engine.curiosity import CuriosityEngine
58
 
59
- _memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
60
- _gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
61
- _tts = WaxalTTSEngine()
62
- _curiosity = CuriosityEngine(interval=5)
 
 
63
 
64
  # Whisper — loaded lazily in background
65
  _whisper_model = None
@@ -149,10 +153,14 @@ def _run_llm_and_tts(
149
  lang_code: str,
150
  history: list,
151
  source_label: str,
 
152
  ) -> tuple:
153
  """
154
- Shared core: Gemma → memory update → TTS.
155
  Returns: (history, recent_words_md, status_msg, audio_tuple_or_None)
 
 
 
156
  """
157
  # 1. Ask Gemma (with vocabulary context)
158
  vocab_ctx = _memory.get_vocabulary_context()
@@ -169,11 +177,16 @@ def _run_llm_and_tts(
169
  if word and trans:
170
  _memory.add_word_pair(word, lang, trans, trans_l, source="user_taught")
171
 
172
- # 3. TTS speak the response if language supported
173
  audio_out = None
174
  tts_result = _tts.synthesize(response, lang_code)
175
  if tts_result is not None:
176
- audio_out = WaxalTTSEngine.audio_to_gradio(*tts_result)
 
 
 
 
 
177
 
178
  # 4. Update chat history
179
  history = list(history or [])
@@ -196,9 +209,14 @@ def _run_llm_and_tts(
196
  return history, _render_recent_words(), status_msg, audio_out
197
 
198
 
199
- def process_audio(audio_path, language_label: str, history: list) -> tuple:
 
 
 
 
 
200
  """
201
- Full pipeline: audio → Whisper STT → Gemma → TTS.
202
  Returns: (history, recent_words_md, status_msg, audio_out)
203
  """
204
  try:
@@ -211,11 +229,30 @@ def process_audio(audio_path, language_label: str, history: list) -> tuple:
211
  if _whisper_model is None:
212
  return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  transcript, avg_logprob = _transcribe(audio_path, lang_code)
215
  if not transcript:
216
  return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
217
 
218
- # Low-confidence transcription → ask user to repeat and explain
219
  if avg_logprob < LOW_CONFIDENCE_THRESHOLD:
220
  logger.info(
221
  "Low STT confidence (avg_logprob=%.3f) — switching to confusion prompt",
@@ -223,20 +260,24 @@ def process_audio(audio_path, language_label: str, history: list) -> tuple:
223
  )
224
  transcript = CONFUSION_PROMPT
225
 
226
- return _run_llm_and_tts(transcript, lang_code, history, "voice")
227
  except Exception as exc:
228
  logger.exception("process_audio error")
229
  return history, _render_recent_words(), f"❌ Error: {exc}", None
230
 
231
 
232
- def process_text(text: str, language_label: str, history: list) -> tuple:
233
- """Text input path — Gemma → TTS. Returns: (history, recent_words_md, status_msg, audio_out)"""
234
  try:
235
  if not text.strip():
236
  return history, _render_recent_words(), "⚠️ Please type something.", None
237
 
238
  lang_code = _label_to_code(language_label)
239
- return _run_llm_and_tts(text.strip(), lang_code, history, "text")
 
 
 
 
240
  except Exception as exc:
241
  logger.exception("process_text error")
242
  return history, _render_recent_words(), f"❌ Error: {exc}", None
@@ -289,13 +330,16 @@ def build_ui() -> gr.Blocks:
289
  tts = _tts.get_status()
290
  bam = "🟢" if tts["bam"] == "ready" else ("🟡" if "not" in tts["bam"] else "🔴")
291
  ful = "🟢" if tts["ful"] == "ready" else ("🟡" if "not" in tts["ful"] else "🔴")
292
- return f"{stt} | TTS Bambara {bam} | TTS Fula {ful}"
 
 
 
293
 
294
  status_box = gr.Textbox(
295
  value=_full_status(),
296
  label="System status",
297
  interactive=False,
298
- max_lines=1,
299
  )
300
  status_timer = gr.Timer(value=4)
301
  status_timer.tick(fn=_full_status, outputs=status_box)
@@ -306,6 +350,16 @@ def build_ui() -> gr.Blocks:
306
  label="Language you are speaking",
307
  )
308
 
 
 
 
 
 
 
 
 
 
 
309
  with gr.Tab("🎙️ Push-to-Talk"):
310
  audio_input = gr.Audio(
311
  sources=["microphone"],
@@ -370,7 +424,7 @@ def build_ui() -> gr.Blocks:
370
 
371
  talk_btn.click(
372
  fn=process_audio,
373
- inputs=[audio_input, language_dd, history_state],
374
  outputs=[history_state, recent_words, action_status, audio_output],
375
  ).then(
376
  fn=lambda h: h,
@@ -380,7 +434,7 @@ def build_ui() -> gr.Blocks:
380
 
381
  text_btn.click(
382
  fn=process_text,
383
- inputs=[text_input, language_dd, history_state],
384
  outputs=[history_state, recent_words, action_status, audio_output],
385
  ).then(
386
  fn=lambda h: (h, ""),
@@ -390,7 +444,7 @@ def build_ui() -> gr.Blocks:
390
 
391
  text_input.submit(
392
  fn=process_text,
393
- inputs=[text_input, language_dd, history_state],
394
  outputs=[history_state, recent_words, action_status, audio_output],
395
  ).then(
396
  fn=lambda h: (h, ""),
@@ -414,6 +468,10 @@ threading.Thread(target=_memory.load, daemon=True).start()
414
  _ensure_whisper()
415
  # Preload TTS models in background
416
  _tts.preload()
 
 
 
 
417
 
418
  if __name__ == "__main__":
419
  from dotenv import load_dotenv
 
49
  from src.memory.memory_manager import MemoryManager
50
  from src.llm.gemma_client import GemmaClient
51
  from src.tts.waxal_tts import WaxalTTSEngine
52
+ from src.tts.voice_cloner import VoiceCloner
53
+ from src.voice.speaker_profiles import SpeakerProfileManager
54
  from src.engine.stt_processor import (
55
  transcribe_with_confidence,
56
  LOW_CONFIDENCE_THRESHOLD,
 
58
  )
59
  from src.engine.curiosity import CuriosityEngine
60
 
61
+ _memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
62
+ _gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
63
+ _tts = WaxalTTSEngine()
64
+ _voice_cloner = VoiceCloner()
65
+ _speaker_profiles = SpeakerProfileManager()
66
+ _curiosity = CuriosityEngine(interval=5)
67
 
68
  # Whisper — loaded lazily in background
69
  _whisper_model = None
 
153
  lang_code: str,
154
  history: list,
155
  source_label: str,
156
+ active_se=None,
157
  ) -> tuple:
158
  """
159
+ Shared core: Gemma → memory update → TTS → optional voice cloning.
160
  Returns: (history, recent_words_md, status_msg, audio_tuple_or_None)
161
+
162
+ active_se: OpenVoice V2 tone-color SE (numpy array) to clone into, or None
163
+ for the base VITS voice.
164
  """
165
  # 1. Ask Gemma (with vocabulary context)
166
  vocab_ctx = _memory.get_vocabulary_context()
 
177
  if word and trans:
178
  _memory.add_word_pair(word, lang, trans, trans_l, source="user_taught")
179
 
180
+ # 3. TTS optional voice cloning
181
  audio_out = None
182
  tts_result = _tts.synthesize(response, lang_code)
183
  if tts_result is not None:
184
+ audio_np, sr = tts_result
185
+ if active_se is not None:
186
+ cloned = _voice_cloner.convert(audio_np, sr, active_se)
187
+ if cloned is not None:
188
+ audio_np, sr = cloned
189
+ audio_out = WaxalTTSEngine.audio_to_gradio(audio_np, sr)
190
 
191
  # 4. Update chat history
192
  history = list(history or [])
 
209
  return history, _render_recent_words(), status_msg, audio_out
210
 
211
 
212
+ def process_audio(
213
+ audio_path,
214
+ language_label: str,
215
+ voice_mode: str,
216
+ history: list,
217
+ ) -> tuple:
218
  """
219
+ Full pipeline: audio → speaker ID → Whisper STT → Gemma → TTS → voice clone.
220
  Returns: (history, recent_words_md, status_msg, audio_out)
221
  """
222
  try:
 
229
  if _whisper_model is None:
230
  return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
231
 
232
+ # Load audio once — used for both speaker ID and STT
233
+ import librosa
234
+ audio_np, _ = librosa.load(audio_path, sr=16_000, mono=True)
235
+
236
+ # ── Speaker identification (Task 1) ───────────────────────────────────
237
+ uid, _ = _speaker_profiles.identify_or_create(audio_np)
238
+
239
+ # Extract OpenVoice SE and update the user's profile
240
+ if uid is not None:
241
+ ov_se = _voice_cloner.extract_se(audio_np, 16_000)
242
+ if ov_se is not None:
243
+ _speaker_profiles.update_ov_embedding(uid, ov_se)
244
+
245
+ # ── Select target SE based on mode (Task 3) ───────────────────────────
246
+ if voice_mode == "Individual" and uid is not None:
247
+ active_se = _speaker_profiles.get_openvoice_se(uid)
248
+ else:
249
+ active_se = _speaker_profiles.get_collective_embedding()
250
+
251
+ # ── Transcription with confidence scoring ─────────────────────────────
252
  transcript, avg_logprob = _transcribe(audio_path, lang_code)
253
  if not transcript:
254
  return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
255
 
 
256
  if avg_logprob < LOW_CONFIDENCE_THRESHOLD:
257
  logger.info(
258
  "Low STT confidence (avg_logprob=%.3f) — switching to confusion prompt",
 
260
  )
261
  transcript = CONFUSION_PROMPT
262
 
263
+ return _run_llm_and_tts(transcript, lang_code, history, "voice", active_se)
264
  except Exception as exc:
265
  logger.exception("process_audio error")
266
  return history, _render_recent_words(), f"❌ Error: {exc}", None
267
 
268
 
269
+ def process_text(text: str, language_label: str, voice_mode: str, history: list) -> tuple:
270
+ """Text input path — Gemma → TTS optional voice clone."""
271
  try:
272
  if not text.strip():
273
  return history, _render_recent_words(), "⚠️ Please type something.", None
274
 
275
  lang_code = _label_to_code(language_label)
276
+
277
+ # Text has no speaker signal — use Collective in both modes as fallback
278
+ active_se = _speaker_profiles.get_collective_embedding()
279
+
280
+ return _run_llm_and_tts(text.strip(), lang_code, history, "text", active_se)
281
  except Exception as exc:
282
  logger.exception("process_text error")
283
  return history, _render_recent_words(), f"❌ Error: {exc}", None
 
330
  tts = _tts.get_status()
331
  bam = "🟢" if tts["bam"] == "ready" else ("🟡" if "not" in tts["bam"] else "🔴")
332
  ful = "🟢" if tts["ful"] == "ready" else ("🟡" if "not" in tts["ful"] else "🔴")
333
+ spk = _speaker_profiles.get_status()
334
+ cln = "🟢 Cloner" if _voice_cloner._ready else (
335
+ "🔴 Cloner" if _voice_cloner._error else "🟡 Cloner")
336
+ return f"{stt} | TTS Bambara {bam} | TTS Fula {ful}\n{spk} | {cln}"
337
 
338
  status_box = gr.Textbox(
339
  value=_full_status(),
340
  label="System status",
341
  interactive=False,
342
+ max_lines=2,
343
  )
344
  status_timer = gr.Timer(value=4)
345
  status_timer.tick(fn=_full_status, outputs=status_box)
 
350
  label="Language you are speaking",
351
  )
352
 
353
+ voice_mode_radio = gr.Radio(
354
+ choices=["Individual", "Collective"],
355
+ value="Individual",
356
+ label="Voice Mode",
357
+ info=(
358
+ "Individual — respond in the voice of the last speaker detected. "
359
+ "Collective — blend all known voices into one shared voice."
360
+ ),
361
+ )
362
+
363
  with gr.Tab("🎙️ Push-to-Talk"):
364
  audio_input = gr.Audio(
365
  sources=["microphone"],
 
424
 
425
  talk_btn.click(
426
  fn=process_audio,
427
+ inputs=[audio_input, language_dd, voice_mode_radio, history_state],
428
  outputs=[history_state, recent_words, action_status, audio_output],
429
  ).then(
430
  fn=lambda h: h,
 
434
 
435
  text_btn.click(
436
  fn=process_text,
437
+ inputs=[text_input, language_dd, voice_mode_radio, history_state],
438
  outputs=[history_state, recent_words, action_status, audio_output],
439
  ).then(
440
  fn=lambda h: (h, ""),
 
444
 
445
  text_input.submit(
446
  fn=process_text,
447
+ inputs=[text_input, language_dd, voice_mode_radio, history_state],
448
  outputs=[history_state, recent_words, action_status, audio_output],
449
  ).then(
450
  fn=lambda h: (h, ""),
 
468
  _ensure_whisper()
469
  # Preload TTS models in background
470
  _tts.preload()
471
+ # Preload speaker identification (SpeechBrain ECAPA-TDNN)
472
+ _speaker_profiles.preload()
473
+ # Preload voice cloner (OpenVoice V2) — gracefully degrades if unavailable
474
+ _voice_cloner.preload()
475
 
476
  if __name__ == "__main__":
477
  from dotenv import load_dotenv
requirements.txt CHANGED
@@ -52,5 +52,12 @@ scipy==1.15.2
52
  # Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
53
  rapidfuzz==3.13.0
54
 
 
 
 
 
 
 
 
55
  # maliba-ai is NOT listed here — it has strict conflicting pins (librosa, soundfile).
56
  # It is installed lazily at runtime on first Bambara TTS call (see src/tts/waxal_tts.py).
 
52
  # Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
53
  rapidfuzz==3.13.0
54
 
55
+ # Speaker identification (ECAPA-TDNN 192-d embeddings, used by SpeakerProfileManager)
56
+ speechbrain>=0.5.15
57
+
58
+ # Voice cloning (OpenVoice V2 ToneColorConverter — installed at build time from GitHub)
59
+ # HF Spaces blocks GitHub at *runtime* but allows it during Docker build via requirements.txt
60
+ openvoice @ git+https://github.com/myshell-ai/OpenVoice.git
61
+
62
  # maliba-ai is NOT listed here — it has strict conflicting pins (librosa, soundfile).
63
  # It is installed lazily at runtime on first Bambara TTS call (see src/tts/waxal_tts.py).
src/tts/voice_cloner.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceCloner — OpenVoice V2 tone-color converter.
3
+
4
+ Sits downstream of WaxalTTSEngine: takes the base VITS audio and reshapes it
5
+ to match a target speaker's tone color.
6
+
7
+ Usage:
8
+ cloner = VoiceCloner()
9
+ cloner.preload() # background thread
10
+
11
+ # After WaxalTTS produces (audio_np, sr) …
12
+ se = cloner.extract_se(audio_np, sr) # extract SE from user's mic audio
13
+ result = cloner.convert(audio_np, sr, se) # returns (cloned_audio, sr) or None
14
+
15
+ The OpenVoice V2 checkpoint is downloaded from myshell-ai/openvoice-v2 on
16
+ HuggingFace Hub at first use (cached in data/openvoice_v2/).
17
+
18
+ Falls back gracefully (returns None) if openvoice is not installed or the
19
+ checkpoint download fails — in that case the caller uses the raw VITS output.
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import logging
24
+ import os
25
+ import tempfile
26
+ import threading
27
+ from pathlib import Path
28
+ from typing import Optional
29
+
30
+ import numpy as np
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ OV_HF_REPO = "myshell-ai/openvoice-v2"
35
+ OV_CKPT_DIR = Path("data/openvoice_v2")
36
+ HF_TOKEN = os.environ.get("HF_TOKEN")
37
+
38
+
39
+ class VoiceCloner:
40
+ """
41
+ Thin wrapper around OpenVoice V2 ToneColorConverter.
42
+
43
+ Thread-safety: convert() holds _lock so parallel calls are serialised;
44
+ the lock is released while waiting for subprocess/IO.
45
+ """
46
+
47
+ def __init__(self) -> None:
48
+ self._lock = threading.Lock()
49
+ self._converter = None
50
+ self._src_se = None # cached base-TTS source SE (computed on first convert)
51
+ self._ready = False
52
+ self._error: Optional[str] = None
53
+
54
+ def preload(self) -> None:
55
+ threading.Thread(target=self._load, daemon=True).start()
56
+
57
+ def get_status(self) -> str:
58
+ if self._ready: return "ready"
59
+ if self._error: return f"error: {self._error}"
60
+ return "loading…"
61
+
62
+ # ── Loading ───────────────────────────────────────────────────────────────
63
+
64
+ def _load(self) -> None:
65
+ try:
66
+ from openvoice.api import ToneColorConverter # noqa: F401 — validate import
67
+ OV_CKPT_DIR.mkdir(parents=True, exist_ok=True)
68
+
69
+ # Download checkpoint from HF Hub once, then use local cache
70
+ converter_cfg = OV_CKPT_DIR / "converter" / "config.json"
71
+ if not converter_cfg.exists():
72
+ logger.info("VoiceCloner: downloading OpenVoice V2 from HF Hub …")
73
+ from huggingface_hub import snapshot_download
74
+ snapshot_download(
75
+ repo_id=OV_HF_REPO,
76
+ local_dir=str(OV_CKPT_DIR),
77
+ token=HF_TOKEN,
78
+ )
79
+
80
+ # Find config — repo layout may vary
81
+ cfg_path = self._find_converter_config()
82
+ if cfg_path is None:
83
+ raise FileNotFoundError(
84
+ f"converter/config.json not found under {OV_CKPT_DIR}"
85
+ )
86
+
87
+ from openvoice.api import ToneColorConverter
88
+ logger.info("VoiceCloner: loading ToneColorConverter from %s …", cfg_path)
89
+ converter = ToneColorConverter(str(cfg_path), device="cpu")
90
+ ckpt = cfg_path.parent / "checkpoint.pth"
91
+ converter.load_ckpt(str(ckpt))
92
+
93
+ with self._lock:
94
+ self._converter = converter
95
+ self._ready = True
96
+ logger.info("VoiceCloner: OpenVoice V2 ready")
97
+
98
+ except Exception as exc:
99
+ self._error = str(exc)
100
+ logger.warning(
101
+ "VoiceCloner: load failed — voice cloning disabled: %s", exc
102
+ )
103
+
104
+ def _find_converter_config(self) -> Optional[Path]:
105
+ """Probe known checkpoint layouts to locate converter/config.json."""
106
+ candidates = [
107
+ OV_CKPT_DIR / "converter" / "config.json",
108
+ OV_CKPT_DIR / "checkpoints_v2" / "converter" / "config.json",
109
+ ]
110
+ for p in candidates:
111
+ if p.exists():
112
+ return p
113
+ # Walk one level deep as fallback
114
+ for p in OV_CKPT_DIR.rglob("config.json"):
115
+ if p.parent.name == "converter":
116
+ return p
117
+ return None
118
+
119
+ # ── SE extraction ─────────────────────────────────────────────────────────
120
+
121
+ def extract_se(self, audio_np: np.ndarray, sr: int) -> Optional[np.ndarray]:
122
+ """
123
+ Extract OpenVoice V2 tone-color SE from raw float32 audio.
124
+ Returns a numpy array (shape depends on OV model, typically (1, 256)),
125
+ or None if not ready.
126
+ """
127
+ if not self._ready:
128
+ return None
129
+ try:
130
+ import soundfile as sf
131
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
132
+ tmp = f.name
133
+ sf.write(tmp, audio_np, sr)
134
+ se = self._extract_se_from_file(tmp)
135
+ Path(tmp).unlink(missing_ok=True)
136
+ return se
137
+ except Exception as exc:
138
+ logger.debug("VoiceCloner.extract_se: %s", exc)
139
+ return None
140
+
141
+ def _extract_se_from_file(self, audio_path: str) -> Optional[np.ndarray]:
142
+ try:
143
+ from openvoice import se_extractor
144
+ se, _ = se_extractor.get_se(
145
+ audio_path,
146
+ self._converter,
147
+ target_dir=str(OV_CKPT_DIR / "tmp"),
148
+ vad=False,
149
+ )
150
+ arr = se.cpu().numpy() if hasattr(se, "cpu") else np.array(se)
151
+ return arr
152
+ except Exception as exc:
153
+ logger.debug("VoiceCloner._extract_se_from_file: %s", exc)
154
+ return None
155
+
156
+ # ── Voice conversion ──────────────────────────────────────────────────────
157
+
158
+ def convert(
159
+ self,
160
+ audio_np: np.ndarray,
161
+ sr: int,
162
+ target_se: np.ndarray,
163
+ ) -> Optional[tuple[np.ndarray, int]]:
164
+ """
165
+ Reshape audio to match the target speaker's tone color.
166
+
167
+ Args:
168
+ audio_np: float32 audio from WaxalTTS (base voice).
169
+ sr: sample rate of audio_np.
170
+ target_se: OpenVoice SE from SpeakerProfileManager (Individual or
171
+ Collective).
172
+
173
+ Returns (cloned_audio_float32, sample_rate) or None if not ready.
174
+ """
175
+ if not self._ready:
176
+ return None
177
+ try:
178
+ import soundfile as sf
179
+ import torch
180
+
181
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
182
+ src_path = f.name
183
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
184
+ out_path = f.name
185
+
186
+ sf.write(src_path, audio_np, sr)
187
+
188
+ with self._lock:
189
+ # Extract source SE on first call, then cache it for the session
190
+ if self._src_se is None:
191
+ se = self._extract_se_from_file(src_path)
192
+ if se is not None:
193
+ self._src_se = se
194
+
195
+ if self._src_se is None:
196
+ logger.warning("VoiceCloner: could not extract source SE")
197
+ return None
198
+
199
+ src_se_t = torch.tensor(self._src_se)
200
+ tgt_se_t = torch.tensor(target_se)
201
+
202
+ # Ensure batch dim matches what the converter expects
203
+ if src_se_t.dim() == 1:
204
+ src_se_t = src_se_t.unsqueeze(0)
205
+ if tgt_se_t.dim() == 1:
206
+ tgt_se_t = tgt_se_t.unsqueeze(0)
207
+
208
+ self._converter.convert(
209
+ audio_src_path=src_path,
210
+ src_se=src_se_t,
211
+ tgt_se=tgt_se_t,
212
+ output_path=out_path,
213
+ message="@MyShell",
214
+ )
215
+
216
+ audio_out, out_sr = sf.read(out_path, dtype="float32")
217
+ Path(src_path).unlink(missing_ok=True)
218
+ Path(out_path).unlink(missing_ok=True)
219
+ return audio_out.astype(np.float32), out_sr
220
+
221
+ except Exception as exc:
222
+ logger.error("VoiceCloner.convert: %s", exc)
223
+ return None
src/voice/__init__.py ADDED
File without changes
src/voice/speaker_profiles.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SpeakerProfileManager — multi-user voice identity.
3
+
4
+ SpeechBrain ECAPA-TDNN extracts 192-d embeddings for speaker identification.
5
+ Each confirmed user gets two files in data/profiles/:
6
+
7
+ user_N_sb.npy — running-average SpeechBrain embedding (identification)
8
+ user_N_ov.npy — running-average OpenVoice V2 tone-color SE (cloning)
9
+ user_N_count.txt — number of utterances averaged so far
10
+
11
+ Speaker matching uses cosine similarity. If similarity ≥ COSINE_THRESHOLD the
12
+ new utterance is attributed to that user and their embedding is updated;
13
+ otherwise a new profile is created.
14
+
15
+ get_collective_embedding() (Task 2) returns the mean of all stored OV SEs.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ import threading
21
+ from pathlib import Path
22
+ from typing import Optional
23
+
24
+ import numpy as np
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ PROFILES_DIR = Path("data/profiles")
29
+ COSINE_THRESHOLD = 0.75 # empirical threshold for ECAPA-TDNN
30
+
31
+
32
+ def _cosine(a: np.ndarray, b: np.ndarray) -> float:
33
+ denom = np.linalg.norm(a) * np.linalg.norm(b)
34
+ return float(np.dot(a, b) / denom) if denom > 1e-8 else 0.0
35
+
36
+
37
+ def _running_avg(old: np.ndarray, new: np.ndarray, count: int) -> np.ndarray:
38
+ """Weighted running average — older observations decay gently."""
39
+ alpha = 1.0 / (count + 1)
40
+ return (1.0 - alpha) * old + alpha * new
41
+
42
+
43
+ class SpeakerProfileManager:
44
+ """Thread-safe multi-user voice profile store backed by .npy files."""
45
+
46
+ def __init__(self, profiles_dir: Path = PROFILES_DIR) -> None:
47
+ self._dir = Path(profiles_dir)
48
+ self._dir.mkdir(parents=True, exist_ok=True)
49
+ self._lock = threading.Lock()
50
+
51
+ # SpeechBrain state
52
+ self._sb_model = None
53
+ self._sb_ready = False
54
+ self._sb_error: Optional[str] = None
55
+
56
+ # In-memory cache: { "user_0": {"sb": ndarray, "ov": ndarray|None, "count": int} }
57
+ self._profiles: dict[str, dict] = {}
58
+ self._load_profiles()
59
+
60
+ # ── SpeechBrain loading ───────────────────────────────────────────────────
61
+
62
+ def preload(self) -> None:
63
+ threading.Thread(target=self._load_sb, daemon=True).start()
64
+
65
+ def _load_sb(self) -> None:
66
+ try:
67
+ try:
68
+ from speechbrain.inference.classifiers import EncoderClassifier
69
+ except ImportError:
70
+ from speechbrain.pretrained import EncoderClassifier
71
+
72
+ logger.info("SpeakerProfiles: loading SpeechBrain ECAPA-TDNN …")
73
+ self._sb_model = EncoderClassifier.from_hparams(
74
+ source="speechbrain/spkrec-ecapa-voxceleb",
75
+ run_opts={"device": "cpu"},
76
+ savedir="data/speechbrain_cache",
77
+ )
78
+ self._sb_ready = True
79
+ logger.info("SpeakerProfiles: SpeechBrain ready")
80
+ except Exception as exc:
81
+ self._sb_error = str(exc)
82
+ logger.error("SpeakerProfiles: SpeechBrain load failed: %s", exc)
83
+
84
+ def _extract_sb(self, audio_np: np.ndarray) -> Optional[np.ndarray]:
85
+ """Return 192-d ECAPA embedding, or None if model not ready."""
86
+ if not self._sb_ready:
87
+ self._load_sb()
88
+ if not self._sb_ready:
89
+ return None
90
+ try:
91
+ import torch
92
+ wav = torch.tensor(audio_np, dtype=torch.float32).unsqueeze(0)
93
+ with torch.no_grad():
94
+ emb = self._sb_model.encode_batch(wav) # (1, 1, 192)
95
+ return emb.squeeze().cpu().numpy()
96
+ except Exception as exc:
97
+ logger.error("SpeakerProfiles: SpeechBrain inference error: %s", exc)
98
+ return None
99
+
100
+ # ── Profile I/O ───────────────────────────────────────────────────────────
101
+
102
+ def _load_profiles(self) -> None:
103
+ profiles = {}
104
+ for sb_path in sorted(self._dir.glob("user_*_sb.npy")):
105
+ uid = sb_path.stem[:-3] # "user_N_sb" → "user_N"
106
+ ov_path = self._dir / f"{uid}_ov.npy"
107
+ cnt_path = self._dir / f"{uid}_count.txt"
108
+ profiles[uid] = {
109
+ "sb": np.load(sb_path),
110
+ "ov": np.load(ov_path) if ov_path.exists() else None,
111
+ "count": int(cnt_path.read_text()) if cnt_path.exists() else 1,
112
+ }
113
+ with self._lock:
114
+ self._profiles = profiles
115
+ logger.info("SpeakerProfiles: loaded %d profile(s)", len(profiles))
116
+
117
+ def _save_profile(self, uid: str) -> None:
118
+ p = self._profiles[uid]
119
+ np.save(self._dir / f"{uid}_sb.npy", p["sb"])
120
+ if p["ov"] is not None:
121
+ np.save(self._dir / f"{uid}_ov.npy", p["ov"])
122
+ (self._dir / f"{uid}_count.txt").write_text(str(p["count"]))
123
+
124
+ # ── Task 1: Speaker identification ────────────────────────────────────────
125
+
126
+ def identify_or_create(
127
+ self, audio_np: np.ndarray
128
+ ) -> tuple[Optional[str], Optional[np.ndarray]]:
129
+ """
130
+ Extract a SpeechBrain embedding and match it to an existing profile
131
+ (cosine similarity ≥ threshold) or create a new one.
132
+
133
+ Returns (user_id, sb_embedding).
134
+ Returns (None, None) if SpeechBrain is not available.
135
+ """
136
+ sb_emb = self._extract_sb(audio_np)
137
+ if sb_emb is None:
138
+ return None, None
139
+
140
+ with self._lock:
141
+ best_uid, best_sim = None, -1.0
142
+ for uid, profile in self._profiles.items():
143
+ sim = _cosine(sb_emb, profile["sb"])
144
+ if sim > best_sim:
145
+ best_sim, best_uid = sim, uid
146
+
147
+ if best_uid is not None and best_sim >= COSINE_THRESHOLD:
148
+ # Known speaker — update running average
149
+ p = self._profiles[best_uid]
150
+ new_count = p["count"] + 1
151
+ p["sb"] = _running_avg(p["sb"], sb_emb, p["count"])
152
+ p["count"] = new_count
153
+ uid = best_uid
154
+ logger.debug(
155
+ "SpeakerProfiles: recognised %s (sim=%.3f, n=%d)",
156
+ uid, best_sim, new_count,
157
+ )
158
+ else:
159
+ # New speaker
160
+ uid = f"user_{len(self._profiles)}"
161
+ self._profiles[uid] = {"sb": sb_emb, "ov": None, "count": 1}
162
+ logger.info("SpeakerProfiles: new profile → %s", uid)
163
+
164
+ self._save_profile(uid)
165
+
166
+ return uid, sb_emb
167
+
168
+ # ── OpenVoice SE management ───────────────────────────────────────────────
169
+
170
+ def update_ov_embedding(self, uid: str, ov_emb: np.ndarray) -> None:
171
+ """Store or running-average the OpenVoice tone-color SE for a user."""
172
+ with self._lock:
173
+ if uid not in self._profiles:
174
+ return
175
+ p = self._profiles[uid]
176
+ if p["ov"] is None:
177
+ p["ov"] = ov_emb.copy()
178
+ else:
179
+ p["ov"] = _running_avg(p["ov"], ov_emb, p["count"])
180
+ self._save_profile(uid)
181
+
182
+ def get_openvoice_se(self, uid: str) -> Optional[np.ndarray]:
183
+ """Return the stored OpenVoice SE for this user, or None."""
184
+ with self._lock:
185
+ p = self._profiles.get(uid)
186
+ return p["ov"].copy() if p and p["ov"] is not None else None
187
+
188
+ # ── Task 2: Collective Voice ──────────────────────────────────────────────
189
+
190
+ def get_collective_embedding(self) -> Optional[np.ndarray]:
191
+ """
192
+ Load all user_N_ov.npy files, return the mean vector.
193
+ This is the "Median Embedding" that represents all known speakers.
194
+ Returns None if no OpenVoice SEs have been collected yet.
195
+ """
196
+ # Prefer in-memory cache
197
+ with self._lock:
198
+ ov_list = [p["ov"] for p in self._profiles.values() if p["ov"] is not None]
199
+
200
+ if not ov_list:
201
+ # Fall back to disk scan (e.g. after a restart that didn't re-identify)
202
+ ov_list = [np.load(p) for p in sorted(self._dir.glob("user_*_ov.npy"))]
203
+
204
+ if not ov_list:
205
+ return None
206
+
207
+ stacked = np.stack(ov_list, axis=0)
208
+ return stacked.mean(axis=0)
209
+
210
+ # ── Status ────────────────────────────────────────────────────────────────
211
+
212
+ def get_status(self) -> str:
213
+ n = len(self._profiles)
214
+ sb = "🟢" if self._sb_ready else ("🔴" if self._sb_error else "🟡")
215
+ return f"{sb} SpeechBrain | {n} speaker profile{'s' if n != 1 else ''}"