mgbam commited on
Commit
5158329
·
verified ·
1 Parent(s): bf46e9b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +7 -4
agent.py CHANGED
@@ -23,7 +23,6 @@ logging.basicConfig(level=logging.INFO)
23
  UMLS_API_KEY = os.getenv("UMLS_API_KEY")
24
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
26
-
27
  if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
28
  logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
29
  raise RuntimeError("Missing required API keys")
@@ -43,8 +42,8 @@ class ClinicalPrompts:
43
  def wrap_message(msg: Any) -> AIMessage:
44
  """
45
  Ensures the given message is an AIMessage.
46
- If it is a dict, it extracts the 'content' field (or serializes the dict).
47
- Otherwise, it converts the message to a string.
48
  """
49
  if isinstance(msg, AIMessage):
50
  return msg
@@ -358,6 +357,7 @@ def tool_node(state: AgentState) -> Dict[str, Any]:
358
  new_state = {"messages": []}
359
  return propagate_state(new_state, state)
360
  last = wrap_message(messages_list[-1])
 
361
  tool_calls = last.__dict__.get("tool_calls")
362
  if not (isinstance(last, AIMessage) and tool_calls):
363
  logger.warning("tool_node invoked without pending tool_calls")
@@ -463,6 +463,7 @@ def should_continue(state: AgentState) -> str:
463
  state["done"] = True
464
  return "end_conversation_turn"
465
  state["done"] = False
 
466
  return "start"
467
 
468
  def after_tools_router(state: AgentState) -> str:
@@ -479,15 +480,17 @@ class ClinicalAgent:
479
  wf.add_node("tools", tool_node)
480
  wf.add_node("reflection", reflection_node)
481
  wf.set_entry_point("start")
 
482
  wf.add_conditional_edges("start", should_continue, {
483
  "continue_tools": "tools",
 
484
  "end_conversation_turn": END
485
  })
486
  wf.add_conditional_edges("tools", after_tools_router, {
487
  "reflection": "reflection",
488
  "end_conversation_turn": END
489
  })
490
- # Removed edge from reflection back to start.
491
  self.graph_app = wf.compile()
492
  logger.info("ClinicalAgent ready")
493
 
 
23
  UMLS_API_KEY = os.getenv("UMLS_API_KEY")
24
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
25
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
 
26
  if not all([UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY]):
27
  logger.error("Missing one or more required API keys: UMLS_API_KEY, GROQ_API_KEY, TAVILY_API_KEY")
28
  raise RuntimeError("Missing required API keys")
 
42
  def wrap_message(msg: Any) -> AIMessage:
43
  """
44
  Ensures the given message is an AIMessage.
45
+ If it is a dict, extracts the 'content' field (or serializes the dict).
46
+ Otherwise, converts the message to a string.
47
  """
48
  if isinstance(msg, AIMessage):
49
  return msg
 
357
  new_state = {"messages": []}
358
  return propagate_state(new_state, state)
359
  last = wrap_message(messages_list[-1])
360
+ # Safely retrieve pending tool_calls from the message's __dict__
361
  tool_calls = last.__dict__.get("tool_calls")
362
  if not (isinstance(last, AIMessage) and tool_calls):
363
  logger.warning("tool_node invoked without pending tool_calls")
 
463
  state["done"] = True
464
  return "end_conversation_turn"
465
  state["done"] = False
466
+ # Return "start" to loop back.
467
  return "start"
468
 
469
  def after_tools_router(state: AgentState) -> str:
 
480
  wf.add_node("tools", tool_node)
481
  wf.add_node("reflection", reflection_node)
482
  wf.set_entry_point("start")
483
+ # Note: Added a "start" branch in the conditional edges.
484
  wf.add_conditional_edges("start", should_continue, {
485
  "continue_tools": "tools",
486
+ "start": "start",
487
  "end_conversation_turn": END
488
  })
489
  wf.add_conditional_edges("tools", after_tools_router, {
490
  "reflection": "reflection",
491
  "end_conversation_turn": END
492
  })
493
+ # Removed edge from reflection back to start to break the cycle.
494
  self.graph_app = wf.compile()
495
  logger.info("ClinicalAgent ready")
496