File size: 27,910 Bytes
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d517324
17f0761
497441d
 
 
 
 
9325a21
497441d
 
 
8925670
17f0761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
d517324
 
 
 
497441d
 
 
 
 
 
d517324
497441d
 
 
 
 
 
 
 
5731404
 
 
 
 
 
497441d
 
17f0761
 
d517324
5731404
497441d
 
 
 
17f0761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3feaf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9325a21
eb07602
9325a21
eb07602
9325a21
 
 
 
 
497441d
 
5731404
 
9325a21
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9325a21
eb07602
9325a21
eb07602
9325a21
 
5731404
9325a21
 
497441d
5731404
 
9325a21
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3feaf4
 
 
 
 
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ba8e95
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
d517324
497441d
 
 
 
 
 
5731404
 
497441d
 
 
 
5731404
 
497441d
 
 
 
 
 
 
 
 
 
4d0050b
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b08793
497441d
 
 
 
 
4d0050b
497441d
 
 
 
 
 
 
 
 
 
 
 
d3feaf4
 
 
 
 
 
 
 
 
497441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0050b
497441d
 
 
 
 
 
 
 
 
 
44521ed
 
 
 
eec3132
 
 
 
 
 
bceeb42
 
 
 
 
 
 
 
 
d517324
2110b35
 
 
ffa3a03
2110b35
ffa3a03
2110b35
ffa3a03
 
 
eb4e27d
 
ffa3a03
 
 
 
 
 
 
 
 
 
 
2110b35
 
ffa3a03
 
2110b35
21cf285
 
 
 
eb4e27d
 
1cbe806
 
2110b35
 
d517324
 
ada1ece
35d6f1d
d517324
 
bceeb42
497441d
74593a5
4d0050b
74593a5
eb45241
 
2110b35
 
1cbe806
 
2110b35
74593a5
 
bb8f474
74593a5
 
 
 
8526485
74593a5
 
 
 
 
d517324
 
 
 
 
 
 
 
bceeb42
 
 
 
 
d517324
 
4237157
d517324
4d0050b
 
d517324
 
 
497441d
2110b35
eb45241
 
 
 
 
 
 
4d0050b
8526485
497441d
17f0761
 
 
 
 
 
 
 
 
 
ea2aa1e
17f0761
d517324
497441d
d517324
497441d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
import os
import json
import re
import time
from typing import Any, Dict, List, Optional, Tuple

import gradio as gr
import numpy as np

# Audio processing
import soundfile as sf
import librosa

# Models
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
)
from gtts import gTTS
import spaces
import threading


# ---------------------------
# Configuration
# ---------------------------
DEFAULT_CHAT_MODEL_ID = os.getenv("LLM_MODEL_ID", "google/gemma-2-2b-it")
DEFAULT_ASR_MODEL_ID = os.getenv("ASR_MODEL_ID", "openai/whisper-tiny.en")
CONFIDENCE_THRESHOLD_DEFAULT = float(os.getenv("CONFIDENCE_THRESHOLD", "0.8"))
MAX_TURNS = int(os.getenv("MAX_TURNS", "12"))
USE_TTS_DEFAULT = os.getenv("USE_TTS", "true").strip().lower() == "true"
CONFIG_PATH = os.getenv("MODEL_CONFIG_PATH", "model_config.json")


def _load_model_id_from_config() -> str:
    try:
        if os.path.exists(CONFIG_PATH):
            with open(CONFIG_PATH, "r") as f:
                data = json.load(f)
                if isinstance(data, dict) and data.get("model_id"):
                    return str(data["model_id"])
    except Exception:
        pass
    return DEFAULT_CHAT_MODEL_ID


current_model_id = _load_model_id_from_config()


# ---------------------------
# Lazy singletons for pipelines
# ---------------------------
_asr_pipe = None
_gen_pipe = None
_tokenizer = None


def _hf_device() -> int:
    return 0 if torch.cuda.is_available() else -1


def get_asr_pipeline():
    global _asr_pipe
    if _asr_pipe is None:
        _asr_pipe = pipeline(
            "automatic-speech-recognition",
            model=DEFAULT_ASR_MODEL_ID,
            device=_hf_device(),
        )
    return _asr_pipe


def get_textgen_pipeline():
    global _gen_pipe
    if _gen_pipe is None:
        # Use a small default chat model for Spaces CPU; override via LLM_MODEL_ID
        if torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
            _dtype = torch.bfloat16
        elif torch.cuda.is_available():
            _dtype = torch.float16
        else:
            _dtype = torch.float32
        _gen_pipe = pipeline(
            task="text-generation",
            model=current_model_id,
            tokenizer=current_model_id,
            device=_hf_device(),
            torch_dtype=_dtype,
        )
    return _gen_pipe


def set_current_model_id(new_model_id: str) -> str:
    global current_model_id, _gen_pipe
    new_model_id = (new_model_id or "").strip()
    if not new_model_id:
        return "Model id is empty; keeping current model."
    if new_model_id == current_model_id:
        return f"Model unchanged: `{current_model_id}`"
    current_model_id = new_model_id
    _gen_pipe = None  # force reload on next use
    return f"Model switched to `{current_model_id}` (pipeline will reload on next generation)."


def persist_model_id(new_model_id: str) -> None:
    try:
        with open(CONFIG_PATH, "w") as f:
            json.dump({"model_id": new_model_id}, f)
    except Exception:
        pass


def apply_model_and_restart(new_model_id: str) -> str:
    mid = (new_model_id or "").strip()
    if not mid:
        return "Model id is empty; not restarting."
    persist_model_id(mid)
    set_current_model_id(mid)
    # Graceful delayed exit so response can flush
    def _exit_later():
        time.sleep(0.25)
        os._exit(0)
    threading.Thread(target=_exit_later, daemon=True).start()
    return f"Restarting with model `{mid}`..."


# ---------------------------
# Utilities
# ---------------------------
def safe_json_extract(text: str) -> Optional[Dict[str, Any]]:
    """Extract first JSON object from text."""
    if not text:
        return None
    try:
        return json.loads(text)
    except Exception:
        pass
    # Fallback: find the first {...} block
    match = re.search(r"\{[\s\S]*\}", text)
    if match:
        try:
            return json.loads(match.group(0))
        except Exception:
            return None
    return None


def compute_audio_features(audio_path: str) -> Dict[str, float]:
    """Compute lightweight prosodic features as a proxy for OpenSMILE.

    Returns a dictionary with summary statistics.
    """
    try:
        y, sr = librosa.load(audio_path, sr=16000, mono=True)
        if len(y) == 0:
            return {}

        # Frame-based features
        hop_length = 512
        frame_length = 1024

        rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
        zcr = librosa.feature.zero_crossing_rate(y, frame_length=frame_length, hop_length=hop_length)[0]
        centroid = librosa.feature.spectral_centroid(y=y, sr=sr, n_fft=2048, hop_length=hop_length)[0]

        # Pitch estimation (coarse)
        f0 = None
        try:
            f0 = librosa.yin(y, fmin=50, fmax=400, sr=sr, frame_length=frame_length, hop_length=hop_length)
            f0 = f0[np.isfinite(f0)]
        except Exception:
            f0 = None

        # Speaking rate rough proxy: voiced ratio per second
        energy = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
        voiced = energy > (np.median(energy) * 1.2)
        voiced_ratio = float(np.mean(voiced))

        features = {
            "rms_mean": float(np.mean(rms)),
            "rms_std": float(np.std(rms)),
            "zcr_mean": float(np.mean(zcr)),
            "zcr_std": float(np.std(zcr)),
            "centroid_mean": float(np.mean(centroid)),
            "centroid_std": float(np.std(centroid)),
            "voiced_ratio": voiced_ratio,
            "duration_sec": float(len(y) / sr),
        }
        if f0 is not None and f0.size > 0:
            features.update({
                "f0_median": float(np.median(f0)),
                "f0_iqr": float(np.percentile(f0, 75) - np.percentile(f0, 25)),
            })
        return features
    except Exception:
        return {}


def detect_explicit_suicidality(text: Optional[str]) -> bool:
    if not text:
        return False
    t = text.lower()
    patterns = [
        r"\bkill myself\b",
        r"\bend my life\b",
        r"\bend it all\b",
        r"\bcommit suicide\b",
        r"\bsuicide\b",
        r"\bself[-\s]?harm\b",
        r"\bhurt myself\b",
        r"\bno reason to live\b",
        r"\bwant to die\b",
        r"\bending it\b",
    ]
    for pat in patterns:
        if re.search(pat, t):
            return True
    return False


def synthesize_tts(text: Optional[str]) -> Optional[str]:
    if not text:
        return None
    try:
        # Save MP3 to tmp and return filepath
        ts = int(time.time() * 1000)
        out_path = f"/tmp/tts_{ts}.mp3"
        tts = gTTS(text=text, lang="en")
        tts.save(out_path)
        return out_path
    except Exception:
        return None


def severity_from_total(total_score: int) -> str:
    if total_score <= 4:
        return "Minimal Depression"
    if total_score <= 9:
        return "Mild Depression"
    if total_score <= 14:
        return "Moderate Depression"
    if total_score <= 19:
        return "Moderately Severe Depression"
    return "Severe Depression"


def transcript_to_text(chat_history: List[Tuple[str, str]]) -> str:
    """Convert chatbot history [(user, assistant), ...] to a plain text transcript."""
    lines = []
    for user, assistant in chat_history:
        if user:
            lines.append(f"Patient: {user}")
        if assistant:
            lines.append(f"Clinician: {assistant}")
    return "\n".join(lines)


def generate_recording_agent_reply(chat_history: List[Tuple[str, str]]) -> str:
    transcript = transcript_to_text(chat_history)
    system_prompt = (
        "You are a clinician conducting a conversational assessment to infer PHQ-9 symptoms "
        "without listing the nine questions explicitly. Keep tone empathetic, natural, and human. "
        "Ask one concise, natural follow-up question at a time that helps infer symptoms such as mood, "
        "sleep, appetite, energy, concentration, self-worth, psychomotor changes, and suicidal thoughts."
    )
    user_prompt = (
        "Conversation so far (Patient and Clinician turns):\n\n" + transcript +
        "\n\nRespond with a single short clinician-style question for the patient."
    )
    pipe = get_textgen_pipeline()
    tokenizer = pipe.tokenizer
    combined_prompt = system_prompt + "\n\n" + user_prompt
    messages = [
        {"role": "user", "content": combined_prompt},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    gen = pipe(
        prompt,
        max_new_tokens=96,
        temperature=0.7,
        do_sample=True,
        top_p=0.9,
        top_k=50,
        pad_token_id=tokenizer.eos_token_id,
        return_full_text=False,
    )
    reply = gen[0]["generated_text"].strip()
    # Ensure it's a single concise question/sentence
    if len(reply) > 300:
        reply = reply[:300].rstrip() + "…"
    return reply


def scoring_agent_infer(chat_history: List[Tuple[str, str]], features: Dict[str, float]) -> Dict[str, Any]:
    """Ask the LLM to produce PHQ-9 scores and confidences as JSON. Fallback if parsing fails."""
    transcript = transcript_to_text(chat_history)
    features_json = json.dumps(features, ensure_ascii=False)
    system_prompt = (
        "You evaluate an on-going clinician-patient conversation to infer a PHQ-9 assessment. "
        "Return ONLY a JSON object with: PHQ9_Scores (interest,mood,sleep,energy,appetite,self_worth,concentration,motor,suicidal_thoughts; each 0-3), "
        "Confidences (list of 9 floats 0-1 in the same order), Total_Score (0-27), Severity (string), Confidence (min of confidences), "
        "and High_Risk (boolean, true if any suicidal risk)."
    )
    user_prompt = (
        "Conversation transcript:"\
        f"\n{transcript}\n\n"
        f"Acoustic features summary (approximate):\n{features_json}\n\n"
        "Instructions: Infer PHQ9_Scores (0-3 per item), estimate Confidences per item, compute Total_Score and overall Severity. "
        "Set High_Risk=true if any suicidal ideation or risk is present. Return ONLY JSON, no prose."
    )
    pipe = get_textgen_pipeline()
    tokenizer = pipe.tokenizer
    combined_prompt = system_prompt + "\n\n" + user_prompt
    messages = [
        {"role": "user", "content": combined_prompt},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # Use deterministic decoding to avoid CUDA sampling edge cases on some models
    gen = pipe(
        prompt,
        max_new_tokens=256,
        temperature=0.0,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
        return_full_text=False,
    )
    out_text = gen[0]["generated_text"]
    parsed = safe_json_extract(out_text)

    # Validate and coerce
    if parsed is None or "PHQ9_Scores" not in parsed:
        # Simple fallback heuristic: neutral scores with low confidence
        scores = {
            "interest": 1,
            "mood": 1,
            "sleep": 1,
            "energy": 1,
            "appetite": 1,
            "self_worth": 1,
            "concentration": 1,
            "motor": 1,
            "suicidal_thoughts": 0,
        }
        confidences = [0.5] * 9
        total = int(sum(scores.values()))
        return {
            "PHQ9_Scores": scores,
            "Confidences": confidences,
            "Total_Score": total,
            "Severity": severity_from_total(total),
            "Confidence": float(min(confidences)),
            "High_Risk": False,
        }

    try:
        # Coerce types and compute derived values if missing
        scores = parsed.get("PHQ9_Scores", {})
        # Ensure all keys present
        keys = [
            "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
        ]
        for k in keys:
            scores[k] = int(max(0, min(3, int(scores.get(k, 0)))))
        confidences = parsed.get("Confidences", [])
        if not isinstance(confidences, list) or len(confidences) != 9:
            confidences = [float(parsed.get("Confidence", 0.5))] * 9
        confidences = [float(max(0.0, min(1.0, c))) for c in confidences]
        total = int(sum(scores.values()))
        severity = parsed.get("Severity") or severity_from_total(total)
        overall_conf = float(parsed.get("Confidence", min(confidences)))
        # Conservative high-risk detection: require explicit language or high suicidal_thoughts score
        # Extract last patient message
        last_patient = ""
        for user_text, assistant_text in reversed(chat_history):
            if user_text:
                last_patient = user_text
                break
        explicit_flag = detect_explicit_suicidality(last_patient) or detect_explicit_suicidality(transcript)
        high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))

        return {
            "PHQ9_Scores": scores,
            "Confidences": confidences,
            "Total_Score": total,
            "Severity": severity,
            "Confidence": overall_conf,
            "High_Risk": high_risk,
        }
    except Exception:
        # Final fallback
        scores = parsed.get("PHQ9_Scores", {}) if isinstance(parsed, dict) else {}
        if not scores:
            scores = {k: 1 for k in [
                "interest","mood","sleep","energy","appetite","self_worth","concentration","motor","suicidal_thoughts"
            ]}
        confidences = [0.5] * 9
        total = int(sum(scores.values()))
        return {
            "PHQ9_Scores": scores,
            "Confidences": confidences,
            "Total_Score": total,
            "Severity": severity_from_total(total),
            "Confidence": float(min(confidences)),
            "High_Risk": False,
        }


def transcribe_audio(audio_path: Optional[str]) -> str:
    if not audio_path:
        return ""
    try:
        asr = get_asr_pipeline()
        result = asr(audio_path)
        if isinstance(result, dict) and "text" in result:
            return result["text"].strip()
        if isinstance(result, list) and len(result) > 0 and "text" in result[0]:
            return result[0]["text"].strip()
    except Exception:
        pass
    return ""


# ---------------------------
# Gradio app logic
# ---------------------------
INTRO_MESSAGE = (
    "Hi, I'm an assistant, and I will ask you some questions about how you've been doing." 
    "We'll record our conversation, and we will give you a written copy of it."
    "From our conversation, we will send a written copy to the clinician, we will give a summary of what you are experiencing based on a questionnaire, called the Patient Health Questionnaire (PHQ-9), and we will give a summary of what your voice is like."
    "We will send this to the clinician, and the clinician will follow up with you."
    "To start, how has your mood been over the past couple of weeks?"
)


def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any], bool, int]:
    chat_history: List[Tuple[str, str]] = [("", INTRO_MESSAGE)]
    scores = {}
    meta = {"Severity": None, "Total_Score": None, "Confidence": 0.0}
    finished = False
    turns = 0
    return chat_history, scores, meta, finished, turns


@spaces.GPU
def process_turn(
    audio_path: Optional[str],
    text_input: Optional[str],
    chat_history: List[Tuple[str, str]],
    threshold: float,
    tts_enabled: bool,
    finished: Optional[bool],
    turns: Optional[int],
    prev_scores: Dict[str, Any],
    prev_meta: Dict[str, Any],
):
    # If already finished, do nothing
    finished = bool(finished) if finished is not None else False
    turns = int(turns) if isinstance(turns, int) else 0
    if finished:
        return (
            chat_history,
            {"info": "Assessment complete."},
            prev_meta.get("Severity", ""),
            finished,
            turns,
            None,
            None,
            None,
            None,
        )

    patient_text = (text_input or "").strip()
    audio_features: Dict[str, float] = {}
    if audio_path:
        # Transcribe first
        transcribed = transcribe_audio(audio_path)
        if transcribed:
            patient_text = (patient_text + " ").strip() + transcribed if patient_text else transcribed
        # Extract features
        audio_features = compute_audio_features(audio_path)

    if not patient_text:
        # Ask user for input
        chat_history.append(("", "I didn't catch that. Could you share a bit about how you've been feeling?"))
        return (
            chat_history,
            prev_scores or {},
            prev_meta.get("Severity", ""),
            finished,
            turns,
            None,
            None,
            None,
            None,
        )

    # Add patient's message
    chat_history.append((patient_text, None))

    # Scoring agent
    scoring = scoring_agent_infer(chat_history, audio_features)
    scores = scoring.get("PHQ9_Scores", {})
    confidences = scoring.get("Confidences", [])
    total = scoring.get("Total_Score", 0)
    severity = scoring.get("Severity", severity_from_total(total))
    overall_conf = float(scoring.get("Confidence", min(confidences) if confidences else 0.0))
    # Override high-risk to reduce false positives: rely on explicit text or high item score
    # Extract last patient message
    last_patient = ""
    for user_text, assistant_text in reversed(chat_history):
        if user_text:
            last_patient = user_text
            break
    explicit_flag = detect_explicit_suicidality(last_patient)
    high_risk = bool(explicit_flag or (scores.get("suicidal_thoughts", 0) >= 2))

    meta = {"Severity": severity, "Total_Score": total, "Confidence": overall_conf}

    # Termination conditions
    min_conf = float(min(confidences)) if confidences else 0.0
    turns += 1
    done = high_risk or (min_conf >= threshold) or (turns >= MAX_TURNS)

    if high_risk:
        closing = (
            "I’m concerned about your safety based on what you shared. "
            "If you are in danger or need immediate help, please call 988 in the U.S. or your local emergency number. "
            "I'll end the assessment now and display emergency resources."
        )
        chat_history[-1] = (chat_history[-1][0], closing)
        finished = True
    elif done:
        summary = (
            f"Thank you for sharing. Based on our conversation, your responses suggest {severity.lower()}. "
            "We can stop here."
        )
        chat_history[-1] = (chat_history[-1][0], summary)
        finished = True
    else:
        # Generate next clinician question
        reply = generate_recording_agent_reply(chat_history)
        chat_history[-1] = (chat_history[-1][0], reply)

    # TTS for the latest clinician message, if enabled
    tts_path = synthesize_tts(chat_history[-1][1]) if tts_enabled else None

    # Build a compact JSON for display
    display_json = {
        "PHQ9_Scores": scores,
        "Confidences": confidences,
        "Total_Score": total,
        "Severity": severity,
        "Confidence": overall_conf,
        "High_Risk": high_risk,
    }

    # Clear inputs after processing
    return (
        chat_history,
        display_json,
        severity,
        finished,
        turns,
        None,
        None,
        tts_path,
        tts_path,
    )


def reset_app():
    return init_state()


# ---------------------------
# UI
# ---------------------------
def _on_load_init():
    return init_state()


def _on_load_init_with_tts(tts_on: bool):
    chat_history, scores_state, meta_state, finished_state, turns_state = init_state()
    # Play the intro message via TTS if enabled
    tts_path = synthesize_tts(chat_history[-1][1]) if bool(tts_on) else None
    return chat_history, scores_state, meta_state, finished_state, turns_state, tts_path


def _play_intro_tts(tts_on: bool):
    if not bool(tts_on):
        return None
    try:
        return synthesize_tts(INTRO_MESSAGE)
    except Exception:
        return None

def create_demo():
    with gr.Blocks(
        theme=gr.themes.Soft(),
        css='''
        /* Voice bubble styles - clean and centered */
        #voice-bubble { 
            width: 240px; height: 240px; border-radius: 9999px; margin: 40px auto; 
            display: flex; align-items: center; justify-content: center;
            background: linear-gradient(135deg, #6ee7b7 0%, #34d399 100%);
            box-shadow: 0 20px 40px rgba(16,185,129,0.3), 0 0 0 1px rgba(255,255,255,0.1) inset;
            transition: all 250ms cubic-bezier(0.4, 0, 0.2, 1);
            cursor: default;            /* green circle itself is not clickable */
            pointer-events: none;       /* ignore clicks on the green circle */
            position: relative;
        }
        #voice-bubble:hover { 
            transform: translateY(-2px) scale(1.02); 
            box-shadow: 0 25px 50px rgba(16,185,129,0.4), 0 0 0 1px rgba(255,255,255,0.15) inset;
        }
        #voice-bubble:active { transform: translateY(0px) scale(0.98); }
        #voice-bubble.listening { 
            animation: bubble-pulse 1.5s ease-in-out infinite; 
            background: linear-gradient(135deg, #fb7185 0%, #ef4444 100%);
            box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 1px rgba(255,255,255,0.1) inset;
        }
        @keyframes bubble-pulse {
            0%, 100% { transform: scale(1.0); box-shadow: 0 20px 40px rgba(239,68,68,0.4), 0 0 0 0 rgba(239,68,68,0.5); }
            50% { transform: scale(1.05); box-shadow: 0 25px 50px rgba(239,68,68,0.5), 0 0 0 15px rgba(239,68,68,0.0); }
        }
        /* Hide microphone dropdown selector only */
        #voice-bubble select { display: none !important; }
        #voice-bubble .source-selection { display: none !important; }
        #voice-bubble label[for] { display: none !important; }
        /* Make the inner button the only clickable target */
        #voice-bubble button { pointer-events: auto; cursor: pointer; }
        /* Hide TTS player UI but keep it in DOM for autoplay */
        #tts-player { width: 0 !important; height: 0 !important; opacity: 0 !important; position: absolute; pointer-events: none; }
        '''
    ) as demo:
        gr.Markdown(
            """
            ### Conversational Assessment for Responsive Engagement (CARE) Notes
            Tap on 'Record' to start speaking, then tap on 'Stop' to stop recording.
            """
        )
        intro_play_btn = gr.Button("▶️ Play Intro", variant="secondary")

        with gr.Tabs():
            with gr.TabItem("Main"):
                with gr.Column():
                    # Microphone component styled as central bubble (tap to record/stop)
                    audio_main = gr.Microphone(type="filepath", label=None, elem_id="voice-bubble", show_label=False)
                    # Hidden text input placeholder for pipeline compatibility
                    text_main = gr.Textbox(value="", visible=False)
                    # Autoplay clinician voice output (player hidden with CSS)
                    tts_audio_main = gr.Audio(label=None, interactive=False, autoplay=True, show_label=False, elem_id="tts-player")

            with gr.TabItem("Advanced"):
                with gr.Column():
                    chatbot = gr.Chatbot(height=360, type="tuples", label="Conversation")
                    score_json = gr.JSON(label="PHQ-9 Assessment (live)")
                    severity_label = gr.Label(label="Severity")
                    threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
                    tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
                    tts_audio = gr.Audio(label="Clinician voice", interactive=False, autoplay=False, visible=False)
                    model_id_tb = gr.Textbox(value=current_model_id, label="Chat Model ID", info="e.g., google/gemma-2-2b-it or google/medgemma-4b-it")
                    with gr.Row():
                        apply_model_btn = gr.Button("Apply model (no restart)")
                        # apply_model_restart_btn = gr.Button("Apply model and restart")
                    model_status = gr.Markdown(value=f"Current model: `{current_model_id}`")

        # App state
        chat_state = gr.State()
        scores_state = gr.State()
        meta_state = gr.State()
        finished_state = gr.State()
        turns_state = gr.State()

        # Initialize on load (no autoplay due to browser policies)
        demo.load(_on_load_init, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])

        # Explicit user gesture to play intro TTS (works across browsers)
        intro_play_btn.click(fn=_play_intro_tts, inputs=[tts_enable], outputs=[tts_audio_main])

        # Wire interactions
        audio_main.stop_recording(
            fn=process_turn,
            inputs=[audio_main, text_main, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
            outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio_main, text_main, tts_audio, tts_audio_main],
            queue=True,
            api_name="message",
        )

        # Tap bubble to toggle microphone record/stop via JS
        # This JS is no longer needed as the bubble is the mic
        # voice_bubble.click(
        #     None,
        #     inputs=None,
        #     outputs=None,
        #     js="() => {\n                const bubble = document.getElementById('voice-bubble');\n                const root = document.getElementById('hidden-mic');\n                if (!root) return;\n                let didClick = false;\n                const wc = root.querySelector && root.querySelector('gradio-audio');\n                if (wc && wc.shadowRoot) {\n                  const btns = Array.from(wc.shadowRoot.querySelectorAll('button')).filter(b => !b.disabled);\n                  const txt = (b) => ((b.getAttribute('aria-label')||'') + ' ' + (b.textContent||'')).toLowerCase();\n                  const stopBtn = btns.find(b => txt(b).includes('stop'));\n                  const recBtn = btns.find(b => { const t = txt(b); return t.includes('record') || t.includes('start') || t.includes('microphone') || t.includes('mic'); });\n                  if (stopBtn) { stopBtn.click(); didClick = true; } else if (recBtn) { recBtn.click(); didClick = true; } else if (btns[0]) { btns[0].click(); didClick = true; }\n                }\n                if (!didClick) {\n                  const candidates = Array.from(root.querySelectorAll('button, [role=\\'button\\']')).filter(el => !el.disabled);\n                  if (candidates.length) { candidates[0].click(); didClick = true; }\n                }\n                if (bubble && didClick) bubble.classList.toggle('listening');\n            }",
        # )

        # No reset button in Main tab anymore

        # Model switch handlers
        def _on_apply_model(mid: str):
            msg = set_current_model_id(mid)
            return f"Current model: `{current_model_id}`\n\n{msg}"

        def _on_apply_model_restart(mid: str):
            msg = apply_model_and_restart(mid)
            return f"{msg}"

        apply_model_btn.click(fn=_on_apply_model, inputs=[model_id_tb], outputs=[model_status])
        # apply_model_restart_btn.click(fn=_on_apply_model_restart, inputs=[model_id_tb], outputs=[model_status])

    return demo

demo = create_demo()
if __name__ == "__main__":
    # For local dev
    demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))