wt002 commited on
Commit
dae11a5
·
verified ·
1 Parent(s): f75a052

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +18 -21
agent.py CHANGED
@@ -1094,24 +1094,29 @@ def tools_condition(state: dict) -> str:
1094
  from langgraph.graph import StateGraph
1095
  from langchain_core.messages import SystemMessage
1096
  from langchain_core.runnables import RunnableLambda
1097
-
1098
  def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
 
1099
  llm = get_llm(provider, model_config)
1100
 
1101
- tools = [wiki_search, calculator, web_search, arxiv_search,
1102
- get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query]
 
 
 
1103
 
 
1104
  global tool_map
1105
  tool_map = {t.name: t for t in tools}
1106
 
 
1107
  if hasattr(llm, "bind_tools"):
1108
  llm_with_tools = llm.bind_tools(tools)
1109
  else:
1110
- llm_with_tools = llm # fallback: no tool binding
1111
 
1112
-
1113
  sys_msg = SystemMessage(content="You are a helpful assistant.")
1114
 
 
1115
  retriever = RunnableLambda(lambda state: {
1116
  **state,
1117
  "retrieved_docs": vector_store.similarity_search(state["input"])
@@ -1122,35 +1127,27 @@ def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
1122
  "messages": [sys_msg] + state["messages"]
1123
  })
1124
 
1125
- call_llm = llm_with_tools
1126
 
 
1127
  builder = StateGraph(AgentState)
1128
  builder.add_node("retriever", retriever)
1129
  builder.add_node("assistant", assistant)
1130
  builder.add_node("call_llm", call_llm)
1131
  builder.add_node("call_tool", tool_dispatcher)
 
1132
 
 
1133
  builder.set_entry_point("retriever")
1134
  builder.add_edge("retriever", "assistant")
1135
  builder.add_edge("assistant", "call_llm")
1136
- builder.add_node("end", lambda state: state) # Add a terminal node
1137
-
1138
  builder.add_conditional_edges("call_llm", should_call_tool, {
1139
  "call_tool": "call_tool",
1140
- "end": "end"
1141
  })
1142
 
1143
- builder.add_edge("call_tool", "call_llm")
1144
 
1145
  return builder.compile()
1146
-
1147
-
1148
-
1149
- # call build_graph AFTER it’s defined
1150
- graph = build_graph(vector_store, provider, model_config)
1151
- agent = graph
1152
-
1153
- # Now you can use the agent like this:
1154
- result = agent.invoke({"messages": [HumanMessage(content=question)]})
1155
-
1156
-
 
1094
  from langgraph.graph import StateGraph
1095
  from langchain_core.messages import SystemMessage
1096
  from langchain_core.runnables import RunnableLambda
 
1097
  def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
1098
+ # Get LLM
1099
  llm = get_llm(provider, model_config)
1100
 
1101
+ # Define available tools
1102
+ tools = [
1103
+ wiki_search, calculator, web_search, arxiv_search,
1104
+ get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query
1105
+ ]
1106
 
1107
+ # Tool mapping (global if needed elsewhere)
1108
  global tool_map
1109
  tool_map = {t.name: t for t in tools}
1110
 
1111
+ # Bind tools only if LLM supports it
1112
  if hasattr(llm, "bind_tools"):
1113
  llm_with_tools = llm.bind_tools(tools)
1114
  else:
1115
+ llm_with_tools = llm # fallback for non-tool-aware models
1116
 
 
1117
  sys_msg = SystemMessage(content="You are a helpful assistant.")
1118
 
1119
+ # Define nodes as runnables
1120
  retriever = RunnableLambda(lambda state: {
1121
  **state,
1122
  "retrieved_docs": vector_store.similarity_search(state["input"])
 
1127
  "messages": [sys_msg] + state["messages"]
1128
  })
1129
 
1130
+ call_llm = llm_with_tools # already configured
1131
 
1132
+ # Start building the graph
1133
  builder = StateGraph(AgentState)
1134
  builder.add_node("retriever", retriever)
1135
  builder.add_node("assistant", assistant)
1136
  builder.add_node("call_llm", call_llm)
1137
  builder.add_node("call_tool", tool_dispatcher)
1138
+ builder.add_node("end", lambda state: state) # Add explicit end node
1139
 
1140
+ # Define graph flow
1141
  builder.set_entry_point("retriever")
1142
  builder.add_edge("retriever", "assistant")
1143
  builder.add_edge("assistant", "call_llm")
1144
+
 
1145
  builder.add_conditional_edges("call_llm", should_call_tool, {
1146
  "call_tool": "call_tool",
1147
+ "end": "end" # ✅ fixed: must point to actual "end" node
1148
  })
1149
 
1150
+ builder.add_edge("call_tool", "call_llm") # loop back after tool call
1151
 
1152
  return builder.compile()
1153
+