wt002 commited on
Commit
a28328c
·
verified ·
1 Parent(s): 94dcdd5

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +17 -16
agent.py CHANGED
@@ -685,28 +685,28 @@ def process_question(question: str):
685
 
686
 
687
  # Build graph function
688
- def build_graph():
689
- """Build the graph based on provider"""
 
690
  llm = get_llm(provider, model_config)
691
  llm_with_tools = llm.bind_tools(tools)
692
 
693
- # Node
 
694
  def assistant(state: MessagesState):
695
- """Assistant node"""
696
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
697
-
698
  def retriever(state: MessagesState):
699
  user_query = state["messages"][0].content
700
  similar_docs = vector_store.similarity_search(user_query)
701
-
702
  if not similar_docs:
703
- print("No similar docs found in FAISS. Using wiki_search.")
704
- wiki_result = wiki_search.invoke(user_query)
705
  return {
706
  "messages": [
707
  sys_msg,
708
  state["messages"][0],
709
- HumanMessage(content=f"Using Wikipedia search:\n\n{wiki_result['wiki_results']}")
710
  ]
711
  }
712
  else:
@@ -714,22 +714,23 @@ def build_graph():
714
  "messages": [
715
  sys_msg,
716
  state["messages"][0],
717
- HumanMessage(content=f"Reference question:\n\n{similar_docs[0].page_content}")
718
  ]
719
  }
720
 
 
 
 
 
 
721
 
722
-
723
  builder = StateGraph(MessagesState)
724
  builder.add_node("retriever", retriever)
725
  builder.add_node("assistant", assistant)
726
  builder.add_node("tools", ToolNode(tools))
727
- builder.add_edge(START, "retriever")
728
  builder.add_edge("retriever", "assistant")
729
- builder.add_conditional_edges(
730
- "assistant",
731
- tools_condition,
732
- )
733
  builder.add_edge("tools", "assistant")
734
 
735
  # Compile graph
 
685
 
686
 
687
  # Build graph function
688
+ def build_graph(provider: str, model_config: dict):
689
+ from langgraph.prebuilt.tool_node import ToolNode
690
+
691
  llm = get_llm(provider, model_config)
692
  llm_with_tools = llm.bind_tools(tools)
693
 
694
+ sys_msg = SystemMessage(content="You are a helpful assistant.")
695
+
696
  def assistant(state: MessagesState):
 
697
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
698
+
699
  def retriever(state: MessagesState):
700
  user_query = state["messages"][0].content
701
  similar_docs = vector_store.similarity_search(user_query)
702
+
703
  if not similar_docs:
704
+ wiki_result = wiki_tool.run(user_query)
 
705
  return {
706
  "messages": [
707
  sys_msg,
708
  state["messages"][0],
709
+ HumanMessage(content=f"Using Wikipedia search:\n\n{wiki_result}")
710
  ]
711
  }
712
  else:
 
714
  "messages": [
715
  sys_msg,
716
  state["messages"][0],
717
+ HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
718
  ]
719
  }
720
 
721
+ def tools_condition(state: MessagesState) -> str:
722
+ if "use tool" in state["messages"][-1].content.lower():
723
+ return "tools"
724
+ else:
725
+ return END
726
 
 
727
  builder = StateGraph(MessagesState)
728
  builder.add_node("retriever", retriever)
729
  builder.add_node("assistant", assistant)
730
  builder.add_node("tools", ToolNode(tools))
731
+ builder.set_entry_point("retriever")
732
  builder.add_edge("retriever", "assistant")
733
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
734
  builder.add_edge("tools", "assistant")
735
 
736
  # Compile graph