care-notes / app.py
Akis Giannoukos
changed heading and removed Apply model and restart button
ada1ece
raw
history blame
22.6 kB
import os
import json
import re
import time
from typing import Any, Dict, List, Optional, Tuple
import gradio as gr
import numpy as np
# Audio processing
import soundfile as sf
import librosa
# Models
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
from gtts import gTTS
import spaces
import threading
# ---------------------------
# Configuration
# ---------------------------
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
USE_TTS_DEFAULT = os.getenv("USE_TTS", "false").strip().lower() == "true"
CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json")
def _load_model_id_from_config() -> str:
try:
if os.path.exists(CONFIG_PATH):
with open(CONFIG_PATH, "r") as f:
data = json.load(f)
if isinstance(data, dict) and data.get("model_id"):
return str(data["model_id"])
except Exception:
pass
return DEFAULT_CHAT_MODEL_ID
current_model_id = _load_model_id_from_config()
# ---------------------------
# Lazy singletons for pipelines
# ---------------------------
_asr_pipe = None
_gen_pipe = None
_tokenizer = None
def _hf_device() -> int:
return 0 if torch.cuda.is_available() else -1
def get_asr_pipeline():
global _asr_pipe
if _asr_pipe is None:
_asr_pipe = pipeline(
"automatic-speech-recognition",
model=DEFAULT_ASR_MODEL_ID,
device=_hf_device(),
)
return _asr_pipe
def get_textgen_pipeline():
global _gen_pipe
if _gen_pipe is None:
# Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
_dtype = torch.bfloat16
elif torch.cuda.is_available():
_dtype = torch.float16
else:
_dtype = torch.float32
_gen_pipe = pipeline(
task="text-generation",
model=current_model_id,
tokenizer=current_model_id,
device=_hf_device(),
torch_dtype=_dtype,
)
return _gen_pipe
def set_current_model_id(new_model_id: str) -> str:
global current_model_id, _gen_pipe
new_model_id = (new_model_id or "").strip()
if not new_model_id:
return "Model id is empty; keeping current model."
if new_model_id == current_model_id:
return f"Model unchanged: `{current_model_id}`"
current_model_id = new_model_id
_gen_pipe = None # force reload on next use
return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)."
def persist_model_id(new_model_id: str) -> None:
try:
with open(CONFIG_PATH, "w") as f:
json.dump({"model_id": new_model_id}, f)
except Exception:
pass
def apply_model_and_restart(new_model_id: str) -> str:
mid = (new_model_id or "").strip()
if not mid:
return "Model id is empty; not restarting."
persist_model_id(mid)
set_current_model_id(mid)
# Graceful delayed exit so response can flush
def _exit_later():
time.sleep(0.25)
os._exit(0)
threading.Thread(target=_exit_later, daemon=True).start()
return f"Restarting with model `{mid}`..."
# ---------------------------
# Utilities
# ---------------------------
def safe_json_extract(text: str) -> Optional[Dict[str, Any]]:
"""Extract first JSON object from text."""
if not text:
return None
try:
return json.loads(text)
except Exception:
pass
# Fallback: find the first {...} block
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
return json.loads(match.group(0))
except Exception:
return None
return None
def compute_audio_features(audio_path: str) -> Dict[str, float]:
"""Compute lightweight prosodic features as a proxy for OpenSMILE.
Returns a dictionary with summary statistics.
"""
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
if len(y) == 0:
return {}
# Frame-based features
hop_length = 512
frame_length = 1024
rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0]
centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0]
# Pitch estimation (coarse)
f0 = None
try:
f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length)
f0 = f0[np.isfinite(f0)]
except Exception:
f0 = None
# Speaking rate rough proxy: voiced ratio per second
energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
voiced = energy > (np.median(energy) * 1.2)
voiced_ratio = float(np.mean(voiced))
features = {
"rms_mean": float(np.mean(rms)),
"rms_std": float(np.std(rms)),
"zcr_mean": float(np.mean(zcr)),
"zcr_std": float(np.std(zcr)),
"centroid_mean": float(np.mean(centroid)),
"centroid_std": float(np.std(centroid)),
"voiced_ratio": voiced_ratio,
"duration_sec": float(len(y) / sr),
}
if f0 is not None and f0.size > 0:
features.update({
"f0_median": float(np.median(f0)),
"f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)),
})
return features
except Exception:
return {}
def detect_explicit_suicidality(text: Optional[str]) -> bool:
if not text:
return False
t = text.lower()
patterns = [
r"\bkill myself\b",
r"\bend my life\b",
r"\bend it all\b",
r"\bcommit suicide\b",
r"\bsuicide\b",
r"\bself[-\s]?harm\b",
r"\bhurt myself\b",
r"\bno reason to live\b",
r"\bwant to die\b",
r"\bending it\b",
]
for pat in patterns:
if re.search(pat, t):
return True
return False
def synthesize_tts(text: Optional[str]) -> Optional[str]:
if not text:
return None
try:
# Save MP3 to tmp and return filepath
ts = int(time.time() * 1000)
out_path = f"/tmp/tts_{ts}.mp3"
tts = gTTS(text=text, lang="en")
tts.save(out_path)
return out_path
except Exception:
return None
def severity_from_total(total_score: int) -> str:
if total_score <= 4:
return "Minimal Depression"
if total_score <= 9:
return "Mild Depression"
if total_score <= 14:
return "Moderate Depression"
if total_score <= 19:
return "Moderately Severe Depression"
return "Severe Depression"
def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
"""Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
lines = []
for user, assistant in chat_history:
if user:
lines.append(f"Patient: {user}")
if assistant:
lines.append(f"Clinician: {assistant}")
return "\n".join(lines)
def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
transcript = transcript_to_text(chat_history)
system_prompt = (
"You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
"without listing the nine questions explicitly. Keep tone empathetic, natural, and human. "
"Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
"sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
)
user_prompt = (
"Conversation so far (Patient and Clinician turns):\n\n" + transcript +
"\n\nRespond with a single short clinician-style question for the patient."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
gen = pipe(
prompt,
max_new_tokens=96,
temperature=0.7,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
reply = gen[0]["generated_text"].strip()
# Ensure it's a single concise question/sentence
if len(reply) > 300:
reply = reply[:300].rstrip() + "…"
return reply
def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]:
"""Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails."""
transcript = transcript_to_text(chat_history)
features_json = json.dumps(features, ensure_ascii=False)
system_prompt = (
"You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. "
"Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), "
"Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), "
"and High_Risk (boolean, true if any suicidal risk)."
)
user_prompt = (
"Conversation transcript:"\
f"\n{transcript}\n\n"
f"Acoustic features summary (approximate):\n{features_json}\n\n"
"Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. "
"Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
)
pipe = get_textgen_pipeline()
tokenizer = pipe.tokenizer
combined_prompt = system_prompt + "\n\n" + user_prompt
messages = [
{"role": "user", "content": combined_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Use deterministic decoding to avoid CUDA sampling edge cases on some models
gen = pipe(
prompt,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
out_text = gen[0]["generated_text"]
parsed = safe_json_extract(out_text)
# Validate and coerce
if parsed is None or "PHQ9_Scores" not in parsed:
# Simple fallback heuristic: neutral scores with low confidence
scores = {
"interest": 1,
"mood": 1,
"sleep": 1,
"energy": 1,
"appetite": 1,
"self_worth": 1,
"concentration": 1,
"motor": 1,
"suicidal_thoughts": 0,
}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
try:
# Coerce types and compute derived values if missing
scores = parsed.get("PHQ9_Scores", {})
# Ensure all keys present
keys = [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]
for k in keys:
scores[k] = int(max(0, min(3, int(scores.get(k, 0)))))
confidences = parsed.get("Confidences", [])
if not isinstance(confidences, list) or len(confidences) != 9:
confidences = [float(parsed.get("Confidence", 0.5))] * 9
confidences = [float(max(0.0, min(1.0, c))) for c in confidences]
total = int(sum(scores.values()))
severity = parsed.get("Severity") or severity_from_total(total)
overall_conf = float(parsed.get("Confidence", min(confidences)))
# Conservative high-risk detection: require explicit language or high suicidal_thoughts score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
except Exception:
# Final fallback
scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {}
if not scores:
scores = {k: 1 for k in [
"interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
]}
confidences = [0.5] * 9
total = int(sum(scores.values()))
return {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity_from_total(total),
"Confidence": float(min(confidences)),
"High_Risk": False,
}
def transcribe_audio(audio_path: Optional[str]) -> str:
if not audio_path:
return ""
try:
asr = get_asr_pipeline()
result = asr(audio_path)
if isinstance(result, dict) and "text" in result:
return result["text"].strip()
if isinstance(result, list) and len(result) > 0 and "text" in result[0]:
return result[0]["text"].strip()
except Exception:
pass
return ""
# ---------------------------
# Gradio app logic
# ---------------------------
INTRO_MESSAGE = (
"Hello, I'm here to check in on how you've been feeling lately. "
"To start, can you share how your mood has been over the past couple of weeks?"
)
def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]:
chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)]
scores = {}
meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0}
finished = False
turns = 0
return chat_history, scores, meta, finished, turns
@spaces.GPU
def process_turn(
audio_path: Optional[str],
text_input: Optional[str],
chat_history: List[Tuple[str, str]],
threshold: float,
tts_enabled: bool,
finished: Optional[bool],
turns: Optional[int],
prev_scores: Dict[str, Any],
prev_meta: Dict[str, Any],
):
# If already finished, do nothing
finished = bool(finished) if finished is not None else False
turns = int(turns) if isinstance(turns, int) else 0
if finished:
return (
chat_history,
{"info": "Assessment complete."},
prev_meta.get("Severity", ""),
finished,
turns,
None,
None,
None,
)
patient_text = (text_input or "").strip()
audio_features: Dict[str, float] = {}
if audio_path:
# Transcribe first
transcribed = transcribe_audio(audio_path)
if transcribed:
patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed
# Extract features
audio_features = compute_audio_features(audio_path)
if not patient_text:
# Ask user for input
chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?"))
return (
chat_history,
prev_scores or {},
prev_meta,
finished,
turns,
None,
None,
None,
)
# Add patient's message
chat_history.append((patient_text, None))
# Scoring agent
scoring = scoring_agent_infer(chat_history, audio_features)
scores = scoring.get("PHQ9_Scores", {})
confidences = scoring.get("Confidences", [])
total = scoring.get("Total_Score", 0)
severity = scoring.get("Severity", severity_from_total(total))
overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0))
# Override high-risk to reduce false positives: rely on explicit text or high item score
# Extract last patient message
last_patient = ""
for user_text, assistant_text in reversed(chat_history):
if user_text:
last_patient = user_text
break
explicit_flag = detect_explicit_suicidality(last_patient)
high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))
meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf}
# Termination conditions
min_conf = float(min(confidences)) if confidences else 0.0
turns += 1
done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS)
if high_risk:
closing = (
"I’m concerned about your safety based on what you shared. "
"If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. "
"I'll end the assessment now and display emergency resources."
)
chat_history[-1] = (chat_history[-1][0], closing)
finished = True
elif done:
summary = (
f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. "
"We can stop here."
)
chat_history[-1] = (chat_history[-1][0], summary)
finished = True
else:
# Generate next clinician question
reply = generate_recording_agent_reply(chat_history)
chat_history[-1] = (chat_history[-1][0], reply)
# TTS for the latest clinician message, if enabled
tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None
# Build a compact JSON for display
display_json = {
"PHQ9_Scores": scores,
"Confidences": confidences,
"Total_Score": total,
"Severity": severity,
"Confidence": overall_conf,
"High_Risk": high_risk,
}
# Clear inputs after processing
return (
chat_history,
display_json,
severity,
finished,
turns,
None,
None,
tts_path,
)
def reset_app():
return init_state()
# ---------------------------
# UI
# ---------------------------
def _on_load_init():
return init_state()
def create_demo():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
### Conversational Assessment for Responsive Engagement (CARE) Notes
Engage in a brief conversation. Your audio is transcribed, analyzed, and used to infer PHQ-9 scores.
The system stops when confidence is high enough or any safety risk is detected. It does not provide therapy or emergency counseling.
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, type="tuples")
with gr.Column():
score_json = gr.JSON(label="PHQ-9 Assessment (live)")
severity_label = gr.Label(label="Severity")
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
tts_audio = gr.Audio(label="Clinician voice", interactive=False)
model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
with gr.Row():
apply_model_btn = gr.Button("Apply model (no restart)")
# apply_model_restart_btn = gr.Button("Apply model and restart")
model_status = gr.Markdown(value=f"Current model: `{current_model_id}`")
with gr.Row():
audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
text = gr.Textbox(lines=2, placeholder="Optional: type your response instead of audio")
with gr.Row():
send_btn = gr.Button("Send")
reset_btn = gr.Button("Reset")
# App state
chat_state = gr.State()
scores_state = gr.State()
meta_state = gr.State()
finished_state = gr.State()
turns_state = gr.State()
# Initialize on load (top-level function to be pickle-safe under ZeroGPU)
demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
# Wire interactions
send_btn.click(
fn=process_turn,
inputs=[audio, text, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio, text, tts_audio],
queue=True,
api_name="message",
)
reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
# Model switch handlers
def _on_apply_model(mid: str):
msg = set_current_model_id(mid)
return f"Current model: `{current_model_id}`\n\n{msg}"
def _on_apply_model_restart(mid: str):
msg = apply_model_and_restart(mid)
return f"{msg}"
apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status])
apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status])
return demo
demo = create_demo()
if __name__ == "__main__":
# For local dev
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))