wt002 commited on
Commit
e35d057
Β·
verified Β·
1 Parent(s): b22ce48

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +53 -65
agent.py CHANGED
@@ -53,8 +53,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
53
  from huggingface_hub import login
54
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
55
  from langchain_huggingface import HuggingFaceEndpoint
56
- from langchain.agents import initialize_agent
57
- from langchain.agents import AgentType
58
  from typing import Union, List
59
  from functools import reduce
60
  import operator
@@ -446,13 +446,15 @@ llm = HuggingFaceEndpoint(
446
  max_new_tokens=512
447
  )
448
 
 
 
449
  # Initialize LangChain agent
450
- agent = initialize_agent(
451
- tools=tools,
452
- llm=llm,
453
- agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
454
- verbose=True
455
- )
456
 
457
 
458
 
@@ -963,8 +965,14 @@ model_config = {
963
  "huggingfacehub_api_token": os.getenv("HF_TOKEN")
964
  }
965
 
 
966
  def build_graph(provider, model_config):
967
- # Step 1: Initialize the LLM
 
 
 
 
 
968
  def get_llm(provider: str, config: dict):
969
  if provider == "huggingface":
970
  from langchain_huggingface import HuggingFaceEndpoint
@@ -979,53 +987,31 @@ def build_graph(provider, model_config):
979
  raise ValueError(f"Unsupported provider: {provider}")
980
 
981
  llm = get_llm(provider, model_config)
982
-
983
- # -------------------------------
984
- # Step 6: Define LangChain Tools
985
- # -------------------------------
986
- calc_tool = calculator # Math operations tool
987
- web_tool = web_search # Web search tool
988
- wiki_tool = wiki_search # Wikipedia search tool
989
- arxiv_tool = arxiv_search # Arxiv search tool
990
- youtube_tool = get_youtube_transcript # YouTube transcript extraction
991
- video_tool = extract_video_id # Video ID extraction tool
992
- analyze_tool = analyze_attachment # File analysis tool
993
- wikiq_tool = wikidata_query # Wikidata query tool
994
-
995
- # -------------------------------
996
- # Step 7: Create the Planner-Agent Logic
997
- # -------------------------------
998
- # Define tools list
999
  tools = [
1000
- wiki_tool,
1001
- calc_tool,
1002
- web_tool,
1003
- arxiv_tool,
1004
- youtube_tool,
1005
- video_tool,
1006
- analyze_tool,
1007
- wikiq_tool
1008
  ]
1009
-
1010
- # Step 8: Bind tools to the LLM
1011
- llm_with_tools = llm.bind_tools(tools)
1012
-
1013
- # Return the LLM with tools bound
1014
- return llm_with_tools
1015
-
1016
 
 
 
1017
 
1018
-
1019
- # Initialize system message
1020
  sys_msg = SystemMessage(content="You are a helpful assistant.")
1021
-
1022
- # Define the retriever function
1023
- def retriever(state: MessagesState):
1024
  user_query = state["messages"][0].content
1025
  similar_docs = vector_store.similarity_search(user_query)
1026
-
1027
  if not similar_docs:
1028
- wiki_result = wiki_tool.run(user_query)
1029
  return {
1030
  "messages": [
1031
  sys_msg,
@@ -1041,34 +1027,36 @@ def build_graph(provider, model_config):
1041
  HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
1042
  ]
1043
  }
1044
-
1045
- # Define the assistant function
1046
- def assistant(state: MessagesState):
1047
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
1048
-
1049
- # Define condition for tools usage
1050
- def tools_condition(state: MessagesState) -> str:
1051
  if "use tool" in state["messages"][-1].content.lower():
1052
  return "tools"
1053
  else:
1054
  return "END"
1055
-
1056
- # Initialize the StateGraph
1057
- builder = StateGraph(MessagesState)
1058
-
1059
- # Add nodes to the graph
1060
  builder.add_node("retriever", retriever)
1061
  builder.add_node("assistant", assistant)
1062
  builder.add_node("tools", ToolNode(tools))
1063
-
1064
- # Set the entry point
1065
  builder.set_entry_point("retriever")
1066
-
1067
- # Define edges
1068
  builder.add_edge("retriever", "assistant")
1069
  builder.add_conditional_edges("assistant", tools_condition)
1070
  builder.add_edge("tools", "assistant")
1071
 
1072
- # Compile graph
1073
- return builder.compile()
 
 
 
 
 
 
 
 
1074
 
 
53
  from huggingface_hub import login
54
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
55
  from langchain_huggingface import HuggingFaceEndpoint
56
+ #from langchain.agents import initialize_agent
57
+ #from langchain.agents import AgentType
58
  from typing import Union, List
59
  from functools import reduce
60
  import operator
 
446
  max_new_tokens=512
447
  )
448
 
449
+
450
+ # No longer required as Langgraph is replacing Langchain
451
  # Initialize LangChain agent
452
+ #agent = initialize_agent(
453
+ # tools=tools,
454
+ # llm=llm,
455
+ # agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
456
+ # verbose=True
457
+ #)
458
 
459
 
460
 
 
965
  "huggingfacehub_api_token": os.getenv("HF_TOKEN")
966
  }
967
 
968
+
969
  def build_graph(provider, model_config):
970
+ from langchain_core.messages import SystemMessage, HumanMessage
971
+ from langgraph.graph import StateGraph, ToolNode
972
+ from langchain_core.runnables import RunnableLambda
973
+ from some_module import vector_store # Make sure this is defined/imported
974
+
975
+ # Step 1: Get LLM
976
  def get_llm(provider: str, config: dict):
977
  if provider == "huggingface":
978
  from langchain_huggingface import HuggingFaceEndpoint
 
987
  raise ValueError(f"Unsupported provider: {provider}")
988
 
989
  llm = get_llm(provider, model_config)
990
+
991
+ # Step 2: Define tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992
  tools = [
993
+ wiki_search,
994
+ calculator,
995
+ web_search,
996
+ arxiv_search,
997
+ get_youtube_transcript,
998
+ extract_video_id,
999
+ analyze_attachment,
1000
+ wikidata_query
1001
  ]
 
 
 
 
 
 
 
1002
 
1003
+ # Step 3: Bind tools to LLM
1004
+ llm_with_tools = llm.bind_tools(tools)
1005
 
1006
+ # Step 4: Build stateful graph logic
 
1007
  sys_msg = SystemMessage(content="You are a helpful assistant.")
1008
+
1009
+ def retriever(state: dict):
 
1010
  user_query = state["messages"][0].content
1011
  similar_docs = vector_store.similarity_search(user_query)
1012
+
1013
  if not similar_docs:
1014
+ wiki_result = wiki_search.run(user_query)
1015
  return {
1016
  "messages": [
1017
  sys_msg,
 
1027
  HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
1028
  ]
1029
  }
1030
+
1031
+ def assistant(state: dict):
 
1032
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
1033
+
1034
+ def tools_condition(state: dict) -> str:
 
1035
  if "use tool" in state["messages"][-1].content.lower():
1036
  return "tools"
1037
  else:
1038
  return "END"
1039
+
1040
+ # Step 5: Define LangGraph StateGraph
1041
+ builder = StateGraph(dict) # Using dict as state type here
1042
+
 
1043
  builder.add_node("retriever", retriever)
1044
  builder.add_node("assistant", assistant)
1045
  builder.add_node("tools", ToolNode(tools))
1046
+
 
1047
  builder.set_entry_point("retriever")
 
 
1048
  builder.add_edge("retriever", "assistant")
1049
  builder.add_conditional_edges("assistant", tools_condition)
1050
  builder.add_edge("tools", "assistant")
1051
 
1052
+ graph = builder.compile()
1053
+ return graph
1054
+
1055
+
1056
+ # call build_graph AFTER it’s defined
1057
+ agent = build_graph(provider, model_config)
1058
+
1059
+ # Now you can use the agent like this:
1060
+ result = agent.invoke({"messages": [HumanMessage(content=question)]})
1061
+
1062