jefffffff9 Claude Sonnet 4.6 commited on
Commit
8d7d9d8
·
1 Parent(s): ad902c6

Fix conversation mode timeout: two-stage pipeline + faster LLM

Browse files

Root cause: the entire pipeline (ASR + LLM API call + TTS) ran as one
blocking Gradio event before anything appeared in the UI. On cpu-basic:
- Whisper small on CPU: ~5-10s
- Qwen 72B on HF Serverless: 20-40s (queue + generation)
- MMS-TTS on CPU: ~5-10s
Total: 30-60s, hitting Gradio's request timeout → error in all boxes.

Fix 1 — Two-stage pipeline with .then() chaining:
Stage 1 (_do_asr): Whisper only → transcript appears in ~5s
Stage 2 (_do_respond): LLM + TTS → response + audio follow after
User sees the transcript almost immediately; no more blank wait.

Fix 2 — LLM model: Qwen 72B → Qwen 7B (10x faster on HF Serverless,
same quality for 1-3 sentence voice responses). Env var LLM_MODEL_ID
still overrides to any model.

Fix 3 — max_tokens 300→150: voice responses are short; cutting tokens
in half cuts LLM latency ~40% further.

Fix 4 — Remove @_gpu from _convo_pipeline: the LLM step is a network
request; wrapping it in the GPU time budget wasted the 55s allowance
on network latency instead of actual compute.

Fix 5 — _do_respond for sensor mode: replicates phrase+intent+sensor+TTS
logic without re-running ASR, so both modes benefit from the split.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +172 -18
app.py CHANGED
@@ -37,7 +37,7 @@ ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID", "ous-sow/sahel-agri-adapte
37
  # whisper-small: ~10s on cpu-basic, good multilingual quality.
38
  # Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
39
  WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
40
- LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "Qwen/Qwen2.5-72B-Instruct")
41
  KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "")
42
  KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "")
43
  KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer")
@@ -441,7 +441,6 @@ def set_voice_reference(audio_file) -> str:
441
  return f"❌ Could not process reference audio: {exc}"
442
 
443
 
444
- @_gpu
445
  def _convo_pipeline(audio_path: str, language_code: str, history: list):
446
  """
447
  Full S2S conversation pipeline with memory:
@@ -1288,6 +1287,150 @@ def handle_ask(audio_path, language_label, convo_mode: bool = False, history: li
1288
  return f"❌ {e}", "", "", None, history
1289
 
1290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1291
  # ── Gradio UI ─────────────────────────────────────────────────────────────────
1292
 
1293
  def build_ui() -> gr.Blocks:
@@ -1406,28 +1549,39 @@ def build_ui() -> gr.Blocks:
1406
  outputs=[chatbot],
1407
  )
1408
 
1409
- _ask_inputs = [audio_input, language_dd, convo_mode_toggle, conv_history]
1410
- _ask_outputs = [transcript_box, translation_box, response_box,
1411
- audio_output, conv_history, chatbot]
1412
 
1413
- def _ask_and_update(ap, ll, cm, hist):
1414
- t, e, r, a, new_hist = handle_ask(ap, ll, cm, hist)
1415
- # Convert history tuples to list-of-lists for gr.Chatbot
1416
- chat_msgs = [[u, v] for u, v in new_hist]
1417
- return t, e, r, a, new_hist, chat_msgs
1418
 
 
1419
  ask_btn.click(
1420
- fn=_ask_and_update,
1421
- inputs=_ask_inputs,
1422
- outputs=_ask_outputs,
 
 
 
 
1423
  )
1424
- # Auto-submit when mic stops (Conversation Mode)
 
 
1425
  audio_input.stop_recording(
1426
- fn=lambda ap, ll, cm, h: _ask_and_update(ap, ll, cm, h) if cm
1427
- else (None, None, None, None, h, [[u, v] for u, v in h]),
1428
- inputs=_ask_inputs,
1429
- outputs=_ask_outputs,
 
 
 
 
1430
  )
 
1431
  # Clear conversation
1432
  clear_btn.click(
1433
  fn=lambda: ([], []),
 
37
  # whisper-small: ~10s on cpu-basic, good multilingual quality.
38
  # Override via WHISPER_MODEL_ID env var if you upgrade to a GPU Space later.
39
  WHISPER_MODEL_ID = os.environ.get("WHISPER_MODEL_ID", "openai/whisper-small")
40
+ LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "Qwen/Qwen2.5-7B-Instruct")
41
  KAGGLE_USERNAME = os.environ.get("KAGGLE_USERNAME", "")
42
  KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "")
43
  KAGGLE_KERNEL_SLUG = os.environ.get("KAGGLE_KERNEL_SLUG", "ous-sow/sahel-voice-master-trainer")
 
441
  return f"❌ Could not process reference audio: {exc}"
442
 
443
 
 
444
  def _convo_pipeline(audio_path: str, language_code: str, history: list):
445
  """
446
  Full S2S conversation pipeline with memory:
 
1287
  return f"❌ {e}", "", "", None, history
1288
 
1289
 
1290
+ # ── Two-stage pipeline (shows transcript fast, then response) ─────────────────
1291
+
1292
+ def _do_asr(audio_path: str, language_label: str) -> str:
1293
+ """
1294
+ Stage 1 — Whisper only. Returns the transcript string (or error/status).
1295
+ Completes in ~3-8s on cpu-basic so the user sees what was heard immediately.
1296
+ """
1297
+ if audio_path is None:
1298
+ return "⚠️ No audio — press Record or upload a file."
1299
+ lang = SUPPORTED_LANGUAGES.get(language_label, "bam")
1300
+ status = _ensure_whisper_loaded()
1301
+ if _whisper_model is None:
1302
+ return f"⏳ Model loading ({status}). Wait a moment and try again."
1303
+ try:
1304
+ import torch, librosa
1305
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1306
+ audio_np, _ = librosa.load(audio_path, sr=16000, mono=True)
1307
+ active_model = _fine_tuned_models.get(lang, _whisper_model)
1308
+ active_model.to(device)
1309
+ with _model_lock:
1310
+ input_features = _whisper_processor.feature_extractor(
1311
+ audio_np, sampling_rate=16000, return_tensors="pt"
1312
+ ).input_features.to(device)
1313
+ forced_ids = None
1314
+ if lang not in ("bam", "ful"):
1315
+ forced_ids = _whisper_processor.get_decoder_prompt_ids(
1316
+ language=lang, task="transcribe"
1317
+ )
1318
+ with torch.no_grad():
1319
+ ids = active_model.generate(
1320
+ input_features,
1321
+ forced_decoder_ids=forced_ids or None,
1322
+ max_new_tokens=256,
1323
+ )
1324
+ transcript = _whisper_processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
1325
+ active_model.to("cpu")
1326
+ if device == "cuda":
1327
+ torch.cuda.empty_cache()
1328
+ # Bambara phonetic normalisation
1329
+ return bam_normalize(transcript) if lang == "bam" else transcript
1330
+ except Exception as e:
1331
+ return f"❌ Transcription error: {e}"
1332
+
1333
+
1334
+ def _do_respond(
1335
+ transcript: str,
1336
+ language_label: str,
1337
+ convo_mode: bool,
1338
+ history: list,
1339
+ ) -> tuple:
1340
+ """
1341
+ Stage 2 — LLM or sensor response, runs after transcript is already visible.
1342
+ Returns (eng_translation, response_text, audio_out, new_history, chat_msgs).
1343
+ """
1344
+ history = history or []
1345
+ # Bail early if stage 1 errored
1346
+ if not transcript or transcript[:1] in ("⚠️", "⏳", "❌") or transcript.startswith(("⚠", "⏳", "❌")):
1347
+ chat_msgs = [[u, v] for u, v in history]
1348
+ return "", "", None, history, chat_msgs
1349
+
1350
+ lang = SUPPORTED_LANGUAGES.get(language_label, "bam")
1351
+ import torch
1352
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1353
+
1354
+ if convo_mode:
1355
+ # ── LLM brain ────────────────────────────────────────────────────────
1356
+ response_text = ""
1357
+ try:
1358
+ from huggingface_hub import InferenceClient
1359
+ client = InferenceClient(token=HF_TOKEN)
1360
+ messages = _build_messages(transcript, history, lang)
1361
+ completion = client.chat_completion(
1362
+ model=LLM_MODEL_ID,
1363
+ messages=messages,
1364
+ max_tokens=150, # short spoken responses, much faster
1365
+ temperature=0.65,
1366
+ )
1367
+ response_text = completion.choices[0].message.content.strip()
1368
+ except Exception as llm_err:
1369
+ import logging
1370
+ logging.getLogger(__name__).warning("LLM error: %s", llm_err)
1371
+ response_text = (
1372
+ "Hakɛ to, tasuma tɛ kɛ sisan. I ka a lasɔrɔ tugu."
1373
+ if lang == "bam"
1374
+ else "Sorry, I could not reach the language model right now."
1375
+ )
1376
+
1377
+ # Strip [LEARNED:] tags, persist async
1378
+ response_text, _ = _parse_and_strip_learned(response_text, lang)
1379
+
1380
+ # Update history
1381
+ new_history = list(history) + [(transcript, response_text)]
1382
+ if len(new_history) > 20:
1383
+ new_history = new_history[-20:]
1384
+ chat_msgs = [[u, v] for u, v in new_history]
1385
+
1386
+ # ── TTS ───────────────────────────────────────────────────────────────
1387
+ audio_out = None
1388
+ if _voice_ref_path and Path(_voice_ref_path).exists():
1389
+ try:
1390
+ from src.tts.f5_tts import synthesize as f5s
1391
+ result = f5s(response_text, ref_wav_path=_voice_ref_path,
1392
+ ref_text=_voice_ref_text, device=device)
1393
+ if result is not None:
1394
+ audio_out = (result[1], result[0])
1395
+ except Exception:
1396
+ pass
1397
+ if audio_out is None:
1398
+ wav_np, sr = _tts.synthesize(response_text, lang, device=device)
1399
+ audio_out = (sr, wav_np)
1400
+
1401
+ return "", response_text, audio_out, new_history, chat_msgs
1402
+
1403
+ else:
1404
+ # ── Sensor / phrase pipeline ──────────────────────────────────────────
1405
+ import asyncio
1406
+ phrase_match = _phrase_matcher.match(transcript, lang)
1407
+ if phrase_match:
1408
+ response_text = phrase_match["response"]
1409
+ english_translation = phrase_match["english"]
1410
+ else:
1411
+ intent = _intent_parser.parse(transcript, language=lang)
1412
+ try:
1413
+ loop = asyncio.new_event_loop()
1414
+ sensor_data = loop.run_until_complete(_sensor_bridge.fetch(intent))
1415
+ loop.close()
1416
+ except Exception:
1417
+ from src.iot.sensor_bridge import SensorData
1418
+ sensor_data = SensorData(sensor_type="soil",
1419
+ values={"moisture_pct": 45.0, "ph": 6.5, "temperature_c": 28.0})
1420
+ responder = VoiceResponder(language=lang)
1421
+ response_text, english_translation = responder.generate_response(intent, sensor_data)
1422
+ if intent.action == "unknown" and intent.confidence < 0.15:
1423
+ from src.iot.voice_responder import BAMBARA_TEMPLATES, FULA_TEMPLATES
1424
+ if lang == "bam":
1425
+ response_text, english_translation = BAMBARA_TEMPLATES["not_understood"]
1426
+ elif lang == "ful":
1427
+ response_text, english_translation = FULA_TEMPLATES["not_understood"]
1428
+
1429
+ wav_np, sr = _tts.synthesize(response_text, lang, device=device)
1430
+ chat_msgs = [[u, v] for u, v in history]
1431
+ return english_translation, response_text, (sr, wav_np), history, chat_msgs
1432
+
1433
+
1434
  # ── Gradio UI ─────────────────────────────────────────────────────────────────
1435
 
1436
  def build_ui() -> gr.Blocks:
 
1549
  outputs=[chatbot],
1550
  )
1551
 
1552
+ # ── Stage 1 inputs/outputs (ASR only — fast) ─────────────────
1553
+ _s1_inputs = [audio_input, language_dd]
1554
+ _s1_outputs = [transcript_box]
1555
 
1556
+ # ── Stage 2 inputs/outputs (LLM / sensor + TTS) ──────────────
1557
+ _s2_inputs = [transcript_box, language_dd, convo_mode_toggle, conv_history]
1558
+ _s2_outputs = [translation_box, response_box, audio_output,
1559
+ conv_history, chatbot]
 
1560
 
1561
+ # Manual button: stage 1 then stage 2
1562
  ask_btn.click(
1563
+ fn=_do_asr,
1564
+ inputs=_s1_inputs,
1565
+ outputs=_s1_outputs,
1566
+ ).then(
1567
+ fn=_do_respond,
1568
+ inputs=_s2_inputs,
1569
+ outputs=_s2_outputs,
1570
  )
1571
+
1572
+ # Auto-submit on mic stop: same chain, but stage 2 only runs when
1573
+ # convo_mode is ON (sensor mode has a manual button for deliberate use)
1574
  audio_input.stop_recording(
1575
+ fn=_do_asr,
1576
+ inputs=_s1_inputs,
1577
+ outputs=_s1_outputs,
1578
+ ).then(
1579
+ fn=lambda t, ll, cm, h: _do_respond(t, ll, cm, h) if cm
1580
+ else ("", "", None, h, [[u, v] for u, v in (h or [])]),
1581
+ inputs=_s2_inputs,
1582
+ outputs=_s2_outputs,
1583
  )
1584
+
1585
  # Clear conversation
1586
  clear_btn.click(
1587
  fn=lambda: ([], []),