Akis Giannoukos commited on
Commit
497441d
·
1 Parent(s): 8991737

Added code

Browse files
Files changed (3) hide show
  1. README.md +263 -0
  2. app.py +512 -0
  3. requirements.txt +12 -0
README.md CHANGED
@@ -11,3 +11,266 @@ short_description: MedGemma clinician chatbot demo (research prototype)
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+
16
+ Technical Design Document: MedGemma-Based PHQ-9 Conversational Assessment Agent
17
+ 1. Overview
18
+
19
+ 1.1 Project Goal
20
+
21
+ The goal of this project is to develop an AI-driven clinician simulation agent that conducts natural conversations with patients to assess depression severity based on the PHQ-9 (Patient Health Questionnaire-9) scale. Unlike simple questionnaire bots, this system aims to infer a patient’s score implicitly through conversation and speech cues, mirroring a clinician’s behavior in real-world interviews.
22
+
23
+ 1.2 Core Concept
24
+
25
+ The system will:
26
+
27
+ Engage the user in a realistic, adaptive dialogue (clinician-style questioning).
28
+
29
+ Continuously analyze textual and vocal features to estimate PHQ-9 category scores.
30
+
31
+ Stop automatically when confidence in all PHQ-9 items is sufficiently high.
32
+
33
+ Produce a final PHQ-9 severity report.
34
+
35
+ The system will use MedGemma-4B-IT (instruction-tuned medical LLM) as the base model for both:
36
+
37
+ -A Recording Agent (conversational component)
38
+
39
+ -A Scoring Agent (PHQ-9 inference component)
40
+
41
+ 2. System Architecture
42
+
43
+ 2.1 High-Level Components
44
+ Component Description
45
+ -Frontend Client: Handles user interaction, voice input/output, and UI display.
46
+ -Speech I/O Module: Converts speech to text (ASR) and text to speech (TTS).
47
+ -Feature Extraction Module: Extracts acoustic and prosodic features via OpenSmile for emotional/speech analysis.
48
+ -Recording Agent (Chatbot): Conducts clinician-like conversation with adaptive questioning.
49
+ -Scoring Agent: Evaluates PHQ-9 symptom probabilities after each exchange and determines confidence in final diagnosis.
50
+ Controller / Orchestrator: Manages communication between agents and triggers scoring cycles.
51
+ Model Backend: Hosts MedGemma-4B-IT, fine-tuned or prompted for clinician reasoning.
52
+
53
+ 2.2 Architecture Diagram (Text Description)
54
+ ┌───────────────────────┐
55
+ │ Frontend Client │
56
+ │ (Web / Desktop App) │
57
+ │ - Voice Input/Output │
58
+ │ - Text Display │
59
+ └─────────┬─────────────┘
60
+
61
+ (Audio stream)
62
+
63
+ ┌───────────────────────┐
64
+ │ Speech I/O Module │
65
+ │ - ASR (Whisper) │
66
+ │ - TTS (e.g., Coqui) │
67
+ └─────────┬─────────────┘
68
+
69
+
70
+ ┌────────────────────────────┐
71
+ │ Feature Extraction Module │
72
+ │ - OpenSmile (prosody, pitch)│
73
+ └─────────┬──────────────────┘
74
+
75
+
76
+ ┌───────────────────────────────┐
77
+ │ Recording Agent (MedGemma) │
78
+ │ - Generates next question │
79
+ │ - Conversational context │
80
+ └─────────┬─────────────────────┘
81
+
82
+
83
+ ┌───────────────────────────────┐
84
+ │ Scoring Agent (MedGemma) │
85
+ │ - Maps text+voice features → │
86
+ │ PHQ-9 dimension confidences │
87
+ │ - Determines if assessment done│
88
+ └─────────┬─────────────────────┘
89
+
90
+
91
+ ┌───────────────────────────────┐
92
+ │ Controller / Orchestrator │
93
+ │ - Loop until confidence ≥ τ │
94
+ │ - Output PHQ-9 report │
95
+ └───────────────────────────────┘
96
+
97
+ 3. Agent Design
98
+
99
+ 3.1 Recording Agent
100
+
101
+ Role: Simulates a clinician conducting an empathetic, open-ended dialogue to elicit responses relevant to the PHQ-9 categories (mood, sleep, appetite, concentration, energy, self-worth, psychomotor changes, suicidal ideation).
102
+
103
+ Key Responsibilities:
104
+
105
+ Maintain conversational context.
106
+
107
+ Adapt follow-up questions based on inferred patient state.
108
+
109
+ Produce text responses using MedGemma-4B-IT with a clinician-style prompt template.
110
+
111
+ After each user response, trigger the Scoring Agent to reassess.
112
+
113
+ Prompt Skeleton Example:
114
+
115
+ System: You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms without listing questions.
116
+ Keep tone empathetic, natural, and human.
117
+ User: [transcribed patient input]
118
+ Assistant: [clinician-style response / next question]
119
+
120
+ 3.2 Scoring Agent
121
+
122
+ Role: Evaluates the ongoing conversation to infer a PHQ-9 score distribution and confidence values for each symptom.
123
+
124
+ Input:
125
+
126
+ Conversation transcript (all turns)
127
+
128
+ OpenSmile features (prosody, energy, speech rate)
129
+
130
+ Optional: timestamped emotional embeddings (via pretrained affect model)
131
+
132
+ Output:
133
+
134
+ Vector of 9 PHQ-9 scores (0–3)
135
+
136
+ Confidence scores per question
137
+
138
+ Overall depression severity classification (Minimal, Mild, Moderate, Moderately Severe, Severe)
139
+
140
+ Operation Flow:
141
+
142
+ Parse the full transcript and extract statements relevant to each PHQ-9 item.
143
+
144
+ Combine textual cues + acoustic cues.
145
+
146
+ Use MedGemma’s reasoning chain to map features to PHQ-9 scores.
147
+
148
+ When confidence for all ≥ threshold τ (e.g., 0.8), finalize results and signal termination.
149
+
150
+ 4. Data Flow
151
+
152
+ User speaks → Audio captured.
153
+
154
+ ASR transcribes text.
155
+
156
+ OpenSmile extracts voice features.
157
+
158
+ Recording Agent uses transcript (and optionally summarized features) → next conversational message.
159
+
160
+ Scoring Agent evaluates cumulative context → PHQ-9 score vector + confidence.
161
+
162
+ If confidence < τ → continue conversation; else → output final diagnosis.
163
+
164
+ TTS module vocalizes clinician output.
165
+
166
+ 5. Implementation Details
167
+
168
+ 5.1 Models and Libraries
169
+ Function Tool / Library
170
+ Base LLM MedGemma-4B-IT (from Hugging Face)
171
+ Whisper
172
+ gTTS (preferrably), TTS Coqui TTS, gTTS, or Bark
173
+ Audio Features OpenSmile (IS09/ComParE configs)
174
+ Backend Python / FastAPI server
175
+ Frontend Gradio
176
+ Communication WebSocket or REST APIs
177
+
178
+ 5.2 Confidence Computation
179
+
180
+ Each PHQ-9 item i has a confidence score ci ∈ [0,1].
181
+
182
+ ci estimated via secondary LLM reasoning (e.g., “How confident are you about this inference?”).
183
+
184
+ Global confidence C=minici.
185
+
186
+ Stop condition: C≥τ, e.g., 0.8.
187
+
188
+ 5.3 Example API Workflow
189
+
190
+ POST /api/message
191
+ {
192
+ "audio": <base64 encoded>,
193
+ "transcript": "...",
194
+ "features": {...}
195
+ }
196
+
197
+ {
198
+ "agent_response": "...",
199
+ "phq9_scores": [1, 0, 2, ...],
200
+ "confidences": [0.9, 0.85, ...],
201
+ "finished": false
202
+ }
203
+
204
+ 6. Training and Fine-Tuning (Future work, will not be implemented now as we do not have the data at the moment.)
205
+
206
+ Supervised Fine-Tuning (SFT) using synthetic dialogues labeled with PHQ-9 scores.
207
+
208
+ Speech-text alignment: fuse OpenSmile embeddings with conversation text embeddings before feeding to scoring prompts.
209
+
210
+ Possible multi-modal fusion via:
211
+
212
+ Feature concatenation → token embedding
213
+
214
+ or cross-attention adapter (if fine-tuning allowed).
215
+
216
+ 7. Output Specification
217
+
218
+ Final Output:
219
+
220
+ {
221
+ "PHQ9_Scores": {
222
+ "interest": 2,
223
+ "mood": 3,
224
+ "sleep": 2,
225
+ "energy": 2,
226
+ "appetite": 1,
227
+ "self_worth": 2,
228
+ "concentration": 1,
229
+ "motor": 1,
230
+ "suicidal_thoughts": 0
231
+ },
232
+ "Total_Score": 14,
233
+ "Severity": "Moderate Depression",
234
+ "Confidence": 0.86
235
+ }
236
+
237
+
238
+ Displayed alongside a clinician-style summary:
239
+
240
+ “Based on our discussion, your responses suggest moderate depressive symptoms, with difficulties in mood and sleep being most prominent.”
241
+
242
+ 8. Termination and Safety
243
+
244
+ The system will not offer therapy advice or emergency counseling.
245
+
246
+ If the patient mentions suicidal thoughts (item 9), the system:
247
+
248
+ Flags high risk,
249
+
250
+ Terminates the chat, and
251
+
252
+ Displays emergency contact information (e.g., “If you are in danger or need immediate help, call 988 in the U.S.”).
253
+
254
+ 9. Future Extensions (Not implemented now)
255
+
256
+ Fine-tuned model jointly trained on PHQ-9 labeled conversations.
257
+
258
+ Multilingual support (via Whisper multilingual and TTS).
259
+
260
+ Confidence calibration using Bayesian reasoning or uncertainty quantification.
261
+
262
+ Integration with EHR systems for clinician verification.
263
+
264
+ 10. Summary
265
+
266
+ This project creates an intelligent, conversational PHQ-9 assessment agent that blends:
267
+
268
+ The MedGemma-4B-IT medical LLM,
269
+
270
+ Audio emotion analysis with OpenSmile,
271
+
272
+ A dual-agent architecture for conversation and scoring,
273
+
274
+ and multimodal reasoning to deliver clinician-like mental health assessments.
275
+
276
+ The modular design enables local deployment on GPU servers, privacy-preserving operation, and future research extensions into multimodal diagnostic reasoning.
app.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import time
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+
10
+ # Audio processing
11
+ import soundfile as sf
12
+ import librosa
13
+
14
+ # Models
15
+ import torch
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ pipeline,
20
+ )
21
+ from gtts import gTTS
22
+
23
+
24
+ # ---------------------------
25
+ # Configuration
26
+ # ---------------------------
27
+ DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
28
+ DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
29
+ CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
30
+ MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
31
+ USE_TTS_DEFAULT = os.getenv("USE_TTS", "false").strip().lower() == "true"
32
+
33
+
34
+ # ---------------------------
35
+ # Lazy singletons for pipelines
36
+ # ---------------------------
37
+ _asr_pipe = None
38
+ _gen_pipe = None
39
+ _tokenizer = None
40
+
41
+
42
+ def get_asr_pipeline():
43
+ global _asr_pipe
44
+ if _asr_pipe is None:
45
+ _asr_pipe = pipeline(
46
+ "automatic-speech-recognition",
47
+ model=DEFAULT_ASR_MODEL_ID,
48
+ device=-1,
49
+ )
50
+ return _asr_pipe
51
+
52
+
53
+ def get_textgen_pipeline():
54
+ global _gen_pipe
55
+ if _gen_pipe is None:
56
+ # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
57
+ _gen_pipe = pipeline(
58
+ task="text-generation",
59
+ model=DEFAULT_CHAT_MODEL_ID,
60
+ tokenizer=DEFAULT_CHAT_MODEL_ID,
61
+ device=-1,
62
+ torch_dtype=torch.float32,
63
+ )
64
+ return _gen_pipe
65
+
66
+
67
+ # ---------------------------
68
+ # Utilities
69
+ # ---------------------------
70
+ def safe_json_extract(text: str) -> Optional[Dict[str, Any]]:
71
+ """Extract first JSON object from text."""
72
+ if not text:
73
+ return None
74
+ try:
75
+ return json.loads(text)
76
+ except Exception:
77
+ pass
78
+ # Fallback: find the first {...} block
79
+ match = re.search(r"\{[\s\S]*\}", text)
80
+ if match:
81
+ try:
82
+ return json.loads(match.group(0))
83
+ except Exception:
84
+ return None
85
+ return None
86
+
87
+
88
+ def compute_audio_features(audio_path: str) -> Dict[str, float]:
89
+ """Compute lightweight prosodic features as a proxy for OpenSMILE.
90
+
91
+ Returns a dictionary with summary statistics.
92
+ """
93
+ try:
94
+ y, sr = librosa.load(audio_path, sr=16000, mono=True)
95
+ if len(y) == 0:
96
+ return {}
97
+
98
+ # Frame-based features
99
+ hop_length = 512
100
+ frame_length = 1024
101
+
102
+ rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
103
+ zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0]
104
+ centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0]
105
+
106
+ # Pitch estimation (coarse)
107
+ f0 = None
108
+ try:
109
+ f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length)
110
+ f0 = f0[np.isfinite(f0)]
111
+ except Exception:
112
+ f0 = None
113
+
114
+ # Speaking rate rough proxy: voiced ratio per second
115
+ energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
116
+ voiced = energy > (np.median(energy) * 1.2)
117
+ voiced_ratio = float(np.mean(voiced))
118
+
119
+ features = {
120
+ "rms_mean": float(np.mean(rms)),
121
+ "rms_std": float(np.std(rms)),
122
+ "zcr_mean": float(np.mean(zcr)),
123
+ "zcr_std": float(np.std(zcr)),
124
+ "centroid_mean": float(np.mean(centroid)),
125
+ "centroid_std": float(np.std(centroid)),
126
+ "voiced_ratio": voiced_ratio,
127
+ "duration_sec": float(len(y) / sr),
128
+ }
129
+ if f0 is not None and f0.size > 0:
130
+ features.update({
131
+ "f0_median": float(np.median(f0)),
132
+ "f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)),
133
+ })
134
+ return features
135
+ except Exception:
136
+ return {}
137
+
138
+
139
+ def synthesize_tts(text: Optional[str]) -> Optional[str]:
140
+ if not text:
141
+ return None
142
+ try:
143
+ # Save MP3 to tmp and return filepath
144
+ ts = int(time.time() * 1000)
145
+ out_path = f"/tmp/tts_{ts}.mp3"
146
+ tts = gTTS(text=text, lang="en")
147
+ tts.save(out_path)
148
+ return out_path
149
+ except Exception:
150
+ return None
151
+
152
+
153
+ def severity_from_total(total_score: int) -> str:
154
+ if total_score <= 4:
155
+ return "Minimal Depression"
156
+ if total_score <= 9:
157
+ return "Mild Depression"
158
+ if total_score <= 14:
159
+ return "Moderate Depression"
160
+ if total_score <= 19:
161
+ return "Moderately Severe Depression"
162
+ return "Severe Depression"
163
+
164
+
165
+ def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
166
+ """Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
167
+ lines = []
168
+ for user, assistant in chat_history:
169
+ if user:
170
+ lines.append(f"Patient: {user}")
171
+ if assistant:
172
+ lines.append(f"Clinician: {assistant}")
173
+ return "\n".join(lines)
174
+
175
+
176
+ def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
177
+ transcript = transcript_to_text(chat_history)
178
+ system_prompt = (
179
+ "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
180
+ "without listing the nine questions explicitly. Keep tone empathetic, natural, and human. "
181
+ "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
182
+ "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
183
+ )
184
+ user_prompt = (
185
+ "Conversation so far (Patient and Clinician turns):\n\n" + transcript +
186
+ "\n\nRespond with a single short clinician-style question for the patient."
187
+ )
188
+ pipe = get_textgen_pipeline()
189
+ out = pipe(
190
+ f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>",
191
+ max_new_tokens=128,
192
+ temperature=0.7,
193
+ do_sample=True,
194
+ pad_token_id=pipe.tokenizer.eos_token_id,
195
+ )[0]["generated_text"]
196
+
197
+ # Extract assistant content after the last assistant tag if present
198
+ reply = out.split("<|assistant|>")[-1].strip()
199
+ # Post-process to avoid trailing special tokens
200
+ reply = re.split(r"</s>|<\|endoftext\|>", reply)[0].strip()
201
+ # Ensure it's a single concise question/sentence
202
+ if len(reply) > 300:
203
+ reply = reply[:300].rstrip() + "…"
204
+ return reply
205
+
206
+
207
+ def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]:
208
+ """Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails."""
209
+ transcript = transcript_to_text(chat_history)
210
+ features_json = json.dumps(features, ensure_ascii=False)
211
+ system_prompt = (
212
+ "You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. "
213
+ "Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), "
214
+ "Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), "
215
+ "and High_Risk (boolean, true if any suicidal risk)."
216
+ )
217
+ user_prompt = (
218
+ "Conversation transcript:"\
219
+ f"\n{transcript}\n\n"
220
+ f"Acoustic features summary (approximate):\n{features_json}\n\n"
221
+ "Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. "
222
+ "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
223
+ )
224
+ pipe = get_textgen_pipeline()
225
+ out = pipe(
226
+ f"<|system|>\n{system_prompt}\n<|user|>\n{user_prompt}\n<|assistant|>",
227
+ max_new_tokens=256,
228
+ temperature=0.2,
229
+ do_sample=True,
230
+ pad_token_id=pipe.tokenizer.eos_token_id,
231
+ )[0]["generated_text"]
232
+ parsed = safe_json_extract(out)
233
+
234
+ # Validate and coerce
235
+ if parsed is None or "PHQ9_Scores" not in parsed:
236
+ # Simple fallback heuristic: neutral scores with low confidence
237
+ scores = {
238
+ "interest": 1,
239
+ "mood": 1,
240
+ "sleep": 1,
241
+ "energy": 1,
242
+ "appetite": 1,
243
+ "self_worth": 1,
244
+ "concentration": 1,
245
+ "motor": 1,
246
+ "suicidal_thoughts": 0,
247
+ }
248
+ confidences = [0.5] * 9
249
+ total = int(sum(scores.values()))
250
+ return {
251
+ "PHQ9_Scores": scores,
252
+ "Confidences": confidences,
253
+ "Total_Score": total,
254
+ "Severity": severity_from_total(total),
255
+ "Confidence": float(min(confidences)),
256
+ "High_Risk": False,
257
+ }
258
+
259
+ try:
260
+ # Coerce types and compute derived values if missing
261
+ scores = parsed.get("PHQ9_Scores", {})
262
+ # Ensure all keys present
263
+ keys = [
264
+ "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
265
+ ]
266
+ for k in keys:
267
+ scores[k] = int(max(0, min(3, int(scores.get(k, 0)))))
268
+ confidences = parsed.get("Confidences", [])
269
+ if not isinstance(confidences, list) or len(confidences) != 9:
270
+ confidences = [float(parsed.get("Confidence", 0.5))] * 9
271
+ confidences = [float(max(0.0, min(1.0, c))) for c in confidences]
272
+ total = int(sum(scores.values()))
273
+ severity = parsed.get("Severity") or severity_from_total(total)
274
+ overall_conf = float(parsed.get("Confidence", min(confidences)))
275
+ high_risk = bool(parsed.get("High_Risk", False)) or (scores.get("suicidal_thoughts", 0) >= 1)
276
+
277
+ return {
278
+ "PHQ9_Scores": scores,
279
+ "Confidences": confidences,
280
+ "Total_Score": total,
281
+ "Severity": severity,
282
+ "Confidence": overall_conf,
283
+ "High_Risk": high_risk,
284
+ }
285
+ except Exception:
286
+ # Final fallback
287
+ scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {}
288
+ if not scores:
289
+ scores = {k: 1 for k in [
290
+ "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
291
+ ]}
292
+ confidences = [0.5] * 9
293
+ total = int(sum(scores.values()))
294
+ return {
295
+ "PHQ9_Scores": scores,
296
+ "Confidences": confidences,
297
+ "Total_Score": total,
298
+ "Severity": severity_from_total(total),
299
+ "Confidence": float(min(confidences)),
300
+ "High_Risk": False,
301
+ }
302
+
303
+
304
+ def transcribe_audio(audio_path: Optional[str]) -> str:
305
+ if not audio_path:
306
+ return ""
307
+ try:
308
+ asr = get_asr_pipeline()
309
+ result = asr(audio_path)
310
+ if isinstance(result, dict) and "text" in result:
311
+ return result["text"].strip()
312
+ if isinstance(result, list) and len(result) > 0 and "text" in result[0]:
313
+ return result[0]["text"].strip()
314
+ except Exception:
315
+ pass
316
+ return ""
317
+
318
+
319
+ # ---------------------------
320
+ # Gradio app logic
321
+ # ---------------------------
322
+ INTRO_MESSAGE = (
323
+ "Hello, I'm here to check in on how you've been feeling lately. "
324
+ "To start, can you share how your mood has been over the past couple of weeks?"
325
+ )
326
+
327
+
328
+ def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]:
329
+ chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)]
330
+ scores = {}
331
+ meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0}
332
+ finished = False
333
+ turns = 0
334
+ return chat_history, scores, meta, finished, turns
335
+
336
+
337
+ def process_turn(
338
+ audio_path: Optional[str],
339
+ text_input: Optional[str],
340
+ chat_history: List[Tuple[str, str]],
341
+ threshold: float,
342
+ tts_enabled: bool,
343
+ finished: bool,
344
+ turns: int,
345
+ prev_scores: Dict[str, Any],
346
+ prev_meta: Dict[str, Any],
347
+ ):
348
+ # If already finished, do nothing
349
+ if finished:
350
+ return (
351
+ chat_history,
352
+ {"info": "Assessment complete."},
353
+ prev_meta.get("Severity", ""),
354
+ finished,
355
+ turns,
356
+ None,
357
+ None,
358
+ None,
359
+ )
360
+
361
+ patient_text = (text_input or "").strip()
362
+ audio_features: Dict[str, float] = {}
363
+ if audio_path:
364
+ # Transcribe first
365
+ transcribed = transcribe_audio(audio_path)
366
+ if transcribed:
367
+ patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed
368
+ # Extract features
369
+ audio_features = compute_audio_features(audio_path)
370
+
371
+ if not patient_text:
372
+ # Ask user for input
373
+ chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?"))
374
+ return (
375
+ chat_history,
376
+ prev_scores or {},
377
+ prev_meta,
378
+ finished,
379
+ turns,
380
+ None,
381
+ None,
382
+ None,
383
+ )
384
+
385
+ # Add patient's message
386
+ chat_history.append((patient_text, None))
387
+
388
+ # Scoring agent
389
+ scoring = scoring_agent_infer(chat_history, audio_features)
390
+ scores = scoring.get("PHQ9_Scores", {})
391
+ confidences = scoring.get("Confidences", [])
392
+ total = scoring.get("Total_Score", 0)
393
+ severity = scoring.get("Severity", severity_from_total(total))
394
+ overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0))
395
+ high_risk = bool(scoring.get("High_Risk", False))
396
+
397
+ meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf}
398
+
399
+ # Termination conditions
400
+ min_conf = float(min(confidences)) if confidences else 0.0
401
+ turns += 1
402
+ done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS)
403
+
404
+ if high_risk:
405
+ closing = (
406
+ "I’m concerned about your safety based on what you shared. "
407
+ "If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. "
408
+ "I'll end the assessment now and display emergency resources."
409
+ )
410
+ chat_history[-1] = (chat_history[-1][0], closing)
411
+ finished = True
412
+ elif done:
413
+ summary = (
414
+ f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. "
415
+ "We can stop here."
416
+ )
417
+ chat_history[-1] = (chat_history[-1][0], summary)
418
+ finished = True
419
+ else:
420
+ # Generate next clinician question
421
+ reply = generate_recording_agent_reply(chat_history)
422
+ chat_history[-1] = (chat_history[-1][0], reply)
423
+
424
+ # TTS for the latest clinician message, if enabled
425
+ tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None
426
+
427
+ # Build a compact JSON for display
428
+ display_json = {
429
+ "PHQ9_Scores": scores,
430
+ "Confidences": confidences,
431
+ "Total_Score": total,
432
+ "Severity": severity,
433
+ "Confidence": overall_conf,
434
+ "High_Risk": high_risk,
435
+ }
436
+
437
+ # Clear inputs after processing
438
+ return (
439
+ chat_history,
440
+ display_json,
441
+ severity,
442
+ finished,
443
+ turns,
444
+ None,
445
+ None,
446
+ tts_path,
447
+ )
448
+
449
+
450
+ def reset_app():
451
+ return init_state()
452
+
453
+
454
+ # ---------------------------
455
+ # UI
456
+ # ---------------------------
457
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
458
+ gr.Markdown(
459
+ """
460
+ ### PHQ-9 Conversational Clinician Agent
461
+ Engage in a brief, empathetic conversation. Your audio is transcribed, analyzed, and used to infer PHQ-9 scores.
462
+ The system stops when confidence is high enough or any safety risk is detected. It does not provide therapy or emergency counseling.
463
+ """
464
+ )
465
+
466
+ with gr.Row():
467
+ chatbot = gr.Chatbot(height=400, type="tuples")
468
+ with gr.Column():
469
+ score_json = gr.JSON(label="PHQ-9 Assessment (live)")
470
+ severity_label = gr.Label(label="Severity")
471
+ threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
472
+ tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
473
+ tts_audio = gr.Audio(label="Clinician voice", interactive=False)
474
+
475
+ with gr.Row():
476
+ audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
477
+ text = gr.Textbox(lines=2, placeholder="Optional: type your response instead of audio")
478
+
479
+ with gr.Row():
480
+ send_btn = gr.Button("Send")
481
+ reset_btn = gr.Button("Reset")
482
+
483
+ # App state
484
+ chat_state = gr.State()
485
+ scores_state = gr.State()
486
+ meta_state = gr.State()
487
+ finished_state = gr.State()
488
+ turns_state = gr.State()
489
+
490
+ # Initialize on load
491
+ def _on_load():
492
+ return init_state()
493
+
494
+ demo.load(_on_load, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
495
+
496
+ # Wire interactions
497
+ send_btn.click(
498
+ fn=process_turn,
499
+ inputs=[audio, text, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
500
+ outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio, text, tts_audio],
501
+ queue=True,
502
+ api_name="message",
503
+ )
504
+
505
+ reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
506
+
507
+
508
+ if __name__ == "__main__":
509
+ # For local dev
510
+ demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
511
+
512
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ transformers>=4.44.2
3
+ torch>=2.1.0
4
+ accelerate>=0.34.2
5
+ sentencepiece>=0.2.0
6
+ soundfile>=0.12.1
7
+ librosa>=0.10.2
8
+ numpy>=1.26.4
9
+ scipy>=1.11.4
10
+ protobuf>=4.25.3
11
+ gTTS>=2.5.3
12
+