Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
from dotenv import load_dotenv
|
5 |
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
from langgraph.prebuilt import tools_condition
|
7 |
-
from langgraph.prebuilt import ToolNode
|
8 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
from langchain_groq import ChatGroq
|
10 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
@@ -839,44 +839,74 @@ def process_question(question: str):
|
|
839 |
|
840 |
|
841 |
|
|
|
842 |
def call_llm(state):
|
843 |
messages = state["messages"]
|
844 |
response = llm.invoke(messages)
|
845 |
return {"messages": messages + [response]}
|
846 |
|
847 |
-
|
|
|
|
|
|
|
|
|
848 |
|
|
|
|
|
849 |
|
850 |
-
from langgraph.graph import StateGraph
|
851 |
-
from typing import TypedDict
|
852 |
|
853 |
-
# Define the state schema
|
854 |
-
class AgentState(TypedDict):
|
855 |
-
input: str
|
856 |
-
result: str
|
857 |
|
858 |
-
|
859 |
-
builder = StateGraph(AgentState)
|
860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
861 |
|
862 |
-
builder.add_node("call_llm", call_llm)
|
863 |
-
builder.add_node("call_tool", tool_node)
|
864 |
|
865 |
# Decide what to do next: if tool call → call_tool, else → end
|
866 |
-
def
|
867 |
last_msg = state["messages"][-1]
|
868 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
869 |
return "call_tool"
|
870 |
return "end"
|
871 |
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
877 |
|
878 |
-
# After tool runs, go back to the LLM
|
879 |
-
builder.add_edge("call_tool", "call_llm")
|
880 |
|
881 |
|
882 |
|
@@ -1009,7 +1039,7 @@ model_config = {
|
|
1009 |
|
1010 |
def build_graph(provider, model_config):
|
1011 |
from langchain_core.messages import SystemMessage, HumanMessage
|
1012 |
-
from langgraph.graph import StateGraph, ToolNode
|
1013 |
from langchain_core.runnables import RunnableLambda
|
1014 |
from some_module import vector_store # Make sure this is defined/imported
|
1015 |
|
@@ -1078,20 +1108,39 @@ def build_graph(provider, model_config):
|
|
1078 |
else:
|
1079 |
return "END"
|
1080 |
|
1081 |
-
# Step 5: Define LangGraph StateGraph
|
1082 |
-
builder = StateGraph(dict) # Using dict as state type here
|
1083 |
|
1084 |
-
builder.add_node("retriever", retriever)
|
1085 |
-
builder.add_node("assistant", assistant)
|
1086 |
-
builder.add_node("tools", ToolNode(tools))
|
1087 |
|
1088 |
-
|
1089 |
-
|
1090 |
-
|
1091 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1092 |
|
1093 |
-
graph = builder.compile()
|
1094 |
-
return graph
|
1095 |
|
1096 |
|
1097 |
# call build_graph AFTER it’s defined
|
|
|
4 |
from dotenv import load_dotenv
|
5 |
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
from langgraph.prebuilt import tools_condition
|
7 |
+
#from langgraph.prebuilt import ToolNode
|
8 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
from langchain_groq import ChatGroq
|
10 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
|
|
839 |
|
840 |
|
841 |
|
842 |
+
|
843 |
def call_llm(state):
|
844 |
messages = state["messages"]
|
845 |
response = llm.invoke(messages)
|
846 |
return {"messages": messages + [response]}
|
847 |
|
848 |
+
builder.set_entry_point("call_llm")
|
849 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
850 |
+
"call_tool": "call_tool",
|
851 |
+
"end": None
|
852 |
+
})
|
853 |
|
854 |
+
# After tool runs, go back to the LLM
|
855 |
+
builder.add_edge("call_tool", "call_llm")
|
856 |
|
|
|
|
|
857 |
|
|
|
|
|
|
|
|
|
858 |
|
859 |
+
from langchain.schema import AIMessage
|
|
|
860 |
|
861 |
+
def tool_dispatcher(state: AgentState) -> AgentState:
|
862 |
+
last_msg = state["messages"][-1]
|
863 |
+
|
864 |
+
# Make sure it's an AI message with tool_calls
|
865 |
+
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
866 |
+
tool_call = last_msg.tool_calls[0]
|
867 |
+
tool_name = tool_call["name"]
|
868 |
+
tool_input = tool_call["args"] # Adjust based on your actual schema
|
869 |
+
|
870 |
+
tool_func = tool_map.get(tool_name, default_tool)
|
871 |
+
|
872 |
+
# If args is a dict and your tool expects unpacked values:
|
873 |
+
if isinstance(tool_input, dict):
|
874 |
+
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input)
|
875 |
+
else:
|
876 |
+
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input)
|
877 |
+
|
878 |
+
# You can choose to append this to messages, or just save result
|
879 |
+
return {
|
880 |
+
**state,
|
881 |
+
"result": result,
|
882 |
+
# Optionally add: "messages": state["messages"] + [ToolMessage(...)]
|
883 |
+
}
|
884 |
+
|
885 |
+
# No tool call detected, return state unchanged
|
886 |
+
return state
|
887 |
+
|
888 |
+
|
889 |
|
|
|
|
|
890 |
|
891 |
# Decide what to do next: if tool call → call_tool, else → end
|
892 |
+
def call_tool(state):
|
893 |
last_msg = state["messages"][-1]
|
894 |
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
895 |
return "call_tool"
|
896 |
return "end"
|
897 |
|
898 |
+
from typing import TypedDict, List, Optional, Union
|
899 |
+
from langchain.schema import BaseMessage
|
900 |
+
|
901 |
+
class AgentState(TypedDict):
|
902 |
+
messages: List[BaseMessage] # chat history
|
903 |
+
input: str # original input
|
904 |
+
intent: str # derived or predicted intent
|
905 |
+
result: Optional[str] # tool output, if any
|
906 |
+
|
907 |
+
builder.add_node("call_tool", tool_dispatcher)
|
908 |
+
|
909 |
|
|
|
|
|
910 |
|
911 |
|
912 |
|
|
|
1039 |
|
1040 |
def build_graph(provider, model_config):
|
1041 |
from langchain_core.messages import SystemMessage, HumanMessage
|
1042 |
+
#from langgraph.graph import StateGraph, ToolNode
|
1043 |
from langchain_core.runnables import RunnableLambda
|
1044 |
from some_module import vector_store # Make sure this is defined/imported
|
1045 |
|
|
|
1108 |
else:
|
1109 |
return "END"
|
1110 |
|
|
|
|
|
1111 |
|
|
|
|
|
|
|
1112 |
|
1113 |
+
from langgraph.graph import StateGraph
|
1114 |
+
|
1115 |
+
# Build graph using AgentState as the shared schema
|
1116 |
+
builder = StateGraph(AgentState)
|
1117 |
+
|
1118 |
+
# Add nodes
|
1119 |
+
builder.add_node("retriever", retriever)
|
1120 |
+
builder.add_node("assistant", assistant)
|
1121 |
+
builder.add_node("call_llm", call_llm)
|
1122 |
+
builder.add_node("call_tool", tool_dispatcher) # one name is enough
|
1123 |
+
|
1124 |
+
# Entry point
|
1125 |
+
builder.set_entry_point("retriever")
|
1126 |
+
|
1127 |
+
# Define the flow
|
1128 |
+
builder.add_edge("retriever", "assistant")
|
1129 |
+
builder.add_edge("assistant", "call_llm")
|
1130 |
+
|
1131 |
+
# Conditional edge from LLM to tool or end
|
1132 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
1133 |
+
"call_tool": "call_tool",
|
1134 |
+
"end": None
|
1135 |
+
})
|
1136 |
+
|
1137 |
+
# Loop back after tool execution
|
1138 |
+
builder.add_edge("call_tool", "call_llm")
|
1139 |
+
|
1140 |
+
# Compile
|
1141 |
+
graph = builder.compile()
|
1142 |
+
return graph
|
1143 |
|
|
|
|
|
1144 |
|
1145 |
|
1146 |
# call build_graph AFTER it’s defined
|