wt002 commited on
Commit
ef3d9bb
·
verified ·
1 Parent(s): 2d14e5a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +56 -36
agent.py CHANGED
@@ -145,7 +145,7 @@ def wiki_search(query: str) -> str:
145
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
146
  for doc in search_docs
147
  ])
148
- return {"wiki_results": formatted_search_docs}
149
 
150
 
151
 
@@ -175,7 +175,7 @@ def web_search(query: str) -> str:
175
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
176
  for doc in search_docs
177
  ])
178
- return {"web_results": formatted_search_docs}
179
 
180
  @tool
181
  def arvix_search(query: str) -> str:
@@ -189,7 +189,7 @@ def arvix_search(query: str) -> str:
189
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
190
  for doc in search_docs
191
  ])
192
- return {"arvix_results": formatted_search_docs}
193
 
194
 
195
 
@@ -414,26 +414,6 @@ question_retriever_tool = create_retriever_tool(
414
  )
415
 
416
 
417
- # -------------------------------
418
- # Step 6: Create LangChain Tools
419
- # -------------------------------
420
- calc_tool = calculator
421
- file_tool = analyze_attachment
422
- web_tool = web_search
423
- wiki_tool = wiki_search
424
- arvix_tool = arvix_search
425
- youtube_tool = get_youtube_transcript
426
- video_tool = extract_video_id
427
- analyze_tool = analyze_attachment
428
- wikiq_tool = wikidata_query
429
-
430
-
431
- # -------------------------------
432
- # Step 7: Create the Planner-Agent Logic
433
- # -------------------------------
434
-
435
- # Define the tools (as you've already done)
436
- tools = [wiki_tool, calc_tool, file_tool, web_tool, arvix_tool, youtube_tool, video_tool, analyze_tool, wikiq_tool]
437
 
438
  # Define the LLM before using it
439
  #llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4"
@@ -693,8 +673,7 @@ model_config = {
693
  }
694
 
695
  def build_graph(provider, model_config):
696
- # from langgraph.prebuilt.tool_node import ToolNode
697
-
698
  def get_llm(provider: str, config: dict):
699
  if provider == "huggingface":
700
  from langchain_huggingface import HuggingFaceEndpoint
@@ -707,24 +686,53 @@ def build_graph(provider, model_config):
707
  )
708
  else:
709
  raise ValueError(f"Unsupported provider: {provider}")
710
-
711
-
712
  llm = get_llm(provider, model_config)
713
- return llm
714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  llm_with_tools = llm.bind_tools(tools)
716
 
 
 
717
 
718
 
719
- sys_msg = SystemMessage(content="You are a helpful assistant.")
720
-
721
- def assistant(state: MessagesState):
722
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
723
 
 
 
 
 
 
724
  def retriever(state: MessagesState):
725
  user_query = state["messages"][0].content
726
  similar_docs = vector_store.similarity_search(user_query)
727
-
728
  if not similar_docs:
729
  wiki_result = wiki_tool.run(user_query)
730
  return {
@@ -742,18 +750,30 @@ def build_graph(provider, model_config):
742
  HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
743
  ]
744
  }
745
-
 
 
 
 
 
746
  def tools_condition(state: MessagesState) -> str:
747
  if "use tool" in state["messages"][-1].content.lower():
748
  return "tools"
749
  else:
750
- return END
751
-
 
752
  builder = StateGraph(MessagesState)
 
 
753
  builder.add_node("retriever", retriever)
754
  builder.add_node("assistant", assistant)
755
  builder.add_node("tools", ToolNode(tools))
 
 
756
  builder.set_entry_point("retriever")
 
 
757
  builder.add_edge("retriever", "assistant")
758
  builder.add_conditional_edges("assistant", tools_condition)
759
  builder.add_edge("tools", "assistant")
 
145
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
146
  for doc in search_docs
147
  ])
148
+ return formatted_search_docs
149
 
150
 
151
 
 
175
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
176
  for doc in search_docs
177
  ])
178
+ return formatted_search_docs
179
 
180
  @tool
181
  def arvix_search(query: str) -> str:
 
189
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
190
  for doc in search_docs
191
  ])
192
+ return formatted_search_docs
193
 
194
 
195
 
 
414
  )
415
 
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  # Define the LLM before using it
419
  #llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # or "gpt-3.5-turbo" "gpt-4"
 
673
  }
674
 
675
  def build_graph(provider, model_config):
676
+ # Step 1: Initialize the LLM
 
677
  def get_llm(provider: str, config: dict):
678
  if provider == "huggingface":
679
  from langchain_huggingface import HuggingFaceEndpoint
 
686
  )
687
  else:
688
  raise ValueError(f"Unsupported provider: {provider}")
689
+
 
690
  llm = get_llm(provider, model_config)
 
691
 
692
+ # -------------------------------
693
+ # Step 6: Define LangChain Tools
694
+ # -------------------------------
695
+ calc_tool = calculator # Math operations tool
696
+ web_tool = web_search # Web search tool
697
+ wiki_tool = wiki_search # Wikipedia search tool
698
+ arvix_tool = arvix_search # Arxiv search tool
699
+ youtube_tool = get_youtube_transcript # YouTube transcript extraction
700
+ video_tool = extract_video_id # Video ID extraction tool
701
+ analyze_tool = analyze_attachment # File analysis tool
702
+ wikiq_tool = wikidata_query # Wikidata query tool
703
+
704
+ # -------------------------------
705
+ # Step 7: Create the Planner-Agent Logic
706
+ # -------------------------------
707
+ # Define tools list
708
+ tools = [
709
+ wiki_tool,
710
+ calc_tool,
711
+ web_tool,
712
+ arvix_tool,
713
+ youtube_tool,
714
+ video_tool,
715
+ analyze_tool,
716
+ wikiq_tool
717
+ ]
718
+
719
+ # Step 8: Bind tools to the LLM
720
  llm_with_tools = llm.bind_tools(tools)
721
 
722
+ # Return the LLM with tools bound
723
+ return llm_with_tools
724
 
725
 
 
 
 
 
726
 
727
+
728
+ # Initialize system message
729
+ sys_msg = SystemMessage(content="You are a helpful assistant.")
730
+
731
+ # Define the retriever function
732
  def retriever(state: MessagesState):
733
  user_query = state["messages"][0].content
734
  similar_docs = vector_store.similarity_search(user_query)
735
+
736
  if not similar_docs:
737
  wiki_result = wiki_tool.run(user_query)
738
  return {
 
750
  HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
751
  ]
752
  }
753
+
754
+ # Define the assistant function
755
+ def assistant(state: MessagesState):
756
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
757
+
758
+ # Define condition for tools usage
759
  def tools_condition(state: MessagesState) -> str:
760
  if "use tool" in state["messages"][-1].content.lower():
761
  return "tools"
762
  else:
763
+ return "END"
764
+
765
+ # Initialize the StateGraph
766
  builder = StateGraph(MessagesState)
767
+
768
+ # Add nodes to the graph
769
  builder.add_node("retriever", retriever)
770
  builder.add_node("assistant", assistant)
771
  builder.add_node("tools", ToolNode(tools))
772
+
773
+ # Set the entry point
774
  builder.set_entry_point("retriever")
775
+
776
+ # Define edges
777
  builder.add_edge("retriever", "assistant")
778
  builder.add_conditional_edges("assistant", tools_condition)
779
  builder.add_edge("tools", "assistant")