File size: 49,583 Bytes
788074d
b34efbf
 
 
 
 
 
 
 
 
788074d
 
4258926
896de2d
 
63b0a52
4258926
 
 
 
b34efbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc636ce
31ea2bf
896de2d
b34efbf
896de2d
31ea2bf
896de2d
 
 
 
b34efbf
 
896de2d
4258926
 
896de2d
4258926
31ea2bf
4258926
31ea2bf
896de2d
 
4258926
896de2d
4258926
896de2d
 
 
 
31ea2bf
 
 
896de2d
 
31ea2bf
 
 
896de2d
 
31ea2bf
4258926
896de2d
 
31ea2bf
 
 
4258926
31ea2bf
896de2d
 
b34efbf
 
 
 
896de2d
b34efbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896de2d
31ea2bf
 
 
b34efbf
896de2d
 
 
31ea2bf
b34efbf
896de2d
31ea2bf
896de2d
 
 
31ea2bf
4258926
 
 
 
31ea2bf
 
 
 
 
 
b34efbf
 
 
 
 
 
 
31ea2bf
 
 
 
 
 
b34efbf
 
 
 
 
 
31ea2bf
896de2d
31ea2bf
 
b34efbf
31ea2bf
 
 
 
 
 
 
 
 
b34efbf
31ea2bf
 
b34efbf
31ea2bf
 
896de2d
 
31ea2bf
4258926
b34efbf
896de2d
31ea2bf
 
896de2d
 
 
 
31ea2bf
 
 
 
896de2d
 
b34efbf
 
896de2d
31ea2bf
 
 
 
b34efbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896de2d
b34efbf
896de2d
31ea2bf
b34efbf
 
 
 
 
 
896de2d
b34efbf
896de2d
b34efbf
31ea2bf
 
b34efbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4258926
 
b34efbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896de2d
 
 
 
31ea2bf
4258926
b34efbf
896de2d
 
31ea2bf
 
b34efbf
31ea2bf
4258926
 
 
 
 
b34efbf
 
4258926
 
 
 
 
b34efbf
4258926
 
 
 
 
 
 
 
31ea2bf
4258926
31ea2bf
4258926
b34efbf
4258926
b34efbf
4258926
 
31ea2bf
4258926
31ea2bf
 
 
 
 
 
 
b34efbf
 
31ea2bf
 
b34efbf
31ea2bf
 
4258926
 
b34efbf
4258926
 
31ea2bf
b34efbf
4258926
31ea2bf
 
b34efbf
31ea2bf
4258926
 
b34efbf
4258926
b34efbf
 
 
31ea2bf
b34efbf
31ea2bf
b34efbf
 
31ea2bf
 
b34efbf
31ea2bf
 
b34efbf
31ea2bf
 
 
b34efbf
4258926
b34efbf
4258926
31ea2bf
4258926
b34efbf
 
 
 
31ea2bf
 
4258926
31ea2bf
b34efbf
 
 
4258926
b34efbf
31ea2bf
b34efbf
31ea2bf
 
 
 
 
b34efbf
31ea2bf
b34efbf
31ea2bf
b34efbf
31ea2bf
b34efbf
31ea2bf
 
b34efbf
31ea2bf
b34efbf
31ea2bf
 
 
 
b34efbf
31ea2bf
 
 
b34efbf
31ea2bf
 
4258926
 
 
b34efbf
4258926
31ea2bf
 
 
b34efbf
 
 
 
31ea2bf
b34efbf
4258926
 
 
 
b34efbf
4258926
 
31ea2bf
4258926
31ea2bf
4258926
 
 
b34efbf
4258926
b34efbf
 
 
 
4258926
b34efbf
4258926
 
b34efbf
 
31ea2bf
4258926
 
31ea2bf
4258926
b34efbf
31ea2bf
 
 
4258926
31ea2bf
4258926
31ea2bf
 
4258926
 
31ea2bf
4258926
 
 
 
 
 
 
 
 
b34efbf
4258926
b34efbf
 
4258926
b34efbf
4258926
b34efbf
 
4258926
 
b34efbf
31ea2bf
b34efbf
31ea2bf
 
 
 
b34efbf
 
 
4258926
 
 
 
 
b34efbf
 
 
4258926
 
 
 
 
b34efbf
4258926
31ea2bf
4258926
31ea2bf
 
b34efbf
4258926
b34efbf
 
4258926
b34efbf
 
4258926
b34efbf
4258926
 
31ea2bf
b34efbf
31ea2bf
b34efbf
4258926
b34efbf
4258926
31ea2bf
b34efbf
896de2d
31ea2bf
896de2d
b34efbf
31ea2bf
 
 
 
b34efbf
 
 
 
31ea2bf
b34efbf
 
31ea2bf
 
b34efbf
31ea2bf
b34efbf
31ea2bf
 
b34efbf
 
 
 
 
31ea2bf
b34efbf
 
 
31ea2bf
b34efbf
 
 
 
31ea2bf
b34efbf
 
31ea2bf
b34efbf
31ea2bf
 
 
4258926
 
b34efbf
 
4258926
 
b34efbf
31ea2bf
b34efbf
4258926
b34efbf
 
 
4258926
31ea2bf
 
 
b34efbf
 
4258926
b34efbf
 
4258926
896de2d
b34efbf
4258926
 
b34efbf
31ea2bf
 
 
b34efbf
31ea2bf
b34efbf
31ea2bf
 
 
b34efbf
31ea2bf
 
b34efbf
31ea2bf
b34efbf
 
 
31ea2bf
b34efbf
31ea2bf
b34efbf
896de2d
b34efbf
b564942
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
import streamlit as st
import requests
import json
import re
import os
import operator
import traceback
from functools import lru_cache
from dotenv import load_dotenv

from langchain_groq import ChatGroq
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from langgraph.prebuilt import ToolExecutor
from langgraph.graph import StateGraph, END

from typing import Optional, List, Dict, Any, TypedDict, Annotated

# --- Environment Variable Loading & Validation ---
load_dotenv() # Load .env file if present (for local development)

UMLS_API_KEY = os.environ.get("UMLS_API_KEY")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")

# Stop execution if essential keys are missing (crucial for HF Spaces)
missing_keys = []
if not UMLS_API_KEY: missing_keys.append("UMLS_API_KEY")
if not GROQ_API_KEY: missing_keys.append("GROQ_API_KEY")
if not TAVILY_API_KEY: missing_keys.append("TAVILY_API_KEY")

if missing_keys:
    st.error(f"Missing required API Key(s): {', '.join(missing_keys)}. Please set them in Hugging Face Space Secrets or your environment variables.")
    st.stop()

# --- Configuration & Constants ---
class ClinicalAppSettings:
    APP_TITLE = "SynapseAI: Interactive Clinical Decision Support (UMLS/FDA Integrated)"
    PAGE_LAYOUT = "wide"
    MODEL_NAME = "llama3-70b-8192" # Groq Llama3 70b
    TEMPERATURE = 0.1
    MAX_SEARCH_RESULTS = 3

class ClinicalPrompts:
    # System prompt remains the same as the previous version, emphasizing structured output,
    # safety checks, guideline search, and conversational flow.
    SYSTEM_PROMPT = """
    You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation.
    Your goal is to support healthcare professionals by analyzing patient data, providing differential diagnoses, suggesting evidence-based management plans, and identifying risks according to current standards of care.

    **Core Directives for this Conversation:**
    1.  **Analyze Sequentially:** Process information turn-by-turn. Base your responses on the *entire* conversation history.
    2.  **Seek Clarity:** If the provided information is insufficient or ambiguous for a safe assessment, CLEARLY STATE what specific additional information or clarification is needed. Do NOT guess or make unsafe assumptions.
    3.  **Structured Assessment (When Ready):** When you have sufficient information and have performed necessary checks (like interactions, guideline searches), provide a comprehensive assessment using the following JSON structure. Output this JSON structure as the primary content of your response when you are providing the full analysis. Do NOT output incomplete JSON. If you need to ask a question or perform a tool call first, do that instead of outputting this structure.
        ```json
        {
          "assessment": "Concise summary of the patient's presentation and key findings based on the conversation.",
          "differential_diagnosis": [
            {"diagnosis": "Primary Diagnosis", "likelihood": "High/Medium/Low", "rationale": "Supporting evidence from conversation..."},
            {"diagnosis": "Alternative Diagnosis 1", "likelihood": "Medium/Low", "rationale": "Supporting/Refuting evidence..."},
            {"diagnosis": "Alternative Diagnosis 2", "likelihood": "Low", "rationale": "Why it's less likely but considered..."}
          ],
          "risk_assessment": {
            "identified_red_flags": ["List any triggered red flags based on input and analysis"],
            "immediate_concerns": ["Specific urgent issues requiring attention (e.g., sepsis risk, ACS rule-out)"],
            "potential_complications": ["Possible future issues based on presentation"]
          },
          "recommended_plan": {
            "investigations": ["List specific lab tests or imaging required. Use 'order_lab_test' tool."],
            "therapeutics": ["Suggest specific treatments or prescriptions. Use 'prescribe_medication' tool. MUST check interactions first using 'check_drug_interactions'."],
            "consultations": ["Recommend specialist consultations if needed."],
            "patient_education": ["Key points for patient communication."]
          },
          "rationale_summary": "Justification for assessment/plan. **Crucially, if relevant (e.g., ACS, sepsis, common infections), use 'tavily_search_results' to find and cite current clinical practice guidelines (e.g., 'latest ACC/AHA chest pain guidelines 202X', 'Surviving Sepsis Campaign guidelines') supporting your recommendations.** Include summary of guideline findings here.",
          "interaction_check_summary": "Summary of findings from 'check_drug_interactions' if performed."
        }
        ```
    4.  **Safety First - Interactions:** BEFORE suggesting a new prescription via `prescribe_medication`, you MUST FIRST use `check_drug_interactions` in a preceding or concurrent tool call. Report the findings from the interaction check. If significant interactions exist, modify the plan or state the contraindication clearly.
    5.  **Safety First - Red Flags:** Use the `flag_risk` tool IMMEDIATELY if critical red flags requiring urgent action are identified at any point in the conversation.
    6.  **Tool Use:** Employ tools (`order_lab_test`, `prescribe_medication`, `check_drug_interactions`, `flag_risk`, `tavily_search_results`) logically within the conversational flow. Wait for tool results before proceeding if the result is needed for the next step (e.g., wait for interaction check before confirming prescription in the structured JSON).
    7.  **Evidence & Guidelines:** Actively use `tavily_search_results` not just for general knowledge, but specifically to query for and incorporate **current clinical practice guidelines** relevant to the patient's presentation (e.g., chest pain, shortness of breath, suspected infection). Summarize findings in the `rationale_summary` when providing the structured output.
    8.  **Conciseness & Flow:** Be medically accurate and concise. Use standard terminology. Respond naturally in conversation (asking questions, acknowledging info) until ready for the full structured JSON output.
    """

# --- UMLS/RxNorm & OpenFDA API Helper Functions ---
UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key" # May not be needed if using apiKey directly
RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"

@lru_cache(maxsize=256) # Cache RxCUI lookups
def get_rxcui(drug_name: str) -> Optional[str]:
    """Uses RxNorm API to find the RxCUI for a given drug name."""
    if not drug_name or not isinstance(drug_name, str): return None
    drug_name = drug_name.strip()
    if not drug_name: return None

    print(f"RxNorm Lookup for: '{drug_name}'")
    try:
        params = {"name": drug_name, "search": 1} # Search for concepts related to the name
        response = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
        response.raise_for_status()
        data = response.json()
        # Extract RxCUI - prioritize exact matches or common types
        if data and "idGroup" in data and "rxnormId" in data["idGroup"]:
            # Select the first one, assuming it's the most relevant by default.
            # More sophisticated logic could check TTYs (Term Types) if needed.
            rxcui = data["idGroup"]["rxnormId"][0]
            print(f"  Found RxCUI: {rxcui} for '{drug_name}'")
            return rxcui
        else:
            # Fallback: Search /drugs endpoint if direct rxcui lookup fails
            params = {"name": drug_name}
            response = requests.get(f"{RXNORM_API_BASE}/drugs.json", params=params, timeout=10)
            response.raise_for_status()
            data = response.json()
            if data and "drugGroup" in data and "conceptGroup" in data["drugGroup"]:
                for group in data["drugGroup"]["conceptGroup"]:
                    # Prioritize Semantic Types like Brand/Clinical Drug/Ingredient
                    if group.get("tty") in ["SBD", "SCD", "GPCK", "BPCK", "IN", "MIN", "PIN"]:
                        if "conceptProperties" in group and group["conceptProperties"]:
                            rxcui = group["conceptProperties"][0].get("rxcui")
                            if rxcui:
                                print(f"  Found RxCUI (via /drugs): {rxcui} for '{drug_name}'")
                                return rxcui
        print(f"  RxCUI not found for '{drug_name}'.")
        return None
    except requests.exceptions.RequestException as e:
        print(f"  Error fetching RxCUI for '{drug_name}': {e}")
        return None
    except json.JSONDecodeError as e:
        print(f"  Error decoding RxNorm JSON response for '{drug_name}': {e}")
        return None
    except Exception as e: # Catch any other unexpected error
        print(f"  Unexpected error in get_rxcui for '{drug_name}': {e}")
        return None

@lru_cache(maxsize=128) # Cache OpenFDA lookups
def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[dict]:
    """Fetches drug label information from OpenFDA using RxCUI or drug name."""
    if not rxcui and not drug_name: return None
    print(f"OpenFDA Label Lookup for: RXCUI={rxcui}, Name={drug_name}")

    search_terms = []
    # Prioritize RxCUI lookup using multiple potential fields
    if rxcui:
        search_terms.append(f'spl_rxnorm_code:"{rxcui}"')
        search_terms.append(f'openfda.rxcui:"{rxcui}"')
    # Add name search as fallback or supplement
    if drug_name:
        search_terms.append(f'(openfda.brand_name:"{drug_name.lower()}" OR openfda.generic_name:"{drug_name.lower()}")')

    search_query = " OR ".join(search_terms)
    params = {"search": search_query, "limit": 1} # Get only the most relevant label

    try:
        response = requests.get(OPENFDA_API_BASE, params=params, timeout=15)
        response.raise_for_status()
        data = response.json()
        if data and "results" in data and data["results"]:
            print(f"  Found OpenFDA label for query: {search_query}")
            return data["results"][0] # Return the first label found
        print(f"  No OpenFDA label found for query: {search_query}")
        return None
    except requests.exceptions.RequestException as e:
        print(f"  Error fetching OpenFDA label: {e}")
        return None
    except json.JSONDecodeError as e:
        print(f"  Error decoding OpenFDA JSON response: {e}")
        return None
    except Exception as e:
        print(f"  Unexpected error in get_openfda_label: {e}")
        return None

def search_text_list(text_list: Optional[List[str]], search_terms: List[str]) -> List[str]:
    """ Case-insensitive search for any search_term within a list of text strings. Returns snippets. """
    found_snippets = []
    if not text_list or not search_terms: return found_snippets
    # Ensure search terms are lowercased strings
    search_terms_lower = [str(term).lower() for term in search_terms if term]

    for text_item in text_list:
        if not isinstance(text_item, str): continue # Skip non-string items
        text_item_lower = text_item.lower()
        for term in search_terms_lower:
            if term in text_item_lower:
                # Find the start index of the term
                start_index = text_item_lower.find(term)
                # Define snippet boundaries (e.g., 50 chars before, 100 after)
                snippet_start = max(0, start_index - 50)
                snippet_end = min(len(text_item), start_index + len(term) + 100)
                snippet = text_item[snippet_start:snippet_end]
                # Add indication of where the match is
                snippet = snippet.replace(term, f"**{term}**", 1) # Highlight first match
                found_snippets.append(f"...{snippet}...")
                break # Move to the next text item once a match is found
    return found_snippets

# --- Other Helper Functions ---
def parse_bp(bp_string: str) -> Optional[tuple[int, int]]:
    """Parses BP string like '120/80' into (systolic, diastolic) integers."""
    if not isinstance(bp_string, str): return None
    match = re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", bp_string.strip())
    if match: return int(match.group(1)), int(match.group(2))
    return None

def check_red_flags(patient_data: dict) -> List[str]:
    """Checks patient data against predefined red flags."""
    # (Keep the implementation from the previous full code listing)
    flags = []
    if not patient_data: return flags
    symptoms = patient_data.get("hpi", {}).get("symptoms", [])
    vitals = patient_data.get("vitals", {})
    history = patient_data.get("pmh", {}).get("conditions", "")
    symptoms_lower = [str(s).lower() for s in symptoms if isinstance(s, str)]

    if "chest pain" in symptoms_lower: flags.append("Red Flag: Chest Pain reported.")
    if "shortness of breath" in symptoms_lower: flags.append("Red Flag: Shortness of Breath reported.")
    if "severe headache" in symptoms_lower: flags.append("Red Flag: Severe Headache reported.")
    if "sudden vision loss" in symptoms_lower: flags.append("Red Flag: Sudden Vision Loss reported.")
    if "weakness on one side" in symptoms_lower: flags.append("Red Flag: Unilateral Weakness reported (potential stroke).")
    if "hemoptysis" in symptoms_lower: flags.append("Red Flag: Hemoptysis (coughing up blood).")
    if "syncope" in symptoms_lower: flags.append("Red Flag: Syncope (fainting).")

    if vitals:
        temp = vitals.get("temp_c"); hr = vitals.get("hr_bpm"); rr = vitals.get("rr_rpm")
        spo2 = vitals.get("spo2_percent"); bp_str = vitals.get("bp_mmhg")
        if temp is not None and temp >= 38.5: flags.append(f"Red Flag: Fever ({temp}Β°C).")
        if hr is not None and hr >= 120: flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
        if hr is not None and hr <= 50: flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
        if rr is not None and rr >= 24: flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
        if spo2 is not None and spo2 <= 92: flags.append(f"Red Flag: Hypoxia ({spo2}%).")
        if bp_str:
            bp = parse_bp(bp_str)
            if bp:
                if bp[0] >= 180 or bp[1] >= 110: flags.append(f"Red Flag: Hypertensive Urgency/Emergency (BP: {bp_str} mmHg).")
                if bp[0] <= 90 or bp[1] <= 60: flags.append(f"Red Flag: Hypotension (BP: {bp_str} mmHg).")

    if history and isinstance(history, str):
        history_lower = history.lower()
        if "history of mi" in history_lower and "chest pain" in symptoms_lower: flags.append("Red Flag: History of MI with current Chest Pain.")
        if "history of dvt/pe" in history_lower and "shortness of breath" in symptoms_lower: flags.append("Red Flag: History of DVT/PE with current Shortness of Breath.")

    return list(set(flags)) # Unique flags


def format_patient_data_for_prompt(data: dict) -> str:
    """Formats the patient dictionary into a readable string for the LLM."""
    # (Keep the implementation from the previous full code listing)
    if not data: return "No patient data provided."
    prompt_str = ""
    for key, value in data.items():
        section_title = key.replace('_', ' ').title()
        if isinstance(value, dict) and value:
            has_content = any(sub_value for sub_value in value.values())
            if has_content:
                prompt_str += f"**{section_title}:**\n"
                for sub_key, sub_value in value.items():
                     if sub_value: prompt_str += f"  - {sub_key.replace('_', ' ').title()}: {sub_value}\n"
        elif isinstance(value, list) and value:
             prompt_str += f"**{section_title}:** {', '.join(map(str, value))}\n"
        elif value and not isinstance(value, dict):
             prompt_str += f"**{section_title}:** {value}\n"
    return prompt_str.strip()


# --- Tool Definitions ---

# Pydantic models for tool inputs
class LabOrderInput(BaseModel):
    test_name: str = Field(..., description="Specific name of the lab test or panel (e.g., 'CBC', 'BMP', 'Troponin I', 'Urinalysis', 'D-dimer').")
    reason: str = Field(..., description="Clinical justification for ordering the test (e.g., 'Rule out infection', 'Assess renal function', 'Evaluate for ACS', 'Assess for PE').")
    priority: str = Field("Routine", description="Priority of the test (e.g., 'STAT', 'Routine').")

class PrescriptionInput(BaseModel):
    medication_name: str = Field(..., description="Name of the medication.")
    dosage: str = Field(..., description="Dosage amount and unit (e.g., '500 mg', '10 mg', '81 mg').")
    route: str = Field(..., description="Route of administration (e.g., 'PO', 'IV', 'IM', 'Topical', 'SL').")
    frequency: str = Field(..., description="How often the medication should be taken (e.g., 'BID', 'QDaily', 'Q4-6H PRN', 'once').")
    duration: str = Field("As directed", description="Duration of treatment (e.g., '7 days', '1 month', 'Ongoing', 'Until follow-up').")
    reason: str = Field(..., description="Clinical indication for the prescription.")

# Updated InteractionCheckInput - Note: current_medications/allergies are Optional here
# because they are populated by the tool_node from state *before* execution.
class InteractionCheckInput(BaseModel):
    potential_prescription: str = Field(..., description="The name of the NEW medication being considered for prescribing.")
    current_medications: Optional[List[str]] = Field(None, description="List of patient's current medication names (populated from state).")
    allergies: Optional[List[str]] = Field(None, description="List of patient's known allergies (populated from state).")

class FlagRiskInput(BaseModel):
    risk_description: str = Field(..., description="Specific critical risk identified (e.g., 'Suspected Sepsis', 'Acute Coronary Syndrome', 'Stroke Alert').")
    urgency: str = Field("High", description="Urgency level (e.g., 'Critical', 'High', 'Moderate').")

# Tool functions
@tool("order_lab_test", args_schema=LabOrderInput)
def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
    """Orders a specific lab test with clinical justification and priority."""
    print(f"Executing order_lab_test: {test_name}, Reason: {reason}, Priority: {priority}")
    return json.dumps({"status": "success", "message": f"Lab Ordered: {test_name} ({priority})", "details": f"Reason: {reason}"})

@tool("prescribe_medication", args_schema=PrescriptionInput)
def prescribe_medication(medication_name: str, dosage: str, route: str, frequency: str, duration: str, reason: str) -> str:
    """Prescribes a medication with detailed instructions and clinical indication. IMPORTANT: Requires prior interaction check."""
    print(f"Executing prescribe_medication: {medication_name} {dosage}...")
    # Safety check happens in tool_node *before* this is called.
    return json.dumps({"status": "success", "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}", "details": f"Duration: {duration}. Reason: {reason}"})

# --- NEW Interaction Check Tool using UMLS/RxNorm & OpenFDA ---
@tool("check_drug_interactions", args_schema=InteractionCheckInput)
def check_drug_interactions(potential_prescription: str, current_medications: Optional[List[str]] = None, allergies: Optional[List[str]] = None) -> str:
    """
    Checks for potential drug-drug and drug-allergy interactions using RxNorm API for normalization
    and OpenFDA drug labels for interaction/warning text. REQUIRES UMLS_API_KEY environment variable.
    """
    print(f"\n--- Executing REAL check_drug_interactions ---")
    print(f"Checking potential prescription: '{potential_prescription}'")
    warnings = []
    potential_med_lower = potential_prescription.lower().strip()

    # Use provided lists or default to empty
    current_meds_list = current_medications or []
    allergies_list = allergies or []
    # Clean and lowercase current med names (basic extraction: first word)
    current_med_names_lower = []
    for med in current_meds_list:
         match = re.match(r"^\s*([a-zA-Z\-]+)", str(med))
         if match: current_med_names_lower.append(match.group(1).lower())
    # Clean and lowercase allergies
    allergies_lower = [str(a).lower().strip() for a in allergies_list if a]

    print(f"  Against Current Meds (names): {current_med_names_lower}")
    print(f"  Against Allergies: {allergies_lower}")

    # --- Step 1: Normalize potential prescription ---
    print(f"  Step 1: Normalizing '{potential_prescription}'...")
    potential_rxcui = get_rxcui(potential_prescription)
    potential_label = get_openfda_label(rxcui=potential_rxcui, drug_name=potential_prescription)
    if not potential_rxcui and not potential_label:
         print(f"  Warning: Could not find RxCUI or OpenFDA label for '{potential_prescription}'. Interaction check will be limited.")
         warnings.append(f"INFO: Could not reliably identify '{potential_prescription}' in standard terminologies/databases. Checks may be incomplete.")

    # --- Step 2: Allergy Check ---
    print("  Step 2: Performing Allergy Check...")
    # Direct name match against patient's allergy list
    for allergy in allergies_lower:
        if allergy == potential_med_lower:
            warnings.append(f"CRITICAL ALLERGY (Name Match): Patient allergic to '{allergy}'. Potential prescription is '{potential_prescription}'.")
        # Basic cross-reactivity check (can be expanded)
        elif allergy in ["penicillin", "pcns"] and potential_med_lower in ["amoxicillin", "ampicillin", "augmentin", "piperacillin"]:
             warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Penicillin. High risk with '{potential_prescription}'.")
        elif allergy == "sulfa" and potential_med_lower in ["sulfamethoxazole", "bactrim", "sulfasalazine"]:
             warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to Sulfa. High risk with '{potential_prescription}'.")
        elif allergy in ["nsaids", "aspirin"] and potential_med_lower in ["ibuprofen", "naproxen", "ketorolac", "diclofenac"]:
              warnings.append(f"POTENTIAL CROSS-ALLERGY: Patient allergic to NSAIDs/Aspirin. Risk with '{potential_prescription}'.")

    # Check OpenFDA Label for Contraindications/Warnings related to ALLERGIES
    if potential_label:
        contraindications = potential_label.get("contraindications")
        warnings_section = potential_label.get("warnings_and_cautions") or potential_label.get("warnings")

        if contraindications:
            allergy_mentions_ci = search_text_list(contraindications, allergies_lower)
            if allergy_mentions_ci:
                warnings.append(f"ALLERGY RISK (Contraindication Found): Label for '{potential_prescription}' mentions contraindication potentially related to patient allergies: {'; '.join(allergy_mentions_ci)}")

        if warnings_section:
            allergy_mentions_warn = search_text_list(warnings_section, allergies_lower)
            if allergy_mentions_warn:
                 warnings.append(f"ALLERGY RISK (Warning Found): Label for '{potential_prescription}' mentions warnings potentially related to patient allergies: {'; '.join(allergy_mentions_warn)}")

    # --- Step 3: Drug-Drug Interaction Check ---
    print("  Step 3: Performing Drug-Drug Interaction Check...")
    if potential_rxcui or potential_label: # Proceed only if we have info on the potential drug
        for current_med_name in current_med_names_lower:
            if not current_med_name or current_med_name == potential_med_lower: continue # Skip empty or self-interaction

            print(f"    Checking interaction between '{potential_prescription}' and '{current_med_name}'...")
            current_rxcui = get_rxcui(current_med_name)
            current_label = get_openfda_label(rxcui=current_rxcui, drug_name=current_med_name)

            # Terms to search for in interaction text
            search_terms_for_current = [current_med_name]
            if current_rxcui: search_terms_for_current.append(current_rxcui) # Add RxCUI if found

            search_terms_for_potential = [potential_med_lower]
            if potential_rxcui: search_terms_for_potential.append(potential_rxcui) # Add RxCUI if found

            interaction_found_flag = False
            # Check Potential Drug's Label ('drug_interactions' section) for mentions of Current Drug
            if potential_label and potential_label.get("drug_interactions"):
                interaction_mentions = search_text_list(potential_label.get("drug_interactions"), search_terms_for_current)
                if interaction_mentions:
                    warnings.append(f"Potential Interaction ({potential_prescription.capitalize()} Label): Mentions '{current_med_name.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")
                    interaction_found_flag = True

            # Check Current Drug's Label ('drug_interactions' section) for mentions of Potential Drug
            if current_label and current_label.get("drug_interactions") and not interaction_found_flag: # Avoid duplicate warnings if already found
                interaction_mentions = search_text_list(current_label.get("drug_interactions"), search_terms_for_potential)
                if interaction_mentions:
                     warnings.append(f"Potential Interaction ({current_med_name.capitalize()} Label): Mentions '{potential_prescription.capitalize()}'. Snippets: {'; '.join(interaction_mentions)}")

    else: # Case where potential drug wasn't identified
         warnings.append(f"INFO: Drug-drug interaction check skipped for '{potential_prescription}' as it could not be identified via RxNorm/OpenFDA.")


    # --- Step 4: Format Output ---
    final_warnings = list(set(warnings)) # Remove duplicates
    status = "warning" if any("CRITICAL" in w or "Interaction" in w or "RISK" in w for w in final_warnings) else "clear"
    if not final_warnings: status = "clear" # Ensure clear if no warnings remain

    message = f"Interaction/Allergy check for '{potential_prescription}': {len(final_warnings)} potential issue(s) identified using RxNorm/OpenFDA." if final_warnings else f"No major interactions or allergy issues identified for '{potential_prescription}' based on RxNorm/OpenFDA lookup."
    print(f"--- Interaction Check Complete for '{potential_prescription}' ---")

    return json.dumps({"status": status, "message": message, "warnings": final_warnings})
# --- End of NEW Interaction Check Tool ---

@tool("flag_risk", args_schema=FlagRiskInput)
def flag_risk(risk_description: str, urgency: str) -> str:
    """Flags a critical risk identified during analysis for immediate attention."""
    print(f"Executing flag_risk: {risk_description}, Urgency: {urgency}")
    st.error(f"🚨 **{urgency.upper()} RISK FLAGGED by AI:** {risk_description}", icon="🚨")
    return json.dumps({"status": "flagged", "message": f"Risk '{risk_description}' flagged with {urgency} urgency."})

# Initialize Search Tool
search_tool = TavilySearchResults(
    max_results=ClinicalAppSettings.MAX_SEARCH_RESULTS,
    name="tavily_search_results"
    )

# --- LangGraph Setup ---

# Define the state structure
class AgentState(TypedDict):
    messages: Annotated[list[Any], operator.add]
    patient_data: Optional[dict]

# Define Tools and Tool Executor
tools = [
    order_lab_test,
    prescribe_medication,
    check_drug_interactions, # Using the new implementation
    flag_risk,
    search_tool
]
tool_executor = ToolExecutor(tools)

# Define the Agent Model
model = ChatGroq(
    temperature=ClinicalAppSettings.TEMPERATURE,
    model=ClinicalAppSettings.MODEL_NAME,
)
model_with_tools = model.bind_tools(tools)

# --- Graph Nodes (agent_node, tool_node remain mostly the same structurally) ---

# 1. Agent Node: Calls the LLM (No change needed from previous version)
def agent_node(state: AgentState):
    """Invokes the LLM to decide the next action or response."""
    print("\n---AGENT NODE---")
    current_messages = state['messages']
    if not current_messages or not isinstance(current_messages[0], SystemMessage):
        print("Prepending System Prompt.")
        current_messages = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + current_messages
    print(f"Invoking LLM with {len(current_messages)} messages.")
    try:
        response = model_with_tools.invoke(current_messages)
        print(f"Agent Raw Response Type: {type(response)}")
        if hasattr(response, 'tool_calls') and response.tool_calls: print(f"Agent Response Tool Calls: {response.tool_calls}")
        else: print("Agent Response: No tool calls.")
    except Exception as e:
        print(f"ERROR in agent_node during LLM invocation: {type(e).__name__} - {e}")
        traceback.print_exc()
        error_message = AIMessage(content=f"Sorry, an internal error occurred while processing the request: {type(e).__name__}")
        return {"messages": [error_message]}
    return {"messages": [response]}

# 2. Tool Node: Executes tools (Mostly the same, ensures context injection)
def tool_node(state: AgentState):
    """Executes tools called by the LLM and returns results."""
    print("\n---TOOL NODE---")
    tool_messages = []
    last_message = state['messages'][-1]

    if not isinstance(last_message, AIMessage) or not getattr(last_message, 'tool_calls', None):
        print("Warning: Tool node called unexpectedly without tool calls.")
        return {"messages": []}

    tool_calls = last_message.tool_calls
    print(f"Tool calls received: {json.dumps(tool_calls, indent=2)}")

    # Safety Check Logic (No change needed from previous version)
    prescriptions_requested = {}
    interaction_checks_requested = {}
    for call in tool_calls:
        tool_name = call.get('name'); tool_args = call.get('args', {})
        if tool_name == 'prescribe_medication':
            med_name = tool_args.get('medication_name', '').lower();
            if med_name: prescriptions_requested[med_name] = call
        elif tool_name == 'check_drug_interactions':
            potential_med = tool_args.get('potential_prescription', '').lower()
            if potential_med: interaction_checks_requested[potential_med] = call

    valid_tool_calls_for_execution = []
    blocked_ids = set()
    for med_name, prescribe_call in prescriptions_requested.items():
        if med_name not in interaction_checks_requested:
            st.error(f"**Safety Violation:** AI attempted to prescribe '{med_name}' without requesting `check_drug_interactions` in the *same turn*. Prescription blocked.")
            error_msg = ToolMessage(content=json.dumps({"status": "error", "message": f"Interaction check for '{med_name}' must be requested *before or alongside* the prescription call."}), tool_call_id=prescribe_call['id'], name=prescribe_call['name'])
            tool_messages.append(error_msg)
            blocked_ids.add(prescribe_call['id'])

    valid_tool_calls_for_execution = [call for call in tool_calls if call['id'] not in blocked_ids]

    # Augment interaction checks with patient data (Crucial part - no change needed here)
    patient_data = state.get("patient_data", {})
    patient_meds_full = patient_data.get("medications", {}).get("current", []) # Pass full med list if needed by tool
    patient_allergies = patient_data.get("allergies", [])

    for call in valid_tool_calls_for_execution:
         if call['name'] == 'check_drug_interactions':
             if 'args' not in call: call['args'] = {}
             # Pass the necessary context from patient_data to the tool arguments
             # The tool function expects 'current_medications' (list of names) and 'allergies'
             call['args']['current_medications'] = patient_meds_full # Pass the full strings
             call['args']['allergies'] = patient_allergies
             print(f"Augmented interaction check args for call ID {call['id']}") # Removed args content for brevity

    # Execute valid tool calls (No change needed from previous version)
    if valid_tool_calls_for_execution:
        print(f"Attempting to execute {len(valid_tool_calls_for_execution)} tools: {[c['name'] for c in valid_tool_calls_for_execution]}")
        try:
            responses = tool_executor.batch(valid_tool_calls_for_execution, return_exceptions=True)
            for call, resp in zip(valid_tool_calls_for_execution, responses):
                tool_call_id = call['id']; tool_name = call['name']
                if isinstance(resp, Exception):
                    error_type = type(resp).__name__; error_str = str(resp)
                    print(f"ERROR executing tool '{tool_name}' (ID: {tool_call_id}): {error_type} - {error_str}")
                    traceback.print_exc()
                    st.error(f"Error executing action '{tool_name}': {error_type}")
                    error_content = json.dumps({"status": "error", "message": f"Failed to execute '{tool_name}': {error_type} - {error_str}"})
                    tool_messages.append(ToolMessage(content=error_content, tool_call_id=tool_call_id, name=tool_name))
                    if isinstance(resp, AttributeError) and "'dict' object has no attribute 'tool'" in error_str:
                         print("\n *** DETECTED SPECIFIC ATTRIBUTE ERROR ('dict' object has no attribute 'tool') *** \n")
                else:
                    print(f"Tool '{tool_name}' (ID: {tool_call_id}) executed successfully.")
                    content_str = str(resp)
                    tool_messages.append(ToolMessage(content=content_str, tool_call_id=tool_call_id, name=tool_name))
        except Exception as e:
            print(f"CRITICAL UNEXPECTED ERROR within tool_node logic: {type(e).__name__} - {e}")
            traceback.print_exc(); st.error(f"Critical internal error processing actions: {e}")
            error_content = json.dumps({"status": "error", "message": f"Internal error processing tools: {e}"})
            processed_ids = {msg.tool_call_id for msg in tool_messages}
            for call in valid_tool_calls_for_execution:
                 if call['id'] not in processed_ids: tool_messages.append(ToolMessage(content=error_content, tool_call_id=call['id'], name=call['name']))

    print(f"Returning {len(tool_messages)} tool messages.")
    return {"messages": tool_messages}


# --- Graph Edges (Routing Logic) --- (No change needed)
def should_continue(state: AgentState) -> str:
    """Determines whether to call tools, end the conversation turn, or handle errors."""
    print("\n---ROUTING DECISION---")
    last_message = state['messages'][-1] if state['messages'] else None
    if not isinstance(last_message, AIMessage): return "end_conversation_turn"
    if "Sorry, an internal error occurred" in last_message.content: return "end_conversation_turn"
    if getattr(last_message, 'tool_calls', None): return "continue_tools"
    else: return "end_conversation_turn"

# --- Graph Definition & Compilation --- (No change needed)
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", should_continue, {"continue_tools": "tools", "end_conversation_turn": END})
workflow.add_edge("tools", "agent")
app = workflow.compile()
print("LangGraph compiled successfully.")

# --- Streamlit UI ---
def main():
    st.set_page_config(page_title=ClinicalAppSettings.APP_TITLE, layout=ClinicalAppSettings.PAGE_LAYOUT)
    st.title(f"🩺 {ClinicalAppSettings.APP_TITLE}")
    st.caption(f"Interactive Assistant | LangGraph/Groq/Tavily/UMLS/OpenFDA | Model: {ClinicalAppSettings.MODEL_NAME}")

    # Initialize session state (No change needed)
    if "messages" not in st.session_state: st.session_state.messages = []
    if "patient_data" not in st.session_state: st.session_state.patient_data = None
    if "graph_app" not in st.session_state: st.session_state.graph_app = app

    # --- Patient Data Input Sidebar --- (Adjusted allergy/med extraction slightly)
    with st.sidebar:
        st.header("πŸ“„ Patient Intake Form")
        # Demographics, HPI, History, Social/Family, Vitals/Exam sections remain the same input fields
        # ... (Copy input fields from previous full code version) ...
        st.subheader("Demographics")
        age = st.number_input("Age", min_value=0, max_value=120, value=55, key="age_input")
        sex = st.selectbox("Biological Sex", ["Male", "Female", "Other/Prefer not to say"], key="sex_input")
        st.subheader("History of Present Illness (HPI)")
        chief_complaint = st.text_input("Chief Complaint", "Chest pain", key="cc_input")
        hpi_details = st.text_area("Detailed HPI", "55 y/o male presents with substernal chest pain started 2 hours ago...", key="hpi_input", height=150)
        symptoms = st.multiselect("Associated Symptoms", ["Nausea", "Diaphoresis", "Shortness of Breath", "Dizziness", "Palpitations", "Fever", "Cough", "Severe Headache", "Syncope", "Hemoptysis"], default=["Nausea", "Diaphoresis"], key="sym_input")
        st.subheader("Past History")
        pmh = st.text_area("Past Medical History (PMH)", "Hypertension (HTN), Hyperlipidemia (HLD), Type 2 Diabetes Mellitus (DM2), History of MI", key="pmh_input")
        psh = st.text_area("Past Surgical History (PSH)", "Appendectomy (2005)", key="psh_input")
        st.subheader("Medications & Allergies")
        current_meds_str = st.text_area("Current Medications (name, dose, freq)", "Lisinopril 10mg daily\nMetformin 1000mg BID\nAtorvastatin 40mg daily\nAspirin 81mg daily", key="meds_input")
        allergies_str = st.text_area("Allergies (comma separated, specify reaction if known)", "Penicillin (rash), Sulfa (hives)", key="allergy_input")
        st.subheader("Social/Family History")
        social_history = st.text_area("Social History (SH)", "Smoker (1 ppd x 30 years), occasional alcohol.", key="sh_input")
        family_history = st.text_area("Family History (FHx)", "Father had MI at age 60. Mother has HTN.", key="fhx_input")
        st.subheader("Vitals & Exam Findings")
        col1, col2 = st.columns(2)
        with col1:
            temp_c = st.number_input("Temp (Β°C)", 35.0, 42.0, 36.8, format="%.1f", key="temp_input")
            hr_bpm = st.number_input("HR (bpm)", 30, 250, 95, key="hr_input")
            rr_rpm = st.number_input("RR (rpm)", 5, 50, 18, key="rr_input")
        with col2:
            bp_mmhg = st.text_input("BP (SYS/DIA)", "155/90", key="bp_input")
            spo2_percent = st.number_input("SpO2 (%)", 70, 100, 96, key="spo2_input")
            pain_scale = st.slider("Pain (0-10)", 0, 10, 8, key="pain_input")
        exam_notes = st.text_area("Brief Physical Exam Notes", "Awake, alert, oriented x3. Mild distress. Lungs clear bilaterally...", key="exam_input", height=100)


        # Compile Patient Data Dictionary (Refined Extraction for Tool Use)
        if st.button("Start/Update Consultation", key="start_button"):
            # Store full medication strings for display/context
            current_meds_list = [med.strip() for med in current_meds_str.split('\n') if med.strip()]
            # Extract just the names (simplified) for the interaction check tool's state population
            current_med_names_only = []
            for med in current_meds_list:
                match = re.match(r"^\s*([a-zA-Z\-]+)", med)
                if match: current_med_names_only.append(match.group(1).lower())

            # Extract allergy names (simplified, before parenthesis)
            allergies_list = []
            for a in allergies_str.split(','):
                cleaned_allergy = a.strip()
                if cleaned_allergy:
                     match = re.match(r"^\s*([a-zA-Z\-\s/]+)(?:\s*\(.*\))?", cleaned_allergy)
                     name_part = match.group(1).strip().lower() if match else cleaned_allergy.lower()
                     allergies_list.append(name_part)

            st.session_state.patient_data = {
                "demographics": {"age": age, "sex": sex},
                "hpi": {"chief_complaint": chief_complaint, "details": hpi_details, "symptoms": symptoms},
                "pmh": {"conditions": pmh}, "psh": {"procedures": psh},
                # Store both full list and names_only list
                "medications": {"current": current_meds_list, "names_only": current_med_names_only},
                "allergies": allergies_list, # Store cleaned list
                "social_history": {"details": social_history}, "family_history": {"details": family_history},
                "vitals": { "temp_c": temp_c, "hr_bpm": hr_bpm, "bp_mmhg": bp_mmhg, "rr_rpm": rr_rpm, "spo2_percent": spo2_percent, "pain_scale": pain_scale},
                "exam_findings": {"notes": exam_notes}
            }

            # Initial Red Flag Check
            red_flags = check_red_flags(st.session_state.patient_data)
            st.sidebar.markdown("---")
            if red_flags:
                st.sidebar.warning("**Initial Red Flags Detected:**")
                for flag in red_flags: st.sidebar.warning(f"- {flag.replace('Red Flag: ','')}")
            else: st.sidebar.success("No immediate red flags detected.")

            # Prepare initial message & reset history
            initial_prompt = "Initiate consultation for the patient described in the intake form. Review the data and begin analysis."
            st.session_state.messages = [HumanMessage(content=initial_prompt)]
            st.success("Patient data loaded/updated. Ready for analysis.")


    # --- Main Chat Interface Area --- (No change needed in display logic)
    st.header("πŸ’¬ Clinical Consultation")

    # Display chat messages from history
    # (Copy the message display loop from the previous full code version)
    for msg_index, msg in enumerate(st.session_state.messages):
        unique_key = f"msg_{msg_index}"
        if isinstance(msg, HumanMessage):
            with st.chat_message("user", key=f"{unique_key}_user"): st.markdown(msg.content)
        elif isinstance(msg, AIMessage):
            with st.chat_message("assistant", key=f"{unique_key}_ai"):
                ai_content = msg.content; structured_output = None
                try:
                    json_match = re.search(r"```json\s*(\{.*?\})\s*```", ai_content, re.DOTALL | re.IGNORECASE)
                    if json_match:
                        json_str = json_match.group(1); prefix = ai_content[:json_match.start()].strip(); suffix = ai_content[json_match.end():].strip()
                        if prefix: st.markdown(prefix)
                        structured_output = json.loads(json_str)
                        if suffix: st.markdown(suffix)
                    elif ai_content.strip().startswith("{") and ai_content.strip().endswith("}"):
                         structured_output = json.loads(ai_content); ai_content = ""
                    else: st.markdown(ai_content)
                except Exception as e: st.markdown(ai_content); print(f"Error parsing/displaying AI JSON: {e}")

                if structured_output and isinstance(structured_output, dict):
                    # (Copy the structured JSON display logic from previous full code)
                    st.divider(); st.subheader("πŸ“Š AI Analysis & Recommendations")
                    cols = st.columns(2)
                    with cols[0]:
                        st.markdown(f"**Assessment:**"); st.markdown(f"> {structured_output.get('assessment', 'N/A')}")
                        st.markdown(f"**Differential Diagnosis:**")
                        ddx = structured_output.get('differential_diagnosis', []);
                        if ddx:
                            for item in ddx:
                                likelihood = item.get('likelihood', '?').capitalize(); icon = "πŸ₯‡" if likelihood=="High" else ("πŸ₯ˆ" if likelihood=="Medium" else "πŸ₯‰")
                                with st.expander(f"{icon} {item.get('diagnosis', 'Unknown')} ({likelihood})"): st.write(f"**Rationale:** {item.get('rationale', 'N/A')}")
                        else: st.info("No DDx provided.")
                        st.markdown(f"**Risk Assessment:**"); risk = structured_output.get('risk_assessment', {})
                        flags = risk.get('identified_red_flags', []); concerns = risk.get("immediate_concerns", []); comps = risk.get("potential_complications", [])
                        if flags: st.warning(f"**Flags:** {', '.join(flags)}")
                        if concerns: st.warning(f"**Concerns:** {', '.join(concerns)}")
                        if comps: st.info(f"**Potential Complications:** {', '.join(comps)}")
                        if not flags and not concerns: st.success("No major risks highlighted.")
                    with cols[1]:
                        st.markdown(f"**Recommended Plan:**"); plan = structured_output.get('recommended_plan', {})
                        for section in ["investigations", "therapeutics", "consultations", "patient_education"]:
                             st.markdown(f"_{section.replace('_',' ').capitalize()}:_"); items = plan.get(section)
                             if items: [st.markdown(f"- {item}") for item in items] if isinstance(items, list) else st.markdown(f"- {items}")
                             else: st.markdown("_None suggested._")
                             st.markdown("") # Space
                    st.markdown(f"**Rationale & Guideline Check:**"); st.markdown(f"> {structured_output.get('rationale_summary', 'N/A')}")
                    interaction_summary = structured_output.get("interaction_check_summary", "")
                    if interaction_summary: st.markdown(f"**Interaction Check Summary:**"); st.markdown(f"> {interaction_summary}")
                    st.divider()

                if getattr(msg, 'tool_calls', None):
                     with st.expander("πŸ› οΈ AI requested actions", expanded=False):
                         for tc in msg.tool_calls:
                             try: st.code(f"Action: {tc.get('name', 'Unknown')}\nArgs: {json.dumps(tc.get('args', {}), indent=2)}", language="json")
                             except Exception as display_e: st.error(f"Could not display tool call: {display_e}"); st.code(str(tc))

        elif isinstance(msg, ToolMessage):
            tool_name_display = getattr(msg, 'name', 'tool_execution')
            with st.chat_message(tool_name_display, avatar="πŸ› οΈ", key=f"{unique_key}_tool"):
                # (Copy the ToolMessage display logic from previous full code)
                try:
                    tool_data = json.loads(msg.content); status = tool_data.get("status", "info"); message = tool_data.get("message", msg.content)
                    details = tool_data.get("details"); warnings = tool_data.get("warnings")
                    if status == "success" or status == "clear" or status == "flagged": st.success(f"{message}", icon="βœ…" if status != "flagged" else "🚨")
                    elif status == "warning":
                         st.warning(f"{message}", icon="⚠️")
                         if warnings and isinstance(warnings, list):
                             st.caption("Details:")
                             for warn in warnings: st.caption(f"- {warn}") # Display warnings from the tool output JSON
                    else: st.error(f"{message}", icon="❌")
                    if details: st.caption(f"Details: {details}")
                except json.JSONDecodeError: st.info(f"{msg.content}") # Display raw if not JSON
                except Exception as e: st.error(f"Error displaying tool message: {e}", icon="❌"); st.caption(f"Raw content: {msg.content}")


    # --- Chat Input Logic --- (No change needed)
    if prompt := st.chat_input("Your message or follow-up query..."):
        if not st.session_state.patient_data:
            st.warning("Please load patient data using the sidebar first."); st.stop()

        user_message = HumanMessage(content=prompt)
        st.session_state.messages.append(user_message)
        with st.chat_message("user"): st.markdown(prompt)

        current_state = AgentState(messages=st.session_state.messages, patient_data=st.session_state.patient_data)

        with st.spinner("SynapseAI is thinking..."):
            try:
                final_state = st.session_state.graph_app.invoke(current_state, {"recursion_limit": 15})
                st.session_state.messages = final_state['messages']
            except Exception as e:
                print(f"CRITICAL ERROR during graph invocation: {type(e).__name__} - {e}"); traceback.print_exc()
                st.error(f"An error occurred during the conversation turn: {e}", icon="❌")
                # Optionally add error to history for user visibility
                # error_ai_msg = AIMessage(content=f"Sorry, a critical error occurred: {type(e).__name__}. Please check logs or try again.")
                # st.session_state.messages.append(error_ai_msg)

        st.rerun() # Refresh display

    # Disclaimer (No change needed)
    st.markdown("---")
    st.warning("""**Disclaimer:** SynapseAI is an AI assistant... (Verify all outputs)""")

if __name__ == "__main__":
    main()