Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -53,8 +53,8 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
53 |
from huggingface_hub import login
|
54 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
|
55 |
from langchain_huggingface import HuggingFaceEndpoint
|
56 |
-
from langchain.agents import initialize_agent
|
57 |
-
from langchain.agents import AgentType
|
58 |
from typing import Union, List
|
59 |
from functools import reduce
|
60 |
import operator
|
@@ -446,13 +446,15 @@ llm = HuggingFaceEndpoint(
|
|
446 |
max_new_tokens=512
|
447 |
)
|
448 |
|
|
|
|
|
449 |
# Initialize LangChain agent
|
450 |
-
agent = initialize_agent(
|
451 |
-
tools=tools,
|
452 |
-
llm=llm,
|
453 |
-
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
454 |
-
verbose=True
|
455 |
-
)
|
456 |
|
457 |
|
458 |
|
@@ -963,8 +965,14 @@ model_config = {
|
|
963 |
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
964 |
}
|
965 |
|
|
|
966 |
def build_graph(provider, model_config):
|
967 |
-
|
|
|
|
|
|
|
|
|
|
|
968 |
def get_llm(provider: str, config: dict):
|
969 |
if provider == "huggingface":
|
970 |
from langchain_huggingface import HuggingFaceEndpoint
|
@@ -979,53 +987,31 @@ def build_graph(provider, model_config):
|
|
979 |
raise ValueError(f"Unsupported provider: {provider}")
|
980 |
|
981 |
llm = get_llm(provider, model_config)
|
982 |
-
|
983 |
-
#
|
984 |
-
# Step 6: Define LangChain Tools
|
985 |
-
# -------------------------------
|
986 |
-
calc_tool = calculator # Math operations tool
|
987 |
-
web_tool = web_search # Web search tool
|
988 |
-
wiki_tool = wiki_search # Wikipedia search tool
|
989 |
-
arxiv_tool = arxiv_search # Arxiv search tool
|
990 |
-
youtube_tool = get_youtube_transcript # YouTube transcript extraction
|
991 |
-
video_tool = extract_video_id # Video ID extraction tool
|
992 |
-
analyze_tool = analyze_attachment # File analysis tool
|
993 |
-
wikiq_tool = wikidata_query # Wikidata query tool
|
994 |
-
|
995 |
-
# -------------------------------
|
996 |
-
# Step 7: Create the Planner-Agent Logic
|
997 |
-
# -------------------------------
|
998 |
-
# Define tools list
|
999 |
tools = [
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
]
|
1009 |
-
|
1010 |
-
# Step 8: Bind tools to the LLM
|
1011 |
-
llm_with_tools = llm.bind_tools(tools)
|
1012 |
-
|
1013 |
-
# Return the LLM with tools bound
|
1014 |
-
return llm_with_tools
|
1015 |
-
|
1016 |
|
|
|
|
|
1017 |
|
1018 |
-
|
1019 |
-
# Initialize system message
|
1020 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1021 |
-
|
1022 |
-
|
1023 |
-
def retriever(state: MessagesState):
|
1024 |
user_query = state["messages"][0].content
|
1025 |
similar_docs = vector_store.similarity_search(user_query)
|
1026 |
-
|
1027 |
if not similar_docs:
|
1028 |
-
wiki_result =
|
1029 |
return {
|
1030 |
"messages": [
|
1031 |
sys_msg,
|
@@ -1041,34 +1027,36 @@ def build_graph(provider, model_config):
|
|
1041 |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
1042 |
]
|
1043 |
}
|
1044 |
-
|
1045 |
-
|
1046 |
-
def assistant(state: MessagesState):
|
1047 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
1048 |
-
|
1049 |
-
|
1050 |
-
def tools_condition(state: MessagesState) -> str:
|
1051 |
if "use tool" in state["messages"][-1].content.lower():
|
1052 |
return "tools"
|
1053 |
else:
|
1054 |
return "END"
|
1055 |
-
|
1056 |
-
#
|
1057 |
-
builder = StateGraph(
|
1058 |
-
|
1059 |
-
# Add nodes to the graph
|
1060 |
builder.add_node("retriever", retriever)
|
1061 |
builder.add_node("assistant", assistant)
|
1062 |
builder.add_node("tools", ToolNode(tools))
|
1063 |
-
|
1064 |
-
# Set the entry point
|
1065 |
builder.set_entry_point("retriever")
|
1066 |
-
|
1067 |
-
# Define edges
|
1068 |
builder.add_edge("retriever", "assistant")
|
1069 |
builder.add_conditional_edges("assistant", tools_condition)
|
1070 |
builder.add_edge("tools", "assistant")
|
1071 |
|
1072 |
-
|
1073 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1074 |
|
|
|
53 |
from huggingface_hub import login
|
54 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
|
55 |
from langchain_huggingface import HuggingFaceEndpoint
|
56 |
+
#from langchain.agents import initialize_agent
|
57 |
+
#from langchain.agents import AgentType
|
58 |
from typing import Union, List
|
59 |
from functools import reduce
|
60 |
import operator
|
|
|
446 |
max_new_tokens=512
|
447 |
)
|
448 |
|
449 |
+
|
450 |
+
# No longer required as Langgraph is replacing Langchain
|
451 |
# Initialize LangChain agent
|
452 |
+
#agent = initialize_agent(
|
453 |
+
# tools=tools,
|
454 |
+
# llm=llm,
|
455 |
+
# agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
456 |
+
# verbose=True
|
457 |
+
#)
|
458 |
|
459 |
|
460 |
|
|
|
965 |
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
966 |
}
|
967 |
|
968 |
+
|
969 |
def build_graph(provider, model_config):
|
970 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
971 |
+
from langgraph.graph import StateGraph, ToolNode
|
972 |
+
from langchain_core.runnables import RunnableLambda
|
973 |
+
from some_module import vector_store # Make sure this is defined/imported
|
974 |
+
|
975 |
+
# Step 1: Get LLM
|
976 |
def get_llm(provider: str, config: dict):
|
977 |
if provider == "huggingface":
|
978 |
from langchain_huggingface import HuggingFaceEndpoint
|
|
|
987 |
raise ValueError(f"Unsupported provider: {provider}")
|
988 |
|
989 |
llm = get_llm(provider, model_config)
|
990 |
+
|
991 |
+
# Step 2: Define tools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
992 |
tools = [
|
993 |
+
wiki_search,
|
994 |
+
calculator,
|
995 |
+
web_search,
|
996 |
+
arxiv_search,
|
997 |
+
get_youtube_transcript,
|
998 |
+
extract_video_id,
|
999 |
+
analyze_attachment,
|
1000 |
+
wikidata_query
|
1001 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
|
1003 |
+
# Step 3: Bind tools to LLM
|
1004 |
+
llm_with_tools = llm.bind_tools(tools)
|
1005 |
|
1006 |
+
# Step 4: Build stateful graph logic
|
|
|
1007 |
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1008 |
+
|
1009 |
+
def retriever(state: dict):
|
|
|
1010 |
user_query = state["messages"][0].content
|
1011 |
similar_docs = vector_store.similarity_search(user_query)
|
1012 |
+
|
1013 |
if not similar_docs:
|
1014 |
+
wiki_result = wiki_search.run(user_query)
|
1015 |
return {
|
1016 |
"messages": [
|
1017 |
sys_msg,
|
|
|
1027 |
HumanMessage(content=f"Reference:\n\n{similar_docs[0].page_content}")
|
1028 |
]
|
1029 |
}
|
1030 |
+
|
1031 |
+
def assistant(state: dict):
|
|
|
1032 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
1033 |
+
|
1034 |
+
def tools_condition(state: dict) -> str:
|
|
|
1035 |
if "use tool" in state["messages"][-1].content.lower():
|
1036 |
return "tools"
|
1037 |
else:
|
1038 |
return "END"
|
1039 |
+
|
1040 |
+
# Step 5: Define LangGraph StateGraph
|
1041 |
+
builder = StateGraph(dict) # Using dict as state type here
|
1042 |
+
|
|
|
1043 |
builder.add_node("retriever", retriever)
|
1044 |
builder.add_node("assistant", assistant)
|
1045 |
builder.add_node("tools", ToolNode(tools))
|
1046 |
+
|
|
|
1047 |
builder.set_entry_point("retriever")
|
|
|
|
|
1048 |
builder.add_edge("retriever", "assistant")
|
1049 |
builder.add_conditional_edges("assistant", tools_condition)
|
1050 |
builder.add_edge("tools", "assistant")
|
1051 |
|
1052 |
+
graph = builder.compile()
|
1053 |
+
return graph
|
1054 |
+
|
1055 |
+
|
1056 |
+
# call build_graph AFTER itβs defined
|
1057 |
+
agent = build_graph(provider, model_config)
|
1058 |
+
|
1059 |
+
# Now you can use the agent like this:
|
1060 |
+
result = agent.invoke({"messages": [HumanMessage(content=question)]})
|
1061 |
+
|
1062 |
|