Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -3,15 +3,16 @@ from langchain_core.messages import BaseMessage, HumanMessage
|
|
| 3 |
from langchain_core.tools import tool
|
| 4 |
from langchain_openai import ChatOpenAI
|
| 5 |
from langgraph.graph import END, StateGraph
|
| 6 |
-
from langgraph.prebuilt import
|
| 7 |
from langchain.tools import DuckDuckGoSearchResults
|
| 8 |
from langchain_community.utilities import WikipediaAPIWrapper
|
| 9 |
from langchain.agents import create_tool_calling_agent
|
| 10 |
from langchain.agents import AgentExecutor
|
| 11 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 12 |
import operator
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
# Define the agent state
|
| 17 |
class AgentState(TypedDict):
|
|
@@ -44,7 +45,6 @@ class AdvancedAIAgent:
|
|
| 44 |
|
| 45 |
# Create the agent
|
| 46 |
self.agent = self._create_agent()
|
| 47 |
-
self.tool_executor = ToolExecutor(self.tools)
|
| 48 |
|
| 49 |
# Build the graph workflow
|
| 50 |
self.workflow = self._build_graph()
|
|
@@ -66,7 +66,7 @@ class AdvancedAIAgent:
|
|
| 66 |
|
| 67 |
# Define nodes
|
| 68 |
workflow.add_node("agent", self._call_agent)
|
| 69 |
-
workflow.add_node("tools", self.
|
| 70 |
|
| 71 |
# Define edges
|
| 72 |
workflow.set_entry_point("agent")
|
|
@@ -87,31 +87,6 @@ class AdvancedAIAgent:
|
|
| 87 |
response = self.agent.invoke({"messages": state["messages"]})
|
| 88 |
return {"messages": [response["output"]]}
|
| 89 |
|
| 90 |
-
def _call_tools(self, state: AgentState):
|
| 91 |
-
"""Execute tools"""
|
| 92 |
-
last_message = state["messages"][-1]
|
| 93 |
-
|
| 94 |
-
# Find the tool calls
|
| 95 |
-
tool_calls = last_message.additional_kwargs.get("tool_calls", [])
|
| 96 |
-
|
| 97 |
-
# Execute each tool
|
| 98 |
-
for tool_call in tool_calls:
|
| 99 |
-
action = ToolInvocation(
|
| 100 |
-
tool=tool_call["function"]["name"],
|
| 101 |
-
tool_input=json.loads(tool_call["function"]["arguments"]),
|
| 102 |
-
)
|
| 103 |
-
output = self.tool_executor.invoke(action)
|
| 104 |
-
|
| 105 |
-
# Create tool message
|
| 106 |
-
tool_message = ToolMessage(
|
| 107 |
-
content=str(output),
|
| 108 |
-
name=action.tool,
|
| 109 |
-
tool_call_id=tool_call["id"],
|
| 110 |
-
)
|
| 111 |
-
state["messages"].append(tool_message)
|
| 112 |
-
|
| 113 |
-
return {"messages": state["messages"]}
|
| 114 |
-
|
| 115 |
def _should_continue(self, state: AgentState):
|
| 116 |
"""Determine if the workflow should continue"""
|
| 117 |
last_message = state["messages"][-1]
|
|
@@ -153,27 +128,4 @@ class AdvancedAIAgent:
|
|
| 153 |
if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
|
| 154 |
for call in msg.additional_kwargs['tool_calls']:
|
| 155 |
steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
|
| 156 |
-
return steps
|
| 157 |
-
|
| 158 |
-
# Example usage
|
| 159 |
-
if __name__ == "__main__":
|
| 160 |
-
agent = AdvancedAIAgent()
|
| 161 |
-
|
| 162 |
-
queries = [
|
| 163 |
-
"What is the capital of France?",
|
| 164 |
-
"Calculate 15% of 200",
|
| 165 |
-
"Tell me about the latest developments in quantum computing"
|
| 166 |
-
]
|
| 167 |
-
|
| 168 |
-
for query in queries:
|
| 169 |
-
print(f"\nQuestion: {query}")
|
| 170 |
-
response = agent(query)
|
| 171 |
-
print(f"Answer: {response['response']}")
|
| 172 |
-
if response['sources']:
|
| 173 |
-
print("Sources:")
|
| 174 |
-
for source in response['sources']:
|
| 175 |
-
print(f"- {source}")
|
| 176 |
-
if response['steps']:
|
| 177 |
-
print("Steps taken:")
|
| 178 |
-
for step in response['steps']:
|
| 179 |
-
print(f"- {step}")
|
|
|
|
| 3 |
from langchain_core.tools import tool
|
| 4 |
from langchain_openai import ChatOpenAI
|
| 5 |
from langgraph.graph import END, StateGraph
|
| 6 |
+
from langgraph.prebuilt import ToolNode
|
| 7 |
from langchain.tools import DuckDuckGoSearchResults
|
| 8 |
from langchain_community.utilities import WikipediaAPIWrapper
|
| 9 |
from langchain.agents import create_tool_calling_agent
|
| 10 |
from langchain.agents import AgentExecutor
|
| 11 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 12 |
import operator
|
| 13 |
+
import json
|
| 14 |
|
| 15 |
+
load dotenv()
|
| 16 |
|
| 17 |
# Define the agent state
|
| 18 |
class AgentState(TypedDict):
|
|
|
|
| 45 |
|
| 46 |
# Create the agent
|
| 47 |
self.agent = self._create_agent()
|
|
|
|
| 48 |
|
| 49 |
# Build the graph workflow
|
| 50 |
self.workflow = self._build_graph()
|
|
|
|
| 66 |
|
| 67 |
# Define nodes
|
| 68 |
workflow.add_node("agent", self._call_agent)
|
| 69 |
+
workflow.add_node("tools", ToolNode(self.tools)) # Using ToolNode instead of ToolExecutor
|
| 70 |
|
| 71 |
# Define edges
|
| 72 |
workflow.set_entry_point("agent")
|
|
|
|
| 87 |
response = self.agent.invoke({"messages": state["messages"]})
|
| 88 |
return {"messages": [response["output"]]}
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def _should_continue(self, state: AgentState):
|
| 91 |
"""Determine if the workflow should continue"""
|
| 92 |
last_message = state["messages"][-1]
|
|
|
|
| 128 |
if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
|
| 129 |
for call in msg.additional_kwargs['tool_calls']:
|
| 130 |
steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
|
| 131 |
+
return steps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|