Spaces:
Running
Add multi-user speaker profiles, collective voice, and mode toggle
Browse filesTask 1 — SpeakerProfileManager (src/voice/speaker_profiles.py):
SpeechBrain ECAPA-TDNN identifies speakers by cosine similarity of
192-d embeddings (threshold 0.75). Each user gets user_N_sb.npy
(running-average SpeechBrain embedding) and user_N_ov.npy
(running-average OpenVoice V2 tone-color SE). New utterances update
the running average so profiles sharpen over time.
Task 2 — get_collective_embedding() (speaker_profiles.py):
Loads all user_N_ov.npy files and returns their mean vector — the
"Collective Voice" that blends all known speakers into one SE.
Task 3 — Voice Mode toggle (app_lab.py + src/tts/voice_cloner.py):
Added Individual / Collective radio to the UI. In Individual mode
the TTS output is cloned into the last detected speaker's voice via
OpenVoice V2 ToneColorConverter. In Collective mode the mean SE is
used. Text-input path falls back to the collective SE. Both modes
degrade gracefully to base VITS voice if OpenVoice is not ready.
requirements.txt: add speechbrain>=0.5.15 and openvoice (git install
at Docker build time — HF Spaces blocks GitHub only at runtime).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- app_lab.py +77 -19
- requirements.txt +7 -0
- src/tts/voice_cloner.py +223 -0
- src/voice/__init__.py +0 -0
- src/voice/speaker_profiles.py +215 -0
|
@@ -49,6 +49,8 @@ LANGUAGE_NAMES = {
|
|
| 49 |
from src.memory.memory_manager import MemoryManager
|
| 50 |
from src.llm.gemma_client import GemmaClient
|
| 51 |
from src.tts.waxal_tts import WaxalTTSEngine
|
|
|
|
|
|
|
| 52 |
from src.engine.stt_processor import (
|
| 53 |
transcribe_with_confidence,
|
| 54 |
LOW_CONFIDENCE_THRESHOLD,
|
|
@@ -56,10 +58,12 @@ from src.engine.stt_processor import (
|
|
| 56 |
)
|
| 57 |
from src.engine.curiosity import CuriosityEngine
|
| 58 |
|
| 59 |
-
_memory
|
| 60 |
-
_gemma
|
| 61 |
-
_tts
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# Whisper — loaded lazily in background
|
| 65 |
_whisper_model = None
|
|
@@ -149,10 +153,14 @@ def _run_llm_and_tts(
|
|
| 149 |
lang_code: str,
|
| 150 |
history: list,
|
| 151 |
source_label: str,
|
|
|
|
| 152 |
) -> tuple:
|
| 153 |
"""
|
| 154 |
-
Shared core: Gemma → memory update → TTS.
|
| 155 |
Returns: (history, recent_words_md, status_msg, audio_tuple_or_None)
|
|
|
|
|
|
|
|
|
|
| 156 |
"""
|
| 157 |
# 1. Ask Gemma (with vocabulary context)
|
| 158 |
vocab_ctx = _memory.get_vocabulary_context()
|
|
@@ -169,11 +177,16 @@ def _run_llm_and_tts(
|
|
| 169 |
if word and trans:
|
| 170 |
_memory.add_word_pair(word, lang, trans, trans_l, source="user_taught")
|
| 171 |
|
| 172 |
-
# 3. TTS
|
| 173 |
audio_out = None
|
| 174 |
tts_result = _tts.synthesize(response, lang_code)
|
| 175 |
if tts_result is not None:
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
# 4. Update chat history
|
| 179 |
history = list(history or [])
|
|
@@ -196,9 +209,14 @@ def _run_llm_and_tts(
|
|
| 196 |
return history, _render_recent_words(), status_msg, audio_out
|
| 197 |
|
| 198 |
|
| 199 |
-
def process_audio(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
"""
|
| 201 |
-
Full pipeline: audio → Whisper STT → Gemma → TTS.
|
| 202 |
Returns: (history, recent_words_md, status_msg, audio_out)
|
| 203 |
"""
|
| 204 |
try:
|
|
@@ -211,11 +229,30 @@ def process_audio(audio_path, language_label: str, history: list) -> tuple:
|
|
| 211 |
if _whisper_model is None:
|
| 212 |
return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
transcript, avg_logprob = _transcribe(audio_path, lang_code)
|
| 215 |
if not transcript:
|
| 216 |
return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
|
| 217 |
|
| 218 |
-
# Low-confidence transcription → ask user to repeat and explain
|
| 219 |
if avg_logprob < LOW_CONFIDENCE_THRESHOLD:
|
| 220 |
logger.info(
|
| 221 |
"Low STT confidence (avg_logprob=%.3f) — switching to confusion prompt",
|
|
@@ -223,20 +260,24 @@ def process_audio(audio_path, language_label: str, history: list) -> tuple:
|
|
| 223 |
)
|
| 224 |
transcript = CONFUSION_PROMPT
|
| 225 |
|
| 226 |
-
return _run_llm_and_tts(transcript, lang_code, history, "voice")
|
| 227 |
except Exception as exc:
|
| 228 |
logger.exception("process_audio error")
|
| 229 |
return history, _render_recent_words(), f"❌ Error: {exc}", None
|
| 230 |
|
| 231 |
|
| 232 |
-
def process_text(text: str, language_label: str, history: list) -> tuple:
|
| 233 |
-
"""Text input path — Gemma → TTS
|
| 234 |
try:
|
| 235 |
if not text.strip():
|
| 236 |
return history, _render_recent_words(), "⚠️ Please type something.", None
|
| 237 |
|
| 238 |
lang_code = _label_to_code(language_label)
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
except Exception as exc:
|
| 241 |
logger.exception("process_text error")
|
| 242 |
return history, _render_recent_words(), f"❌ Error: {exc}", None
|
|
@@ -289,13 +330,16 @@ def build_ui() -> gr.Blocks:
|
|
| 289 |
tts = _tts.get_status()
|
| 290 |
bam = "🟢" if tts["bam"] == "ready" else ("🟡" if "not" in tts["bam"] else "🔴")
|
| 291 |
ful = "🟢" if tts["ful"] == "ready" else ("🟡" if "not" in tts["ful"] else "🔴")
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
status_box = gr.Textbox(
|
| 295 |
value=_full_status(),
|
| 296 |
label="System status",
|
| 297 |
interactive=False,
|
| 298 |
-
max_lines=
|
| 299 |
)
|
| 300 |
status_timer = gr.Timer(value=4)
|
| 301 |
status_timer.tick(fn=_full_status, outputs=status_box)
|
|
@@ -306,6 +350,16 @@ def build_ui() -> gr.Blocks:
|
|
| 306 |
label="Language you are speaking",
|
| 307 |
)
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
with gr.Tab("🎙️ Push-to-Talk"):
|
| 310 |
audio_input = gr.Audio(
|
| 311 |
sources=["microphone"],
|
|
@@ -370,7 +424,7 @@ def build_ui() -> gr.Blocks:
|
|
| 370 |
|
| 371 |
talk_btn.click(
|
| 372 |
fn=process_audio,
|
| 373 |
-
inputs=[audio_input, language_dd, history_state],
|
| 374 |
outputs=[history_state, recent_words, action_status, audio_output],
|
| 375 |
).then(
|
| 376 |
fn=lambda h: h,
|
|
@@ -380,7 +434,7 @@ def build_ui() -> gr.Blocks:
|
|
| 380 |
|
| 381 |
text_btn.click(
|
| 382 |
fn=process_text,
|
| 383 |
-
inputs=[text_input, language_dd, history_state],
|
| 384 |
outputs=[history_state, recent_words, action_status, audio_output],
|
| 385 |
).then(
|
| 386 |
fn=lambda h: (h, ""),
|
|
@@ -390,7 +444,7 @@ def build_ui() -> gr.Blocks:
|
|
| 390 |
|
| 391 |
text_input.submit(
|
| 392 |
fn=process_text,
|
| 393 |
-
inputs=[text_input, language_dd, history_state],
|
| 394 |
outputs=[history_state, recent_words, action_status, audio_output],
|
| 395 |
).then(
|
| 396 |
fn=lambda h: (h, ""),
|
|
@@ -414,6 +468,10 @@ threading.Thread(target=_memory.load, daemon=True).start()
|
|
| 414 |
_ensure_whisper()
|
| 415 |
# Preload TTS models in background
|
| 416 |
_tts.preload()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
if __name__ == "__main__":
|
| 419 |
from dotenv import load_dotenv
|
|
|
|
| 49 |
from src.memory.memory_manager import MemoryManager
|
| 50 |
from src.llm.gemma_client import GemmaClient
|
| 51 |
from src.tts.waxal_tts import WaxalTTSEngine
|
| 52 |
+
from src.tts.voice_cloner import VoiceCloner
|
| 53 |
+
from src.voice.speaker_profiles import SpeakerProfileManager
|
| 54 |
from src.engine.stt_processor import (
|
| 55 |
transcribe_with_confidence,
|
| 56 |
LOW_CONFIDENCE_THRESHOLD,
|
|
|
|
| 58 |
)
|
| 59 |
from src.engine.curiosity import CuriosityEngine
|
| 60 |
|
| 61 |
+
_memory = MemoryManager(repo_id=FEEDBACK_REPO_ID, hf_token=HF_TOKEN)
|
| 62 |
+
_gemma = GemmaClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
|
| 63 |
+
_tts = WaxalTTSEngine()
|
| 64 |
+
_voice_cloner = VoiceCloner()
|
| 65 |
+
_speaker_profiles = SpeakerProfileManager()
|
| 66 |
+
_curiosity = CuriosityEngine(interval=5)
|
| 67 |
|
| 68 |
# Whisper — loaded lazily in background
|
| 69 |
_whisper_model = None
|
|
|
|
| 153 |
lang_code: str,
|
| 154 |
history: list,
|
| 155 |
source_label: str,
|
| 156 |
+
active_se=None,
|
| 157 |
) -> tuple:
|
| 158 |
"""
|
| 159 |
+
Shared core: Gemma → memory update → TTS → optional voice cloning.
|
| 160 |
Returns: (history, recent_words_md, status_msg, audio_tuple_or_None)
|
| 161 |
+
|
| 162 |
+
active_se: OpenVoice V2 tone-color SE (numpy array) to clone into, or None
|
| 163 |
+
for the base VITS voice.
|
| 164 |
"""
|
| 165 |
# 1. Ask Gemma (with vocabulary context)
|
| 166 |
vocab_ctx = _memory.get_vocabulary_context()
|
|
|
|
| 177 |
if word and trans:
|
| 178 |
_memory.add_word_pair(word, lang, trans, trans_l, source="user_taught")
|
| 179 |
|
| 180 |
+
# 3. TTS → optional voice cloning
|
| 181 |
audio_out = None
|
| 182 |
tts_result = _tts.synthesize(response, lang_code)
|
| 183 |
if tts_result is not None:
|
| 184 |
+
audio_np, sr = tts_result
|
| 185 |
+
if active_se is not None:
|
| 186 |
+
cloned = _voice_cloner.convert(audio_np, sr, active_se)
|
| 187 |
+
if cloned is not None:
|
| 188 |
+
audio_np, sr = cloned
|
| 189 |
+
audio_out = WaxalTTSEngine.audio_to_gradio(audio_np, sr)
|
| 190 |
|
| 191 |
# 4. Update chat history
|
| 192 |
history = list(history or [])
|
|
|
|
| 209 |
return history, _render_recent_words(), status_msg, audio_out
|
| 210 |
|
| 211 |
|
| 212 |
+
def process_audio(
|
| 213 |
+
audio_path,
|
| 214 |
+
language_label: str,
|
| 215 |
+
voice_mode: str,
|
| 216 |
+
history: list,
|
| 217 |
+
) -> tuple:
|
| 218 |
"""
|
| 219 |
+
Full pipeline: audio → speaker ID → Whisper STT → Gemma → TTS → voice clone.
|
| 220 |
Returns: (history, recent_words_md, status_msg, audio_out)
|
| 221 |
"""
|
| 222 |
try:
|
|
|
|
| 229 |
if _whisper_model is None:
|
| 230 |
return history, _render_recent_words(), f"⏳ {status} — wait a moment and try again.", None
|
| 231 |
|
| 232 |
+
# Load audio once — used for both speaker ID and STT
|
| 233 |
+
import librosa
|
| 234 |
+
audio_np, _ = librosa.load(audio_path, sr=16_000, mono=True)
|
| 235 |
+
|
| 236 |
+
# ── Speaker identification (Task 1) ───────────────────────────────────
|
| 237 |
+
uid, _ = _speaker_profiles.identify_or_create(audio_np)
|
| 238 |
+
|
| 239 |
+
# Extract OpenVoice SE and update the user's profile
|
| 240 |
+
if uid is not None:
|
| 241 |
+
ov_se = _voice_cloner.extract_se(audio_np, 16_000)
|
| 242 |
+
if ov_se is not None:
|
| 243 |
+
_speaker_profiles.update_ov_embedding(uid, ov_se)
|
| 244 |
+
|
| 245 |
+
# ── Select target SE based on mode (Task 3) ───────────────────────────
|
| 246 |
+
if voice_mode == "Individual" and uid is not None:
|
| 247 |
+
active_se = _speaker_profiles.get_openvoice_se(uid)
|
| 248 |
+
else:
|
| 249 |
+
active_se = _speaker_profiles.get_collective_embedding()
|
| 250 |
+
|
| 251 |
+
# ── Transcription with confidence scoring ─────────────────────────────
|
| 252 |
transcript, avg_logprob = _transcribe(audio_path, lang_code)
|
| 253 |
if not transcript:
|
| 254 |
return history, _render_recent_words(), "⚠️ Could not transcribe audio.", None
|
| 255 |
|
|
|
|
| 256 |
if avg_logprob < LOW_CONFIDENCE_THRESHOLD:
|
| 257 |
logger.info(
|
| 258 |
"Low STT confidence (avg_logprob=%.3f) — switching to confusion prompt",
|
|
|
|
| 260 |
)
|
| 261 |
transcript = CONFUSION_PROMPT
|
| 262 |
|
| 263 |
+
return _run_llm_and_tts(transcript, lang_code, history, "voice", active_se)
|
| 264 |
except Exception as exc:
|
| 265 |
logger.exception("process_audio error")
|
| 266 |
return history, _render_recent_words(), f"❌ Error: {exc}", None
|
| 267 |
|
| 268 |
|
| 269 |
+
def process_text(text: str, language_label: str, voice_mode: str, history: list) -> tuple:
|
| 270 |
+
"""Text input path — Gemma → TTS → optional voice clone."""
|
| 271 |
try:
|
| 272 |
if not text.strip():
|
| 273 |
return history, _render_recent_words(), "⚠️ Please type something.", None
|
| 274 |
|
| 275 |
lang_code = _label_to_code(language_label)
|
| 276 |
+
|
| 277 |
+
# Text has no speaker signal — use Collective in both modes as fallback
|
| 278 |
+
active_se = _speaker_profiles.get_collective_embedding()
|
| 279 |
+
|
| 280 |
+
return _run_llm_and_tts(text.strip(), lang_code, history, "text", active_se)
|
| 281 |
except Exception as exc:
|
| 282 |
logger.exception("process_text error")
|
| 283 |
return history, _render_recent_words(), f"❌ Error: {exc}", None
|
|
|
|
| 330 |
tts = _tts.get_status()
|
| 331 |
bam = "🟢" if tts["bam"] == "ready" else ("🟡" if "not" in tts["bam"] else "🔴")
|
| 332 |
ful = "🟢" if tts["ful"] == "ready" else ("🟡" if "not" in tts["ful"] else "🔴")
|
| 333 |
+
spk = _speaker_profiles.get_status()
|
| 334 |
+
cln = "🟢 Cloner" if _voice_cloner._ready else (
|
| 335 |
+
"🔴 Cloner" if _voice_cloner._error else "🟡 Cloner")
|
| 336 |
+
return f"{stt} | TTS Bambara {bam} | TTS Fula {ful}\n{spk} | {cln}"
|
| 337 |
|
| 338 |
status_box = gr.Textbox(
|
| 339 |
value=_full_status(),
|
| 340 |
label="System status",
|
| 341 |
interactive=False,
|
| 342 |
+
max_lines=2,
|
| 343 |
)
|
| 344 |
status_timer = gr.Timer(value=4)
|
| 345 |
status_timer.tick(fn=_full_status, outputs=status_box)
|
|
|
|
| 350 |
label="Language you are speaking",
|
| 351 |
)
|
| 352 |
|
| 353 |
+
voice_mode_radio = gr.Radio(
|
| 354 |
+
choices=["Individual", "Collective"],
|
| 355 |
+
value="Individual",
|
| 356 |
+
label="Voice Mode",
|
| 357 |
+
info=(
|
| 358 |
+
"Individual — respond in the voice of the last speaker detected. "
|
| 359 |
+
"Collective — blend all known voices into one shared voice."
|
| 360 |
+
),
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
with gr.Tab("🎙️ Push-to-Talk"):
|
| 364 |
audio_input = gr.Audio(
|
| 365 |
sources=["microphone"],
|
|
|
|
| 424 |
|
| 425 |
talk_btn.click(
|
| 426 |
fn=process_audio,
|
| 427 |
+
inputs=[audio_input, language_dd, voice_mode_radio, history_state],
|
| 428 |
outputs=[history_state, recent_words, action_status, audio_output],
|
| 429 |
).then(
|
| 430 |
fn=lambda h: h,
|
|
|
|
| 434 |
|
| 435 |
text_btn.click(
|
| 436 |
fn=process_text,
|
| 437 |
+
inputs=[text_input, language_dd, voice_mode_radio, history_state],
|
| 438 |
outputs=[history_state, recent_words, action_status, audio_output],
|
| 439 |
).then(
|
| 440 |
fn=lambda h: (h, ""),
|
|
|
|
| 444 |
|
| 445 |
text_input.submit(
|
| 446 |
fn=process_text,
|
| 447 |
+
inputs=[text_input, language_dd, voice_mode_radio, history_state],
|
| 448 |
outputs=[history_state, recent_words, action_status, audio_output],
|
| 449 |
).then(
|
| 450 |
fn=lambda h: (h, ""),
|
|
|
|
| 468 |
_ensure_whisper()
|
| 469 |
# Preload TTS models in background
|
| 470 |
_tts.preload()
|
| 471 |
+
# Preload speaker identification (SpeechBrain ECAPA-TDNN)
|
| 472 |
+
_speaker_profiles.preload()
|
| 473 |
+
# Preload voice cloner (OpenVoice V2) — gracefully degrades if unavailable
|
| 474 |
+
_voice_cloner.preload()
|
| 475 |
|
| 476 |
if __name__ == "__main__":
|
| 477 |
from dotenv import load_dotenv
|
|
@@ -52,5 +52,12 @@ scipy==1.15.2
|
|
| 52 |
# Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
|
| 53 |
rapidfuzz==3.13.0
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# maliba-ai is NOT listed here — it has strict conflicting pins (librosa, soundfile).
|
| 56 |
# It is installed lazily at runtime on first Bambara TTS call (see src/tts/waxal_tts.py).
|
|
|
|
| 52 |
# Phrase matching (fuzzy match for Whisper mis-transcriptions of Bambara/Fula)
|
| 53 |
rapidfuzz==3.13.0
|
| 54 |
|
| 55 |
+
# Speaker identification (ECAPA-TDNN 192-d embeddings, used by SpeakerProfileManager)
|
| 56 |
+
speechbrain>=0.5.15
|
| 57 |
+
|
| 58 |
+
# Voice cloning (OpenVoice V2 ToneColorConverter — installed at build time from GitHub)
|
| 59 |
+
# HF Spaces blocks GitHub at *runtime* but allows it during Docker build via requirements.txt
|
| 60 |
+
openvoice @ git+https://github.com/myshell-ai/OpenVoice.git
|
| 61 |
+
|
| 62 |
# maliba-ai is NOT listed here — it has strict conflicting pins (librosa, soundfile).
|
| 63 |
# It is installed lazily at runtime on first Bambara TTS call (see src/tts/waxal_tts.py).
|
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceCloner — OpenVoice V2 tone-color converter.
|
| 3 |
+
|
| 4 |
+
Sits downstream of WaxalTTSEngine: takes the base VITS audio and reshapes it
|
| 5 |
+
to match a target speaker's tone color.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
cloner = VoiceCloner()
|
| 9 |
+
cloner.preload() # background thread
|
| 10 |
+
|
| 11 |
+
# After WaxalTTS produces (audio_np, sr) …
|
| 12 |
+
se = cloner.extract_se(audio_np, sr) # extract SE from user's mic audio
|
| 13 |
+
result = cloner.convert(audio_np, sr, se) # returns (cloned_audio, sr) or None
|
| 14 |
+
|
| 15 |
+
The OpenVoice V2 checkpoint is downloaded from myshell-ai/openvoice-v2 on
|
| 16 |
+
HuggingFace Hub at first use (cached in data/openvoice_v2/).
|
| 17 |
+
|
| 18 |
+
Falls back gracefully (returns None) if openvoice is not installed or the
|
| 19 |
+
checkpoint download fails — in that case the caller uses the raw VITS output.
|
| 20 |
+
"""
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
import os
|
| 25 |
+
import tempfile
|
| 26 |
+
import threading
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Optional
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
OV_HF_REPO = "myshell-ai/openvoice-v2"
|
| 35 |
+
OV_CKPT_DIR = Path("data/openvoice_v2")
|
| 36 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class VoiceCloner:
|
| 40 |
+
"""
|
| 41 |
+
Thin wrapper around OpenVoice V2 ToneColorConverter.
|
| 42 |
+
|
| 43 |
+
Thread-safety: convert() holds _lock so parallel calls are serialised;
|
| 44 |
+
the lock is released while waiting for subprocess/IO.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self) -> None:
|
| 48 |
+
self._lock = threading.Lock()
|
| 49 |
+
self._converter = None
|
| 50 |
+
self._src_se = None # cached base-TTS source SE (computed on first convert)
|
| 51 |
+
self._ready = False
|
| 52 |
+
self._error: Optional[str] = None
|
| 53 |
+
|
| 54 |
+
def preload(self) -> None:
|
| 55 |
+
threading.Thread(target=self._load, daemon=True).start()
|
| 56 |
+
|
| 57 |
+
def get_status(self) -> str:
|
| 58 |
+
if self._ready: return "ready"
|
| 59 |
+
if self._error: return f"error: {self._error}"
|
| 60 |
+
return "loading…"
|
| 61 |
+
|
| 62 |
+
# ── Loading ───────────────────────────────────────────────────────────────
|
| 63 |
+
|
| 64 |
+
def _load(self) -> None:
|
| 65 |
+
try:
|
| 66 |
+
from openvoice.api import ToneColorConverter # noqa: F401 — validate import
|
| 67 |
+
OV_CKPT_DIR.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
# Download checkpoint from HF Hub once, then use local cache
|
| 70 |
+
converter_cfg = OV_CKPT_DIR / "converter" / "config.json"
|
| 71 |
+
if not converter_cfg.exists():
|
| 72 |
+
logger.info("VoiceCloner: downloading OpenVoice V2 from HF Hub …")
|
| 73 |
+
from huggingface_hub import snapshot_download
|
| 74 |
+
snapshot_download(
|
| 75 |
+
repo_id=OV_HF_REPO,
|
| 76 |
+
local_dir=str(OV_CKPT_DIR),
|
| 77 |
+
token=HF_TOKEN,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Find config — repo layout may vary
|
| 81 |
+
cfg_path = self._find_converter_config()
|
| 82 |
+
if cfg_path is None:
|
| 83 |
+
raise FileNotFoundError(
|
| 84 |
+
f"converter/config.json not found under {OV_CKPT_DIR}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
from openvoice.api import ToneColorConverter
|
| 88 |
+
logger.info("VoiceCloner: loading ToneColorConverter from %s …", cfg_path)
|
| 89 |
+
converter = ToneColorConverter(str(cfg_path), device="cpu")
|
| 90 |
+
ckpt = cfg_path.parent / "checkpoint.pth"
|
| 91 |
+
converter.load_ckpt(str(ckpt))
|
| 92 |
+
|
| 93 |
+
with self._lock:
|
| 94 |
+
self._converter = converter
|
| 95 |
+
self._ready = True
|
| 96 |
+
logger.info("VoiceCloner: OpenVoice V2 ready")
|
| 97 |
+
|
| 98 |
+
except Exception as exc:
|
| 99 |
+
self._error = str(exc)
|
| 100 |
+
logger.warning(
|
| 101 |
+
"VoiceCloner: load failed — voice cloning disabled: %s", exc
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def _find_converter_config(self) -> Optional[Path]:
|
| 105 |
+
"""Probe known checkpoint layouts to locate converter/config.json."""
|
| 106 |
+
candidates = [
|
| 107 |
+
OV_CKPT_DIR / "converter" / "config.json",
|
| 108 |
+
OV_CKPT_DIR / "checkpoints_v2" / "converter" / "config.json",
|
| 109 |
+
]
|
| 110 |
+
for p in candidates:
|
| 111 |
+
if p.exists():
|
| 112 |
+
return p
|
| 113 |
+
# Walk one level deep as fallback
|
| 114 |
+
for p in OV_CKPT_DIR.rglob("config.json"):
|
| 115 |
+
if p.parent.name == "converter":
|
| 116 |
+
return p
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
# ── SE extraction ─────────────────────────────────────────────────────────
|
| 120 |
+
|
| 121 |
+
def extract_se(self, audio_np: np.ndarray, sr: int) -> Optional[np.ndarray]:
|
| 122 |
+
"""
|
| 123 |
+
Extract OpenVoice V2 tone-color SE from raw float32 audio.
|
| 124 |
+
Returns a numpy array (shape depends on OV model, typically (1, 256)),
|
| 125 |
+
or None if not ready.
|
| 126 |
+
"""
|
| 127 |
+
if not self._ready:
|
| 128 |
+
return None
|
| 129 |
+
try:
|
| 130 |
+
import soundfile as sf
|
| 131 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 132 |
+
tmp = f.name
|
| 133 |
+
sf.write(tmp, audio_np, sr)
|
| 134 |
+
se = self._extract_se_from_file(tmp)
|
| 135 |
+
Path(tmp).unlink(missing_ok=True)
|
| 136 |
+
return se
|
| 137 |
+
except Exception as exc:
|
| 138 |
+
logger.debug("VoiceCloner.extract_se: %s", exc)
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
def _extract_se_from_file(self, audio_path: str) -> Optional[np.ndarray]:
|
| 142 |
+
try:
|
| 143 |
+
from openvoice import se_extractor
|
| 144 |
+
se, _ = se_extractor.get_se(
|
| 145 |
+
audio_path,
|
| 146 |
+
self._converter,
|
| 147 |
+
target_dir=str(OV_CKPT_DIR / "tmp"),
|
| 148 |
+
vad=False,
|
| 149 |
+
)
|
| 150 |
+
arr = se.cpu().numpy() if hasattr(se, "cpu") else np.array(se)
|
| 151 |
+
return arr
|
| 152 |
+
except Exception as exc:
|
| 153 |
+
logger.debug("VoiceCloner._extract_se_from_file: %s", exc)
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
# ── Voice conversion ──────────────────────────────────────────────────────
|
| 157 |
+
|
| 158 |
+
def convert(
|
| 159 |
+
self,
|
| 160 |
+
audio_np: np.ndarray,
|
| 161 |
+
sr: int,
|
| 162 |
+
target_se: np.ndarray,
|
| 163 |
+
) -> Optional[tuple[np.ndarray, int]]:
|
| 164 |
+
"""
|
| 165 |
+
Reshape audio to match the target speaker's tone color.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
audio_np: float32 audio from WaxalTTS (base voice).
|
| 169 |
+
sr: sample rate of audio_np.
|
| 170 |
+
target_se: OpenVoice SE from SpeakerProfileManager (Individual or
|
| 171 |
+
Collective).
|
| 172 |
+
|
| 173 |
+
Returns (cloned_audio_float32, sample_rate) or None if not ready.
|
| 174 |
+
"""
|
| 175 |
+
if not self._ready:
|
| 176 |
+
return None
|
| 177 |
+
try:
|
| 178 |
+
import soundfile as sf
|
| 179 |
+
import torch
|
| 180 |
+
|
| 181 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 182 |
+
src_path = f.name
|
| 183 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 184 |
+
out_path = f.name
|
| 185 |
+
|
| 186 |
+
sf.write(src_path, audio_np, sr)
|
| 187 |
+
|
| 188 |
+
with self._lock:
|
| 189 |
+
# Extract source SE on first call, then cache it for the session
|
| 190 |
+
if self._src_se is None:
|
| 191 |
+
se = self._extract_se_from_file(src_path)
|
| 192 |
+
if se is not None:
|
| 193 |
+
self._src_se = se
|
| 194 |
+
|
| 195 |
+
if self._src_se is None:
|
| 196 |
+
logger.warning("VoiceCloner: could not extract source SE")
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
src_se_t = torch.tensor(self._src_se)
|
| 200 |
+
tgt_se_t = torch.tensor(target_se)
|
| 201 |
+
|
| 202 |
+
# Ensure batch dim matches what the converter expects
|
| 203 |
+
if src_se_t.dim() == 1:
|
| 204 |
+
src_se_t = src_se_t.unsqueeze(0)
|
| 205 |
+
if tgt_se_t.dim() == 1:
|
| 206 |
+
tgt_se_t = tgt_se_t.unsqueeze(0)
|
| 207 |
+
|
| 208 |
+
self._converter.convert(
|
| 209 |
+
audio_src_path=src_path,
|
| 210 |
+
src_se=src_se_t,
|
| 211 |
+
tgt_se=tgt_se_t,
|
| 212 |
+
output_path=out_path,
|
| 213 |
+
message="@MyShell",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
audio_out, out_sr = sf.read(out_path, dtype="float32")
|
| 217 |
+
Path(src_path).unlink(missing_ok=True)
|
| 218 |
+
Path(out_path).unlink(missing_ok=True)
|
| 219 |
+
return audio_out.astype(np.float32), out_sr
|
| 220 |
+
|
| 221 |
+
except Exception as exc:
|
| 222 |
+
logger.error("VoiceCloner.convert: %s", exc)
|
| 223 |
+
return None
|
|
File without changes
|
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SpeakerProfileManager — multi-user voice identity.
|
| 3 |
+
|
| 4 |
+
SpeechBrain ECAPA-TDNN extracts 192-d embeddings for speaker identification.
|
| 5 |
+
Each confirmed user gets two files in data/profiles/:
|
| 6 |
+
|
| 7 |
+
user_N_sb.npy — running-average SpeechBrain embedding (identification)
|
| 8 |
+
user_N_ov.npy — running-average OpenVoice V2 tone-color SE (cloning)
|
| 9 |
+
user_N_count.txt — number of utterances averaged so far
|
| 10 |
+
|
| 11 |
+
Speaker matching uses cosine similarity. If similarity ≥ COSINE_THRESHOLD the
|
| 12 |
+
new utterance is attributed to that user and their embedding is updated;
|
| 13 |
+
otherwise a new profile is created.
|
| 14 |
+
|
| 15 |
+
get_collective_embedding() (Task 2) returns the mean of all stored OV SEs.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import threading
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
PROFILES_DIR = Path("data/profiles")
|
| 29 |
+
COSINE_THRESHOLD = 0.75 # empirical threshold for ECAPA-TDNN
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
|
| 33 |
+
denom = np.linalg.norm(a) * np.linalg.norm(b)
|
| 34 |
+
return float(np.dot(a, b) / denom) if denom > 1e-8 else 0.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _running_avg(old: np.ndarray, new: np.ndarray, count: int) -> np.ndarray:
|
| 38 |
+
"""Weighted running average — older observations decay gently."""
|
| 39 |
+
alpha = 1.0 / (count + 1)
|
| 40 |
+
return (1.0 - alpha) * old + alpha * new
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SpeakerProfileManager:
|
| 44 |
+
"""Thread-safe multi-user voice profile store backed by .npy files."""
|
| 45 |
+
|
| 46 |
+
def __init__(self, profiles_dir: Path = PROFILES_DIR) -> None:
|
| 47 |
+
self._dir = Path(profiles_dir)
|
| 48 |
+
self._dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
self._lock = threading.Lock()
|
| 50 |
+
|
| 51 |
+
# SpeechBrain state
|
| 52 |
+
self._sb_model = None
|
| 53 |
+
self._sb_ready = False
|
| 54 |
+
self._sb_error: Optional[str] = None
|
| 55 |
+
|
| 56 |
+
# In-memory cache: { "user_0": {"sb": ndarray, "ov": ndarray|None, "count": int} }
|
| 57 |
+
self._profiles: dict[str, dict] = {}
|
| 58 |
+
self._load_profiles()
|
| 59 |
+
|
| 60 |
+
# ── SpeechBrain loading ───────────────────────────────────────────────────
|
| 61 |
+
|
| 62 |
+
def preload(self) -> None:
|
| 63 |
+
threading.Thread(target=self._load_sb, daemon=True).start()
|
| 64 |
+
|
| 65 |
+
def _load_sb(self) -> None:
|
| 66 |
+
try:
|
| 67 |
+
try:
|
| 68 |
+
from speechbrain.inference.classifiers import EncoderClassifier
|
| 69 |
+
except ImportError:
|
| 70 |
+
from speechbrain.pretrained import EncoderClassifier
|
| 71 |
+
|
| 72 |
+
logger.info("SpeakerProfiles: loading SpeechBrain ECAPA-TDNN …")
|
| 73 |
+
self._sb_model = EncoderClassifier.from_hparams(
|
| 74 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
| 75 |
+
run_opts={"device": "cpu"},
|
| 76 |
+
savedir="data/speechbrain_cache",
|
| 77 |
+
)
|
| 78 |
+
self._sb_ready = True
|
| 79 |
+
logger.info("SpeakerProfiles: SpeechBrain ready")
|
| 80 |
+
except Exception as exc:
|
| 81 |
+
self._sb_error = str(exc)
|
| 82 |
+
logger.error("SpeakerProfiles: SpeechBrain load failed: %s", exc)
|
| 83 |
+
|
| 84 |
+
def _extract_sb(self, audio_np: np.ndarray) -> Optional[np.ndarray]:
|
| 85 |
+
"""Return 192-d ECAPA embedding, or None if model not ready."""
|
| 86 |
+
if not self._sb_ready:
|
| 87 |
+
self._load_sb()
|
| 88 |
+
if not self._sb_ready:
|
| 89 |
+
return None
|
| 90 |
+
try:
|
| 91 |
+
import torch
|
| 92 |
+
wav = torch.tensor(audio_np, dtype=torch.float32).unsqueeze(0)
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
emb = self._sb_model.encode_batch(wav) # (1, 1, 192)
|
| 95 |
+
return emb.squeeze().cpu().numpy()
|
| 96 |
+
except Exception as exc:
|
| 97 |
+
logger.error("SpeakerProfiles: SpeechBrain inference error: %s", exc)
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
# ── Profile I/O ───────────────────────────────────────────────────────────
|
| 101 |
+
|
| 102 |
+
def _load_profiles(self) -> None:
|
| 103 |
+
profiles = {}
|
| 104 |
+
for sb_path in sorted(self._dir.glob("user_*_sb.npy")):
|
| 105 |
+
uid = sb_path.stem[:-3] # "user_N_sb" → "user_N"
|
| 106 |
+
ov_path = self._dir / f"{uid}_ov.npy"
|
| 107 |
+
cnt_path = self._dir / f"{uid}_count.txt"
|
| 108 |
+
profiles[uid] = {
|
| 109 |
+
"sb": np.load(sb_path),
|
| 110 |
+
"ov": np.load(ov_path) if ov_path.exists() else None,
|
| 111 |
+
"count": int(cnt_path.read_text()) if cnt_path.exists() else 1,
|
| 112 |
+
}
|
| 113 |
+
with self._lock:
|
| 114 |
+
self._profiles = profiles
|
| 115 |
+
logger.info("SpeakerProfiles: loaded %d profile(s)", len(profiles))
|
| 116 |
+
|
| 117 |
+
def _save_profile(self, uid: str) -> None:
|
| 118 |
+
p = self._profiles[uid]
|
| 119 |
+
np.save(self._dir / f"{uid}_sb.npy", p["sb"])
|
| 120 |
+
if p["ov"] is not None:
|
| 121 |
+
np.save(self._dir / f"{uid}_ov.npy", p["ov"])
|
| 122 |
+
(self._dir / f"{uid}_count.txt").write_text(str(p["count"]))
|
| 123 |
+
|
| 124 |
+
# ── Task 1: Speaker identification ────────────────────────────────────────
|
| 125 |
+
|
| 126 |
+
def identify_or_create(
|
| 127 |
+
self, audio_np: np.ndarray
|
| 128 |
+
) -> tuple[Optional[str], Optional[np.ndarray]]:
|
| 129 |
+
"""
|
| 130 |
+
Extract a SpeechBrain embedding and match it to an existing profile
|
| 131 |
+
(cosine similarity ≥ threshold) or create a new one.
|
| 132 |
+
|
| 133 |
+
Returns (user_id, sb_embedding).
|
| 134 |
+
Returns (None, None) if SpeechBrain is not available.
|
| 135 |
+
"""
|
| 136 |
+
sb_emb = self._extract_sb(audio_np)
|
| 137 |
+
if sb_emb is None:
|
| 138 |
+
return None, None
|
| 139 |
+
|
| 140 |
+
with self._lock:
|
| 141 |
+
best_uid, best_sim = None, -1.0
|
| 142 |
+
for uid, profile in self._profiles.items():
|
| 143 |
+
sim = _cosine(sb_emb, profile["sb"])
|
| 144 |
+
if sim > best_sim:
|
| 145 |
+
best_sim, best_uid = sim, uid
|
| 146 |
+
|
| 147 |
+
if best_uid is not None and best_sim >= COSINE_THRESHOLD:
|
| 148 |
+
# Known speaker — update running average
|
| 149 |
+
p = self._profiles[best_uid]
|
| 150 |
+
new_count = p["count"] + 1
|
| 151 |
+
p["sb"] = _running_avg(p["sb"], sb_emb, p["count"])
|
| 152 |
+
p["count"] = new_count
|
| 153 |
+
uid = best_uid
|
| 154 |
+
logger.debug(
|
| 155 |
+
"SpeakerProfiles: recognised %s (sim=%.3f, n=%d)",
|
| 156 |
+
uid, best_sim, new_count,
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
# New speaker
|
| 160 |
+
uid = f"user_{len(self._profiles)}"
|
| 161 |
+
self._profiles[uid] = {"sb": sb_emb, "ov": None, "count": 1}
|
| 162 |
+
logger.info("SpeakerProfiles: new profile → %s", uid)
|
| 163 |
+
|
| 164 |
+
self._save_profile(uid)
|
| 165 |
+
|
| 166 |
+
return uid, sb_emb
|
| 167 |
+
|
| 168 |
+
# ── OpenVoice SE management ───────────────────────────────────────────────
|
| 169 |
+
|
| 170 |
+
def update_ov_embedding(self, uid: str, ov_emb: np.ndarray) -> None:
|
| 171 |
+
"""Store or running-average the OpenVoice tone-color SE for a user."""
|
| 172 |
+
with self._lock:
|
| 173 |
+
if uid not in self._profiles:
|
| 174 |
+
return
|
| 175 |
+
p = self._profiles[uid]
|
| 176 |
+
if p["ov"] is None:
|
| 177 |
+
p["ov"] = ov_emb.copy()
|
| 178 |
+
else:
|
| 179 |
+
p["ov"] = _running_avg(p["ov"], ov_emb, p["count"])
|
| 180 |
+
self._save_profile(uid)
|
| 181 |
+
|
| 182 |
+
def get_openvoice_se(self, uid: str) -> Optional[np.ndarray]:
|
| 183 |
+
"""Return the stored OpenVoice SE for this user, or None."""
|
| 184 |
+
with self._lock:
|
| 185 |
+
p = self._profiles.get(uid)
|
| 186 |
+
return p["ov"].copy() if p and p["ov"] is not None else None
|
| 187 |
+
|
| 188 |
+
# ── Task 2: Collective Voice ──────────────────────────────────────────────
|
| 189 |
+
|
| 190 |
+
def get_collective_embedding(self) -> Optional[np.ndarray]:
|
| 191 |
+
"""
|
| 192 |
+
Load all user_N_ov.npy files, return the mean vector.
|
| 193 |
+
This is the "Median Embedding" that represents all known speakers.
|
| 194 |
+
Returns None if no OpenVoice SEs have been collected yet.
|
| 195 |
+
"""
|
| 196 |
+
# Prefer in-memory cache
|
| 197 |
+
with self._lock:
|
| 198 |
+
ov_list = [p["ov"] for p in self._profiles.values() if p["ov"] is not None]
|
| 199 |
+
|
| 200 |
+
if not ov_list:
|
| 201 |
+
# Fall back to disk scan (e.g. after a restart that didn't re-identify)
|
| 202 |
+
ov_list = [np.load(p) for p in sorted(self._dir.glob("user_*_ov.npy"))]
|
| 203 |
+
|
| 204 |
+
if not ov_list:
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
stacked = np.stack(ov_list, axis=0)
|
| 208 |
+
return stacked.mean(axis=0)
|
| 209 |
+
|
| 210 |
+
# ── Status ────────────────────────────────────────────────────────────────
|
| 211 |
+
|
| 212 |
+
def get_status(self) -> str:
|
| 213 |
+
n = len(self._profiles)
|
| 214 |
+
sb = "🟢" if self._sb_ready else ("🔴" if self._sb_error else "🟡")
|
| 215 |
+
return f"{sb} SpeechBrain | {n} speaker profile{'s' if n != 1 else ''}"
|