Spaces:
Sleeping
Sleeping
File size: 27,910 Bytes
497441d d517324 17f0761 497441d 9325a21 497441d 8925670 17f0761 497441d d517324 497441d d517324 497441d 5731404 497441d 17f0761 d517324 5731404 497441d 17f0761 497441d d3feaf4 497441d 9325a21 eb07602 9325a21 eb07602 9325a21 497441d 5731404 9325a21 497441d 9325a21 eb07602 9325a21 eb07602 9325a21 5731404 9325a21 497441d 5731404 9325a21 497441d d3feaf4 497441d 5ba8e95 497441d d517324 497441d 5731404 497441d 5731404 497441d 4d0050b 497441d 0b08793 497441d 4d0050b 497441d d3feaf4 497441d 4d0050b 497441d 44521ed eec3132 bceeb42 d517324 2110b35 ffa3a03 2110b35 ffa3a03 2110b35 ffa3a03 eb4e27d ffa3a03 2110b35 ffa3a03 2110b35 21cf285 eb4e27d 1cbe806 2110b35 d517324 ada1ece 35d6f1d d517324 bceeb42 497441d 74593a5 4d0050b 74593a5 eb45241 2110b35 1cbe806 2110b35 74593a5 bb8f474 74593a5 8526485 74593a5 d517324 bceeb42 d517324 4237157 d517324 4d0050b d517324 497441d 2110b35 eb45241 4d0050b 8526485 497441d 17f0761 ea2aa1e 17f0761 d517324 497441d d517324 497441d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 |
import os
import json
import re
import time
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
# Audio processing
import soundfile as sf
import librosa
# Models
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
from gtts import gTTS
import spaces
import threading
# ---------------------------
# Configuration
# ---------------------------
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true"
CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json")
def _load_model_id_from_config() -> str:
try:
if os.path.exists(CONFIG_PATH):
with open(CONFIG_PATH, "r") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("model_id"):
return str(data["model_id"])
except Exception:
pass
return DEFAULT_CHAT_MODEL_ID
current_model_id = _load_model_id_from_config()
# ---------------------------
# Lazy singletons for pipelines
# ---------------------------
_asr_pipe = None
_gen_pipe = None
_tokenizer = None
def _hf_device() -> int:
return 0 if torch.cuda.is_available() else -1
def get_asr_pipeline():
global _asr_pipe
if _asr_pipe is None:
_asr_pipe = pipeline(
"automatic-speech-recognition",
model=DEFAULT_ASR_MODEL_ID,
device=_hf_device(),
)
return _asr_pipe
def get_textgen_pipeline():
global _gen_pipe
if _gen_pipe is None:
# Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
_dtype = torch.bfloat16
elif torch.cuda.is_available():
_dtype = torch.float16
else:
_dtype = torch.float32
_gen_pipe = pipeline(
task="text-generation",
model=current_model_id,
tokenizer=current_model_id,
device=_hf_device(),
torch_dtype=_dtype,
)
return _gen_pipe
def set_current_model_id(new_model_id: str) -> str:
global current_model_id, _gen_pipe
new_model_id = (new_model_id or "").strip()
if not new_model_id:
return "Model id is empty; keeping current model."
if new_model_id == current_model_id:
return f"Model unchanged: `{current_model_id}`"
current_model_id = new_model_id
_gen_pipe = None # force reload on next use
return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)."
def persist_model_id(new_model_id: str) -> None:
try:
with open(CONFIG_PATH, "w") as f:
json.dump({"model_id": new_model_id}, f)
except Exception:
pass
def apply_model_and_restart(new_model_id: str) -> str:
mid = (new_model_id or "").strip()
if not mid:
return "Model id is empty; not restarting."
persist_model_id(mid)
set_current_model_id(mid)
# Graceful delayed exit so response can flush
def _exit_later():
time.sleep(0.25)
os._exit(0)
threading.Thread(target=_exit_later, daemon=True).start()
return f"Restarting with model `{mid}`..."
# ---------------------------
# Utilities
# ---------------------------
def safe_json_extract(text: str) -> Optional[Dict[str, Any]]:
"""Extract first JSON object from text."""
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
# Fallback: find the first {...} block
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group(0))
except Exception:
return None
return None
def compute_audio_features(audio_path: str) -> Dict[str, float]:
"""Compute lightweight prosodic features as a proxy for OpenSMILE.
Returns a dictionary with summary statistics.
"""
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
if len(y) == 0:
return {}
# Frame-based features
hop_length = 512
frame_length = 1024
rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0]
centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0]
# Pitch estimation (coarse)
f0 = None
try:
f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length)
f0 = f0[np.isfinite(f0)]
except Exception:
f0 = None
# Speaking rate rough proxy: voiced ratio per second
energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
voiced = energy > (np.median(energy) * 1.2)
voiced_ratio = float(np.mean(voiced))
features = {
"rms_mean": float(np.mean(rms)),
"rms_std": float(np.std(rms)),
"zcr_mean": float(np.mean(zcr)),
"zcr_std": float(np.std(zcr)),
"centroid_mean": float(np.mean(centroid)),
"centroid_std": float(np.std(centroid)),
"voiced_ratio": voiced_ratio,
"duration_sec": float(len(y) / sr),
}
if f0 is not None and f0.size > 0:
features.update({
"f0_median": float(np.median(f0)),
"f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)),
})
return features
except Exception:
return {}
def detect_explicit_suicidality(text: Optional[str]) -> bool:
if not text:
return False
t = text.lower()
patterns = [
r"\bkill myself\b",
r"\bend my life\b",
r"\bend it all\b",
r"\bcommit suicide\b",
r"\bsuicide\b",
r"\bself[-\s]?harm\b",
r"\bhurt myself\b",
r"\bno reason to live\b",
r"\bwant to die\b",
r"\bending it\b",
]
for pat in patterns:
if re.search(pat, t):
return True
return False
def synthesize_tts(text: Optional[str]) -> Optional[str]:
if not text:
return None
try:
# Save MP3 to tmp and return filepath
ts = int(time.time() * 1000)
out_path = f"/tmp/tts_{ts}.mp3"
tts = gTTS(text=text, lang="en")
tts.save(out_path)
return out_path
except Exception:
return None
def severity_from_total(total_score: int) -> str:
if total_score <= 4:
return "Minimal Depression"
if total_score <= 9:
return "Mild Depression"
if total_score <= 14:
return "Moderate Depression"
if total_score <= 19:
return "Moderately Severe Depression"
return "Severe Depression"
def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
"""Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
lines = []
for user, assistant in chat_history:
if user:
lines.append(f"Patient: {user}")
if assistant:
lines.append(f"Clinician: {assistant}")
return "\n".join(lines)
def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
transcript = transcript_to_text(chat_history)
system_prompt = (
"You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
"without listing the nine questions explicitly. Keep tone empathetic, natural, and human. "
"Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
"sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
)
user_prompt = (
"Conversation so far (Patient and Clinician turns):\n\n" + transcript +
"\n\nRespond with a single short clinician-style question for the patient."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
gen = pipe(
prompt,
max_new_tokens=96,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
reply = gen[0]["generated_text"].strip()
# Ensure it's a single concise question/sentence
if len(reply) > 300:
reply = reply[:300].rstrip() + "…"
return reply
def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]:
"""Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails."""
transcript = transcript_to_text(chat_history)
features_json = json.dumps(features, ensure_ascii=False)
system_prompt = (
"You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. "
"Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), "
"Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), "
"and High_Risk (boolean, true if any suicidal risk)."
)
user_prompt = (
"Conversation transcript:"\
f"\n{transcript}\n\n"
f"Acoustic features summary (approximate):\n{features_json}\n\n"
"Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. "
"Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Use deterministic decoding to avoid CUDA sampling edge cases on some models
gen = pipe(
prompt,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
out_text = gen[0]["generated_text"]
parsed = safe_json_extract(out_text)
# Validate and coerce
if parsed is None or "PHQ9_Scores" not in parsed:
# Simple fallback heuristic: neutral scores with low confidence
scores = {
"interest": 1,
"mood": 1,
"sleep": 1,
"energy": 1,
"appetite": 1,
"self_worth": 1,
"concentration": 1,
"motor": 1,
"suicidal_thoughts": 0,
}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
try:
# Coerce types and compute derived values if missing
scores = parsed.get("PHQ9_Scores", {})
# Ensure all keys present
keys = [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]
for k in keys:
scores[k] = int(max(0, min(3, int(scores.get(k, 0)))))
confidences = parsed.get("Confidences", [])
if not isinstance(confidences, list) or len(confidences) != 9:
confidences = [float(parsed.get("Confidence", 0.5))] * 9
confidences = [float(max(0.0, min(1.0, c))) for c in confidences]
total = int(sum(scores.values()))
severity = parsed.get("Severity") or severity_from_total(total)
overall_conf = float(parsed.get("Confidence", min(confidences)))
# Conservative high-risk detection: require explicit language or high suicidal_thoughts score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
except Exception:
# Final fallback
scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {}
if not scores:
scores = {k: 1 for k in [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
def transcribe_audio(audio_path: Optional[str]) -> str:
if not audio_path:
return ""
try:
asr = get_asr_pipeline()
result = asr(audio_path)
if isinstance(result, dict) and "text" in result:
return result["text"].strip()
if isinstance(result, list) and len(result) > 0 and "text" in result[0]:
return result[0]["text"].strip()
except Exception:
pass
return ""
# ---------------------------
# Gradio app logic
# ---------------------------
INTRO_MESSAGE = (
"Hi, I'm an assistant, and I will ask you some questions about how you've been doing."
"We'll record our conversation, and we will give you a written copy of it."
"From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like."
"We will send this to the clinician, and the clinician will follow up with you."
"To start, how has your mood been over the past couple of weeks?"
)
def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]:
chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)]
scores = {}
meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0}
finished = False
turns = 0
return chat_history, scores, meta, finished, turns
@spaces.GPU
def process_turn(
audio_path: Optional[str],
text_input: Optional[str],
chat_history: List[Tuple[str, str]],
threshold: float,
tts_enabled: bool,
finished: Optional[bool],
turns: Optional[int],
prev_scores: Dict[str, Any],
prev_meta: Dict[str, Any],
):
# If already finished, do nothing
finished = bool(finished) if finished is not None else False
turns = int(turns) if isinstance(turns, int) else 0
if finished:
return (
chat_history,
{"info": "Assessment complete."},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
None,
)
patient_text = (text_input or "").strip()
audio_features: Dict[str, float] = {}
if audio_path:
# Transcribe first
transcribed = transcribe_audio(audio_path)
if transcribed:
patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed
# Extract features
audio_features = compute_audio_features(audio_path)
if not patient_text:
# Ask user for input
chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?"))
return (
chat_history,
prev_scores or {},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
None,
)
# Add patient's message
chat_history.append((patient_text, None))
# Scoring agent
scoring = scoring_agent_infer(chat_history, audio_features)
scores = scoring.get("PHQ9_Scores", {})
confidences = scoring.get("Confidences", [])
total = scoring.get("Total_Score", 0)
severity = scoring.get("Severity", severity_from_total(total))
overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0))
# Override high-risk to reduce false positives: rely on explicit text or high item score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf}
# Termination conditions
min_conf = float(min(confidences)) if confidences else 0.0
turns += 1
done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS)
if high_risk:
closing = (
"I’m concerned about your safety based on what you shared. "
"If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. "
"I'll end the assessment now and display emergency resources."
)
chat_history[-1] = (chat_history[-1][0], closing)
finished = True
elif done:
summary = (
f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. "
"We can stop here."
)
chat_history[-1] = (chat_history[-1][0], summary)
finished = True
else:
# Generate next clinician question
reply = generate_recording_agent_reply(chat_history)
chat_history[-1] = (chat_history[-1][0], reply)
# TTS for the latest clinician message, if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None
# Build a compact JSON for display
display_json = {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
# Clear inputs after processing
return (
chat_history,
display_json,
severity,
finished,
turns,
None,
None,
tts_path,
tts_path,
)
def reset_app():
return init_state()
# ---------------------------
# UI
# ---------------------------
def _on_load_init():
return init_state()
def _on_load_init_with_tts(tts_on: bool):
chat_history, scores_state, meta_state, finished_state, turns_state = init_state()
# Play the intro message via TTS if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None
return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path
def _play_intro_tts(tts_on: bool):
if not bool(tts_on):
return None
try:
return synthesize_tts(INTRO_MESSAGE)
except Exception:
return None
def create_demo():
with gr.Blocks(
theme=gr.themes.Soft(),
css='''
/* Voice bubble styles - clean and centered */
#voice-bubble {
width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto;
display: flex; align-items: center; justify-content: center;
background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%);
box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset;
transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1);
cursor: default; /* green circle itself is not clickable */
pointer-events: none; /* ignore clicks on the green circle */
position: relative;
}
#voice-bubble:hover {
transform: translateY(-2px) scale(1.02);
box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset;
}
#voice-bubble:active { transform: translateY(0px) scale(0.98); }
#voice-bubble.listening {
animation: bubble-pulse 1.5s ease-in-out infinite;
background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%);
box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset;
}
@keyframes bubble-pulse {
0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); }
50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); }
}
/* Hide microphone dropdown selector only */
#voice-bubble select { display: none !important; }
#voice-bubble .source-selection { display: none !important; }
#voice-bubble label[for] { display: none !important; }
/* Make the inner button the only clickable target */
#voice-bubble button { pointer-events: auto; cursor: pointer; }
/* Hide TTS player UI but keep it in DOM for autoplay */
#tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; }
'''
) as demo:
gr.Markdown(
"""
### Conversational Assessment for Responsive Engagement (CARE) Notes
Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording.
"""
)
intro_play_btn = gr.Button("▶️ Play Intro", variant="secondary")
with gr.Tabs():
with gr.TabItem("Main"):
with gr.Column():
# Microphone component styled as central bubble (tap to record/stop)
audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False)
# Hidden text input placeholder for pipeline compatibility
text_main = gr.Textbox(value="", visible=False)
# Autoplay clinician voice output (player hidden with CSS)
tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player")
with gr.TabItem("Advanced"):
with gr.Column():
chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation")
score_json = gr.JSON(label="PHQ-9 Assessment (live)")
severity_label = gr.Label(label="Severity")
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
with gr.Row():
apply_model_btn = gr.Button("Apply model (no restart)")
# apply_model_restart_btn = gr.Button("Apply model and restart")
model_status = gr.Markdown(value=f"Current model: `{current_model_id}`")
# App state
chat_state = gr.State()
scores_state = gr.State()
meta_state = gr.State()
finished_state = gr.State()
turns_state = gr.State()
# Initialize on load (no autoplay due to browser policies)
demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
# Explicit user gesture to play intro TTS (works across browsers)
intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])
# Wire interactions
audio_main.stop_recording(
fn=process_turn,
inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
queue=True,
api_name="message",
)
# Tap bubble to toggle microphone record/stop via JS
# This JS is no longer needed as the bubble is the mic
# voice_bubble.click(
# None,
# inputs=None,
# outputs=None,
# js="() => {\n const bubble = document.getElementById('voice-bubble');\n const root = document.getElementById('hidden-mic');\n if (!root) return;\n let didClick = false;\n const wc = root.querySelector && root.querySelector('gradio-audio');\n if (wc && wc.shadowRoot) {\n const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n const stopBtn = btns.find(b => txt(b).includes('stop'));\n const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n }\n if (!didClick) {\n const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n if (candidates.length) { candidates[0].click(); didClick = true; }\n }\n if (bubble && didClick) bubble.classList.toggle('listening');\n }",
# )
# No reset button in Main tab anymore
# Model switch handlers
def _on_apply_model(mid: str):
msg = set_current_model_id(mid)
return f"Current model: `{current_model_id}`\n\n{msg}"
def _on_apply_model_restart(mid: str):
msg = apply_model_and_restart(mid)
return f"{msg}"
apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status])
# apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status])
return demo
demo = create_demo()
if __name__ == "__main__":
# For local dev
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
|