mgbam commited on
Commit
e22139a
Β·
verified Β·
1 Parent(s): a1dbeb2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +461 -234
agent.py CHANGED
@@ -1,250 +1,477 @@
1
- """
2
- SynapseAI Clinical Decision Support System
3
- Expert-Level Implementation with Safety-Centric Architecture
4
- """
5
-
6
  import os
7
  import re
8
  import json
9
  import logging
10
- from typing import (Any, Dict, List, Optional, TypedDict,
11
- Callable, Sequence, Tuple, Union)
12
  from functools import lru_cache
13
- from enum import Enum
14
 
15
  import requests
16
- from pydantic import BaseModel, Field, ValidationError
17
  from langchain_groq import ChatGroq
18
- from langchain_core.messages import (HumanMessage, SystemMessage,
19
- AIMessage, ToolMessage)
20
- from langchain_core.tools import BaseTool
21
- from langgraph.graph import StateGraph, END
22
  from langgraph.prebuilt import ToolExecutor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # ── Type Definitions ──────────────────────────────────────────────────
25
- class ClinicalPriority(str, Enum):
26
- STAT = "STAT"
27
- URGENT = "Urgent"
28
- ROUTINE = "Routine"
29
-
30
- class ClinicalState(TypedDict):
31
- messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]]
32
- patient_data: Dict[str, Any]
33
- safety_warnings: List[Dict[str, str]]
34
- workflow_metadata: Dict[str, Union[int, float, bool]]
35
- execution_log: List[Dict[str, str]]
36
-
37
- # ── Configuration ─────────────────────────────────────────────────────
38
- class ClinicalConfig:
39
- MAX_ITERATIONS = 6 # Evidence-based conversation turn limit
40
- RECURSION_BUFFER = 2 # Safety margin for LangGraph execution
41
- DRUG_CHECK_REQUIRED = True # Hard enforcement for interaction checks
42
-
43
- SAFETY_PARAMETERS = {
44
- 'max_bp_systolic': 180,
45
- 'min_bp_systolic': 90,
46
- 'max_hr': 120,
47
- 'min_spo2': 92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # ── Core Clinical Tools ──────────────────────────────────────────────
51
- class ClinicalToolkit:
52
- @staticmethod
53
- def get_essential_tools() -> List[BaseTool]:
54
- """Return validated clinical tools with safety wrappers"""
55
- return [
56
- ClinicalToolkit.order_lab_test,
57
- ClinicalToolkit.prescribe_medication,
58
- ClinicalToolkit.check_drug_interactions,
59
- ClinicalToolkit.flag_clinical_risk
60
- ]
61
-
62
- class LabOrderInput(BaseModel):
63
- test_name: str = Field(..., pattern=r"^[A-Za-z0-9\s-]+$")
64
- rationale: str = Field(..., min_length=20)
65
- priority: ClinicalPriority = ClinicalPriority.ROUTINE
66
-
67
- @tool("order_lab_test", args_schema=LabOrderInput)
68
- def order_lab_test(test_name: str, rationale: str,
69
- priority: ClinicalPriority) -> Dict[str, Any]:
70
- """Standardized lab ordering with clinical validation"""
71
- # Implementation details...
72
- return {"status": "ordered", "details": {...}}
73
-
74
- class PrescriptionSafetyCheck(BaseModel):
75
- medication: str
76
- rxcui: Optional[str]
77
- contraindications: List[str]
78
- # Additional safety fields...
79
-
80
- @classmethod
81
- def validate_prescription(cls, rx_data: Dict) -> PrescriptionSafetyCheck:
82
- """Pharmaceutical safety validation pipeline"""
83
- # Comprehensive validation logic...
84
- return PrescriptionSafetyCheck(...)
85
-
86
- # ── State Management Engine ─────────────��────────────────────────────
87
- class ClinicalStateManager:
88
- @staticmethod
89
- def initialize_state(patient_data: Dict) -> ClinicalState:
90
- """Create validated initial state with clinical context"""
91
- return {
92
- "messages": [
93
- SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT),
94
- HumanMessage(content="Initiate clinical consultation")
95
- ],
96
- "patient_data": ClinicalValidator.sanitize_patient_data(patient_data),
97
- "safety_warnings": [],
98
- "workflow_metadata": {
99
- "iterations": 0,
100
- "active_alerts": 0,
101
- "safety_override": False
102
- },
103
- "execution_log": []
104
- }
105
-
106
- @staticmethod
107
- def propagate_state(previous: ClinicalState,
108
- updates: Dict) -> ClinicalState:
109
- """State transition with clinical context preservation"""
110
- preserved_fields = {
111
- 'patient_data': previous['patient_data'],
112
- 'workflow_metadata': {
113
- **previous['workflow_metadata'],
114
- **updates.get('workflow_metadata', {})
115
- }
116
- }
117
- return ClinicalValidator.validate_state({
118
- **preserved_fields,
119
- **updates
120
- })
121
 
122
- # ── Clinical Workflow Nodes ─────────────────────────────────────────
123
- class ClinicalWorkflowNodes:
124
- @staticmethod
125
- def agent_node(state: ClinicalState) -> ClinicalState:
126
- """FDA-compliant clinical reasoning engine"""
127
- ClinicalValidator.check_iteration_limit(state)
128
-
129
- try:
130
- response = ClinicalLLM.invoke(state)
131
- new_state = ClinicalStateManager.propagate_state(state, {
132
- "messages": [response],
133
- "workflow_metadata": {
134
- "iterations": state["workflow_metadata"]["iterations"] + 1
135
- }
136
- })
137
-
138
- if ClinicalTerminationCriteria.should_terminate(new_state):
139
- return ClinicalWorkflowNodes.apply_termination_protocol(new_state)
140
-
141
- return new_state
142
- except CriticalClinicalError as e:
143
- return ClinicalErrorHandler.handle_critical_error(state, e)
144
-
145
- @staticmethod
146
- def tool_node(state: ClinicalState) -> ClinicalState:
147
- """HIPAA-compliant tool execution with safety audit"""
148
- ClinicalSafetyEngine.pre_execution_checks(state)
149
-
150
- tool_results = []
151
- for tool_call in state["messages"][-1].tool_calls:
152
- result = ClinicalToolExecutor.execute_with_audit(tool_call)
153
- tool_results.append(result)
154
-
155
- if result['category'] == "DRUG_ORDER":
156
- ClinicalSafetyEngine.post_drug_order_checks(result)
157
-
158
- return ClinicalStateManager.propagate_state(state, {
159
- "messages": [ToolMessage(...)],
160
- "safety_warnings": ClinicalSafetyEngine.aggregate_warnings(tool_results)
161
- })
162
 
163
- # ── Safety Subsystems ───────────────────────────────────────────────
164
- class ClinicalSafetyEngine:
165
- @staticmethod
166
- def enforce_prescription_rules(tool_calls: List) -> None:
167
- """Hard requirements for medication orders"""
168
- drug_orders = [tc for tc in tool_calls if tc.name == "prescribe_medication"]
169
- interaction_checks = [tc for tc in tool_calls
170
- if tc.name == "check_drug_interactions"]
171
-
172
- if ClinicalConfig.DRUG_CHECK_REQUIRED:
173
- for rx in drug_orders:
174
- if not any(ic.args['medication'] == rx.args['medication']
175
- for ic in interaction_checks):
176
- raise CriticalSafetyViolation(
177
- f"Missing interaction check for {rx.args['medication']}"
178
- )
179
-
180
- class ClinicalTerminationCriteria:
181
- @staticmethod
182
- def should_terminate(state: ClinicalState) -> bool:
183
- """Multi-factor clinical conversation termination"""
184
- metadata = state["workflow_metadata"]
185
- return any([
186
- metadata["iterations"] >= ClinicalConfig.MAX_ITERATIONS,
187
- metadata["active_alerts"] > 3,
188
- "terminate consultation" in state["messages"][-1].content.lower()
189
- ])
190
-
191
- # ── Execution Framework ─────────────────────────────────────────────
192
- class ClinicalWorkflow:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def __init__(self):
194
- self.workflow = self._build_workflow()
195
- self.toolkit = ClinicalToolkit.get_essential_tools()
196
- self.llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1)
197
-
198
- def _build_workflow(self) -> StateGraph:
199
- """Construct ISO 13485-compliant clinical workflow"""
200
- workflow = StateGraph(ClinicalState)
201
-
202
- workflow.add_node("clinical_reasoning", ClinicalWorkflowNodes.agent_node)
203
- workflow.add_node("tool_execution", ClinicalWorkflowNodes.tool_node)
204
- workflow.add_node("safety_review", ClinicalSafetyProtocols.review_node)
205
-
206
- workflow.set_entry_point("clinical_reasoning")
207
-
208
- workflow.add_conditional_edges(
209
- "clinical_reasoning",
210
- ClinicalDecisionRouter.route_agent_output,
211
- {
212
- "require_tools": "tool_execution",
213
- "need_safety_review": "safety_review",
214
- "final_output": END
 
 
 
 
 
 
 
 
 
 
 
215
  }
216
- )
217
-
218
- workflow.add_edge("tool_execution", "clinical_reasoning")
219
- workflow.add_edge("safety_review", "clinical_reasoning")
220
-
221
- return workflow.compile()
222
-
223
- def execute_consultation(self, patient_data: Dict) -> ClinicalState:
224
- """Execute full clinical workflow with safety audits"""
225
- initial_state = ClinicalStateManager.initialize_state(patient_data)
226
- return self.workflow.invoke(
227
- initial_state,
228
- config={"recursion_limit": ClinicalConfig.MAX_ITERATIONS + ClinicalConfig.RECURSION_BUFFER}
229
- )
230
-
231
- # ── Usage Example ───────────────────────────────────────────────────
232
- if __name__ == "__main__":
233
- # Initialize clinical environment
234
- ClinicalValidator.validate_environment()
235
-
236
- # Sample patient scenario
237
- complex_case = {
238
- "demographics": {"age": 68, "sex": "F", "weight_kg": 82},
239
- "presenting_complaint": "Chest pain radiating to left arm",
240
- "medical_history": ["HTN", "Type 2 DM", "HLD"],
241
- "current_meds": ["Atenolol 50mg daily", "Simvastatin 40mg HS"]
242
- }
243
-
244
- # Execute clinical workflow
245
- workflow = ClinicalWorkflow()
246
- result = workflow.execute_consultation(complex_case)
247
-
248
- # Generate clinical summary
249
- final_report = ClinicalDocumentation.generate_report(result)
250
- print(json.dumps(final_report, indent=2))
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
4
  import logging
5
+ import traceback
 
6
  from functools import lru_cache
7
+ from typing import List, Dict, Any, Optional, TypedDict
8
 
9
  import requests
 
10
  from langchain_groq import ChatGroq
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
13
+ from langchain_core.pydantic_v1 import BaseModel, Field
14
+ from langchain_core.tools import tool
15
  from langgraph.prebuilt import ToolExecutor
16
+ from langgraph.graph import StateGraph, END
17
+
18
+ # ── Logging Configuration ──────────────────────────────────────────────
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+ # ── Environment Variables ──────────────────────────────────────────────
23
+ UMLS_API_KEY = os.getenv("UMLS_API_KEY")
24
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
+ TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
26
+
27
+ if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
28
+ logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
29
+ raise RuntimeError("Missing required API keys")
30
+
31
+ # ── Agent Configuration ──────────────────────────────────────────────
32
+ AGENT_MODEL_NAME = "llama3-70b-8192"
33
+ AGENT_TEMPERATURE = 0.1
34
+ MAX_SEARCH_RESULTS = 3
35
+
36
+ class ClinicalPrompts:
37
+ SYSTEM_PROMPT = """
38
+ You are SynapseAI, an expert AI clinical assistant engaged in an interactive consultation...
39
+ [SYSTEM PROMPT CONTENT HERE]
40
+ """
41
+
42
+ # ── Helper Functions ─────────────────────────────────────────────────────
43
+ UMLS_AUTH_ENDPOINT = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
44
+ RXNORM_API_BASE = "https://rxnav.nlm.nih.gov/REST"
45
+ OPENFDA_API_BASE = "https://api.fda.gov/drug/label.json"
46
+
47
+ @lru_cache(maxsize=256)
48
+ def get_rxcui(drug_name: str) -> Optional[str]:
49
+ """Lookup RxNorm CUI for a given drug name."""
50
+ drug_name = (drug_name or "").strip()
51
+ if not drug_name:
52
+ return None
53
+ logger.info(f"Looking up RxCUI for '{drug_name}'")
54
+ try:
55
+ params = {"name": drug_name, "search": 1}
56
+ r = requests.get(f"{RXNORM_API_BASE}/rxcui.json", params=params, timeout=10)
57
+ r.raise_for_status()
58
+ ids = r.json().get("idGroup", {}).get("rxnormId")
59
+ if ids:
60
+ logger.info(f"Found RxCUI {ids[0]} for '{drug_name}'")
61
+ return ids[0]
62
+ r = requests.get(f"{RXNORM_API_BASE}/drugs.json", params={"name": drug_name}, timeout=10)
63
+ r.raise_for_status()
64
+ for grp in r.json().get("drugGroup", {}).get("conceptGroup", []):
65
+ props = grp.get("conceptProperties")
66
+ if props:
67
+ logger.info(f"Found RxCUI {props[0]['rxcui']} via /drugs for '{drug_name}'")
68
+ return props[0]["rxcui"]
69
+ except Exception:
70
+ logger.exception(f"Error fetching RxCUI for '{drug_name}'")
71
+ return None
72
 
73
+ @lru_cache(maxsize=128)
74
+ def get_openfda_label(rxcui: Optional[str] = None, drug_name: Optional[str] = None) -> Optional[Dict[str, Any]]:
75
+ """Fetch the OpenFDA label for a drug by RxCUI or name."""
76
+ if not (rxcui or drug_name):
77
+ return None
78
+ terms = []
79
+ if rxcui:
80
+ terms.append(f'spl_rxnorm_code:"{rxcui}" OR openfda.rxcui:"{rxcui}"')
81
+ if drug_name:
82
+ dn = drug_name.lower()
83
+ terms.append(f'(openfda.brand_name:"{dn}" OR openfda.generic_name:"{dn}")')
84
+ query = " OR ".join(terms)
85
+ logger.info(f"Looking up OpenFDA label with query: {query}")
86
+ try:
87
+ r = requests.get(OPENFDA_API_BASE, params={"search": query, "limit": 1}, timeout=15)
88
+ r.raise_for_status()
89
+ results = r.json().get("results", [])
90
+ if results:
91
+ return results[0]
92
+ except Exception:
93
+ logger.exception("Error fetching OpenFDA label")
94
+ return None
95
+
96
+ def search_text_list(texts: List[str], terms: List[str]) -> List[str]:
97
+ """Return highlighted snippets from a list of texts containing any of the search terms."""
98
+ snippets = []
99
+ lowers = [t.lower() for t in terms if t]
100
+ for text in texts or []:
101
+ tl = text.lower()
102
+ for term in lowers:
103
+ if term in tl:
104
+ i = tl.find(term)
105
+ start = max(0, i - 50)
106
+ end = min(len(text), i + len(term) + 100)
107
+ snippet = text[start:end]
108
+ snippet = re.sub(f"({re.escape(term)})", r"**\1**", snippet, flags=re.IGNORECASE)
109
+ snippets.append(f"...{snippet}...")
110
+ break
111
+ return snippets
112
+
113
+ def parse_bp(bp: str) -> Optional[tuple[int, int]]:
114
+ """Parse 'SYS/DIA' blood pressure string into a (sys, dia) tuple."""
115
+ if m := re.match(r"(\d{1,3})\s*/\s*(\d{1,3})", (bp or "").strip()):
116
+ return int(m.group(1)), int(m.group(2))
117
+ return None
118
+
119
+ def check_red_flags(patient_data: Dict[str, Any]) -> List[str]:
120
+ """Identify immediate red flags from patient_data."""
121
+ flags: List[str] = []
122
+ hpi = patient_data.get("hpi", {})
123
+ vitals = patient_data.get("vitals", {})
124
+ syms = [s.lower() for s in hpi.get("symptoms", []) if isinstance(s, str)]
125
+ mapping = {
126
+ "chest pain": "Chest pain reported",
127
+ "shortness of breath": "Shortness of breath reported",
128
+ "severe headache": "Severe headache reported",
129
+ "syncope": "Syncope reported",
130
+ "hemoptysis": "Hemoptysis reported"
131
  }
132
+ for term, desc in mapping.items():
133
+ if term in syms:
134
+ flags.append(f"Red Flag: {desc}.")
135
+ temp = vitals.get("temp_c")
136
+ hr = vitals.get("hr_bpm")
137
+ rr = vitals.get("rr_rpm")
138
+ spo2 = vitals.get("spo2_percent")
139
+ bp = parse_bp(vitals.get("bp_mmhg", ""))
140
+ if temp is not None and temp >= 38.5:
141
+ flags.append(f"Red Flag: Fever ({temp}Β°C).")
142
+ if hr is not None:
143
+ if hr >= 120:
144
+ flags.append(f"Red Flag: Tachycardia ({hr} bpm).")
145
+ if hr <= 50:
146
+ flags.append(f"Red Flag: Bradycardia ({hr} bpm).")
147
+ if rr is not None and rr >= 24:
148
+ flags.append(f"Red Flag: Tachypnea ({rr} rpm).")
149
+ if spo2 is not None and spo2 <= 92:
150
+ flags.append(f"Red Flag: Hypoxia ({spo2}%).")
151
+ if bp:
152
+ sys, dia = bp
153
+ if sys >= 180 or dia >= 110:
154
+ flags.append(f"Red Flag: Hypertensive urgency/emergency ({sys}/{dia} mmHg).")
155
+ if sys <= 90 or dia <= 60:
156
+ flags.append(f"Red Flag: Hypotension ({sys}/{dia} mmHg).")
157
+ return list(dict.fromkeys(flags))
158
 
159
+ def format_patient_data_for_prompt(data: Dict[str, Any]) -> str:
160
+ """Format patient_data dict into a markdown-like prompt section."""
161
+ if not data:
162
+ return "No patient data provided."
163
+ lines: List[str] = []
164
+ for section, value in data.items():
165
+ title = section.replace("_", " ").title()
166
+ if isinstance(value, dict) and any(value.values()):
167
+ lines.append(f"**{title}:**")
168
+ for k, v in value.items():
169
+ if v:
170
+ lines.append(f"- {k.replace('_',' ').title()}: {v}")
171
+ elif isinstance(value, list) and value:
172
+ lines.append(f"**{title}:** {', '.join(map(str, value))}")
173
+ elif value:
174
+ lines.append(f"**{title}:** {value}")
175
+ return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ # ── Tool Input Schemas ─────────────────────────────────────────────────────
178
+ class LabOrderInput(BaseModel):
179
+ test_name: str = Field(...)
180
+ reason: str = Field(...)
181
+ priority: str = Field("Routine")
182
+
183
+ class PrescriptionInput(BaseModel):
184
+ medication_name: str = Field(...)
185
+ dosage: str = Field(...)
186
+ route: str = Field(...)
187
+ frequency: str = Field(...)
188
+ duration: str = Field("As directed")
189
+ reason: str = Field(...)
190
+
191
+ class InteractionCheckInput(BaseModel):
192
+ potential_prescription: str
193
+ current_medications: Optional[List[str]] = Field(None)
194
+ allergies: Optional[List[str]] = Field(None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ class FlagRiskInput(BaseModel):
197
+ risk_description: str = Field(...)
198
+ urgency: str = Field("High")
199
+
200
+ # ── Tool Implementations ───────────────────────────────────────────────────
201
+ @tool("order_lab_test", args_schema=LabOrderInput)
202
+ def order_lab_test(test_name: str, reason: str, priority: str = "Routine") -> str:
203
+ """
204
+ Place an order for a laboratory test.
205
+ """
206
+ logger.info(f"Ordering lab test: {test_name}, reason: {reason}, priority: {priority}")
207
+ return json.dumps({
208
+ "status": "success",
209
+ "message": f"Lab Ordered: {test_name} ({priority})",
210
+ "details": f"Reason: {reason}"
211
+ })
212
+
213
+ @tool("prescribe_medication", args_schema=PrescriptionInput)
214
+ def prescribe_medication(
215
+ medication_name: str,
216
+ dosage: str,
217
+ route: str,
218
+ frequency: str,
219
+ duration: str,
220
+ reason: str
221
+ ) -> str:
222
+ """
223
+ Prepare a medication prescription.
224
+ """
225
+ logger.info(f"Preparing prescription: {medication_name} {dosage}, route: {route}, freq: {frequency}")
226
+ return json.dumps({
227
+ "status": "success",
228
+ "message": f"Prescription Prepared: {medication_name} {dosage} {route} {frequency}",
229
+ "details": f"Duration: {duration}. Reason: {reason}"
230
+ })
231
+
232
+ @tool("check_drug_interactions", args_schema=InteractionCheckInput)
233
+ def check_drug_interactions(
234
+ potential_prescription: str,
235
+ current_medications: Optional[List[str]] = None,
236
+ allergies: Optional[List[str]] = None
237
+ ) -> str:
238
+ """
239
+ Check for drug–drug interactions and allergy risks.
240
+ """
241
+ logger.info(f"Checking interactions for: {potential_prescription}")
242
+ warnings: List[str] = []
243
+ pm = [m.lower().strip() for m in (current_medications or []) if m]
244
+ al = [a.lower().strip() for a in (allergies or []) if a]
245
+ if potential_prescription.lower().strip() in al:
246
+ warnings.append(f"CRITICAL ALLERGY: Patient allergic to '{potential_prescription}'.")
247
+ rxcui = get_rxcui(potential_prescription)
248
+ label = get_openfda_label(rxcui=rxcui, drug_name=potential_prescription)
249
+ if not (rxcui or label):
250
+ warnings.append(f"INFO: Could not identify '{potential_prescription}'. Checks may be incomplete.")
251
+ for section in ("contraindications", "warnings_and_cautions", "warnings"):
252
+ items = label.get(section) if label else None
253
+ if isinstance(items, list):
254
+ snippets = search_text_list(items, al)
255
+ if snippets:
256
+ warnings.append(f"ALLERGY RISK ({section}): {'; '.join(snippets)}")
257
+ for med in pm:
258
+ mrxcui = get_rxcui(med)
259
+ mlabel = get_openfda_label(rxcui=mrxcui, drug_name=med)
260
+ for sec in ("drug_interactions",):
261
+ for src_label, src_name in ((label, potential_prescription), (mlabel, med)):
262
+ items = src_label.get(sec) if src_label else None
263
+ if isinstance(items, list):
264
+ snippets = search_text_list(items, [med if src_name == potential_prescription else potential_prescription])
265
+ if snippets:
266
+ warnings.append(f"Interaction ({src_name} label): {'; '.join(snippets)}")
267
+ status = "warning" if warnings else "clear"
268
+ message = (
269
+ f"{len(warnings)} issue(s) found for '{potential_prescription}'."
270
+ if warnings else
271
+ f"No major interactions or allergy issues identified for '{potential_prescription}'."
272
+ )
273
+ return json.dumps({"status": status, "message": message, "warnings": warnings})
274
+
275
+ @tool("flag_risk", args_schema=FlagRiskInput)
276
+ def flag_risk(risk_description: str, urgency: str = "High") -> str:
277
+ """
278
+ Flag a clinical risk with given urgency.
279
+ """
280
+ logger.info(f"Flagging risk: {risk_description} (urgency={urgency})")
281
+ return json.dumps({
282
+ "status": "flagged",
283
+ "message": f"Risk '{risk_description}' flagged with {urgency} urgency."
284
+ })
285
+
286
+ # Include the Tavily search tool
287
+ search_tool = TavilySearchResults(max_results=MAX_SEARCH_RESULTS, name="tavily_search_results")
288
+ all_tools = [order_lab_test, prescribe_medication, check_drug_interactions, flag_risk, search_tool]
289
+
290
+ # ── LLM & Tool Executor ───────────────────────────────────────────────────
291
+ llm = ChatGroq(temperature=AGENT_TEMPERATURE, model=AGENT_MODEL_NAME)
292
+ model_with_tools = llm.bind_tools(all_tools)
293
+ tool_executor = ToolExecutor(all_tools)
294
+
295
+ # ── State Definition ─────────────────────────────────────────────────────
296
+ class AgentState(TypedDict):
297
+ messages: List[Any]
298
+ patient_data: Optional[Dict[str, Any]]
299
+ summary: Optional[str]
300
+ interaction_warnings: Optional[List[str]]
301
+ done: Optional[bool]
302
+ iterations: Optional[int]
303
+
304
+ # Helper to propagate state fields between nodes
305
+ def propagate_state(new: Dict[str, Any], old: Dict[str, Any]) -> Dict[str, Any]:
306
+ for key in ["iterations", "done", "patient_data", "summary", "interaction_warnings"]:
307
+ if key in old and key not in new:
308
+ new[key] = old[key]
309
+ return new
310
+
311
+ # ── Graph Nodes ─────────────────────────────────────────────────────────
312
+ def agent_node(state: AgentState) -> Dict[str, Any]:
313
+ if state.get("done", False):
314
+ return state
315
+ msgs = state.get("messages", [])
316
+ if not msgs or not isinstance(msgs[0], SystemMessage):
317
+ msgs = [SystemMessage(content=ClinicalPrompts.SYSTEM_PROMPT)] + msgs
318
+ logger.info(f"Invoking LLM with {len(msgs)} messages")
319
+ try:
320
+ response = model_with_tools.invoke(msgs)
321
+ new_state = {"messages": [response]}
322
+ return propagate_state(new_state, state)
323
+ except Exception as e:
324
+ logger.exception("Error in agent_node")
325
+ new_state = {"messages": [AIMessage(content=f"Error: {e}")]}
326
+ return propagate_state(new_state, state)
327
+
328
+ def tool_node(state: AgentState) -> Dict[str, Any]:
329
+ if state.get("done", False):
330
+ return state
331
+ last = state.get("messages", [])[-1]
332
+ if not isinstance(last, AIMessage) or not getattr(last, "tool_calls", None):
333
+ logger.warning("tool_node invoked without pending tool_calls")
334
+ new_state = {"messages": []}
335
+ return propagate_state(new_state, state)
336
+ calls = last.tool_calls
337
+ blocked_ids = set()
338
+ for call in calls:
339
+ if call["name"] == "prescribe_medication":
340
+ med = call["args"].get("medication_name", "").lower()
341
+ if not any(
342
+ c["name"] == "check_drug_interactions" and
343
+ c["args"].get("potential_prescription", "").lower() == med
344
+ for c in calls
345
+ ):
346
+ logger.warning(f"Blocking prescribe_medication for '{med}' without interaction check")
347
+ blocked_ids.add(call["id"])
348
+ to_execute = [c for c in calls if c["id"] not in blocked_ids]
349
+ pd = state.get("patient_data", {})
350
+ for call in to_execute:
351
+ if call["name"] == "check_drug_interactions":
352
+ call["args"].setdefault("current_medications", pd.get("medications", {}).get("current", []))
353
+ call["args"].setdefault("allergies", pd.get("allergies", []))
354
+ messages: List[ToolMessage] = []
355
+ warnings: List[str] = []
356
+ try:
357
+ responses = tool_executor.batch(to_execute, return_exceptions=True)
358
+ for call, resp in zip(to_execute, responses):
359
+ if isinstance(resp, Exception):
360
+ logger.exception(f"Error executing tool {call['name']}")
361
+ content = json.dumps({"status": "error", "message": str(resp)})
362
+ else:
363
+ content = str(resp)
364
+ if call["name"] == "check_drug_interactions":
365
+ data = json.loads(content)
366
+ if data.get("status") == "warning":
367
+ warnings.extend(data.get("warnings", []))
368
+ messages.append(ToolMessage(content=content, tool_call_id=call["id"], name=call["name"]))
369
+ except Exception as e:
370
+ logger.exception("Critical error in tool_node")
371
+ for call in to_execute:
372
+ messages.append(ToolMessage(
373
+ content=json.dumps({"status": "error", "message": str(e)}),
374
+ tool_call_id=call["id"],
375
+ name=call["name"]
376
+ ))
377
+ new_state = {"messages": messages, "interaction_warnings": warnings or None}
378
+ return propagate_state(new_state, state)
379
+
380
+ def reflection_node(state: AgentState) -> Dict[str, Any]:
381
+ if state.get("done", False):
382
+ return state
383
+ warns = state.get("interaction_warnings")
384
+ if not warns:
385
+ logger.warning("reflection_node called without warnings")
386
+ new_state = {"messages": []}
387
+ return propagate_state(new_state, state)
388
+ triggering = None
389
+ for msg in reversed(state.get("messages", [])):
390
+ if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
391
+ triggering = msg
392
+ break
393
+ if not triggering:
394
+ new_state = {"messages": [AIMessage(content="Internal Error: reflection context missing.")]}
395
+ return propagate_state(new_state, state)
396
+ prompt = (
397
+ "You are SynapseAI, performing a focused safety review of the following plan:\n\n"
398
+ f"{triggering.content}\n\n"
399
+ "Highlight any issues based on these warnings:\n" +
400
+ "\n".join(f"- {w}" for w in warns)
401
+ )
402
+ try:
403
+ resp = llm.invoke([SystemMessage(content="Safety reflection"), HumanMessage(content=prompt)])
404
+ new_state = {"messages": [AIMessage(content=resp.content)]}
405
+ return propagate_state(new_state, state)
406
+ except Exception as e:
407
+ logger.exception("Error during reflection")
408
+ new_state = {"messages": [AIMessage(content=f"Error during reflection: {e}")]}
409
+ return propagate_state(new_state, state)
410
+
411
+ # ── Routing Functions ────────────────────────────────────────────────────
412
+ def should_continue(state: AgentState) -> str:
413
+ state.setdefault("iterations", 0)
414
+ state["iterations"] += 1
415
+ logger.info(f"Iteration count: {state['iterations']}")
416
+ # When iterations exceed threshold, force final output and terminate.
417
+ if state["iterations"] >= 4:
418
+ state.setdefault("messages", []).append(AIMessage(content="Final output: consultation complete."))
419
+ state["done"] = True
420
+ return "end_conversation_turn"
421
+ if not state.get("messages"):
422
+ state["done"] = True
423
+ return "end_conversation_turn"
424
+ last = state["messages"][-1]
425
+ if not isinstance(last, AIMessage):
426
+ state["done"] = True
427
+ return "end_conversation_turn"
428
+ if getattr(last, "tool_calls", None):
429
+ return "continue_tools"
430
+ if "consultation complete" in last.content.lower():
431
+ state["done"] = True
432
+ return "end_conversation_turn"
433
+ state["done"] = False
434
+ return "agent"
435
+
436
+ def after_tools_router(state: AgentState) -> str:
437
+ # Instead of routing back to agent, route reflection to END to break the cycle.
438
+ if state.get("interaction_warnings"):
439
+ return "reflection"
440
+ return "end_conversation_turn"
441
+
442
+ # ── ClinicalAgent ─────────────────────────────────────────────────────────
443
+ class ClinicalAgent:
444
  def __init__(self):
445
+ logger.info("Building ClinicalAgent workflow")
446
+ wf = StateGraph(AgentState)
447
+ wf.add_node("agent", agent_node)
448
+ wf.add_node("tools", tool_node)
449
+ wf.add_node("reflection", reflection_node)
450
+ wf.set_entry_point("agent")
451
+ wf.add_conditional_edges("agent", should_continue, {
452
+ "continue_tools": "tools",
453
+ "end_conversation_turn": END
454
+ })
455
+ wf.add_conditional_edges("tools", after_tools_router, {
456
+ "reflection": "reflection",
457
+ "end_conversation_turn": END
458
+ })
459
+ # Removed the edge from reflection back to agent to break the cycle.
460
+ self.graph_app = wf.compile()
461
+ logger.info("ClinicalAgent ready")
462
+
463
+ def invoke_turn(self, state: Dict[str, Any]) -> Dict[str, Any]:
464
+ try:
465
+ # Increase recursion limit if needed.
466
+ result = self.graph_app.invoke(state, {"recursion_limit": 100})
467
+ result.setdefault("summary", state.get("summary"))
468
+ result.setdefault("interaction_warnings", None)
469
+ return result
470
+ except Exception as e:
471
+ logger.exception("Error during graph invocation")
472
+ return {
473
+ "messages": state.get("messages", []) + [AIMessage(content=f"Error: {e}")],
474
+ "patient_data": state.get("patient_data"),
475
+ "summary": state.get("summary"),
476
+ "interaction_warnings": None
477
  }