Spaces:
Running
Running
import os | |
from typing import Dict, TypedDict, List, Annotated, Literal, Union, Any | |
from .tools import get_tools | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langgraph.graph import StateGraph, START, END | |
import operator | |
from langchain_core.messages import ( | |
AIMessage, | |
HumanMessage, | |
SystemMessage, | |
ToolMessage, | |
FunctionMessage, | |
) | |
from langchain_core.tools import BaseTool, tool | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode, tools_condition | |
import json | |
# State definition | |
class AgentState(TypedDict): | |
messages: Annotated[list, add_messages] | |
# Initialize tools | |
tools = get_tools() | |
# System prompt | |
system_prompt = """You are an AI research assistant specialized in {domain}. | |
Your goal is to help users find accurate information about {domain} topics. | |
You have access to the following tools: | |
1. Web Search - For general queries and recent information | |
2. Research Paper Search - For academic and scientific information | |
3. Wikipedia Search - For comprehensive background information and factual summaries | |
4. Data Analysis - For analyzing data provided by the user | |
Choose the most appropriate tool(s) based on the user's question: | |
- Use Web Search for current events, recent developments, or general information | |
- Use Research Paper Search for academic knowledge, scientific findings, or technical details | |
- Use Wikipedia Search for conceptual explanations, definitions, historical context, or general facts | |
- Use Data Analysis when the user provides data to be analyzed | |
Always try to provide the most accurate and helpful information. | |
When responding, cite your sources appropriately.""" | |
# Function to create the system message | |
def create_system_message(domain): | |
return SystemMessage(content=system_prompt.format(domain=domain)) | |
# Create the graph | |
def create_agent_graph(domain="general research"): | |
""" | |
Create a LangGraph for the research agent using prebuilt components | |
""" | |
# Initialize the graph with the state | |
workflow = StateGraph(AgentState) | |
# Add system message with domain context | |
system_prompt_message = create_system_message(domain) | |
# Agent node function | |
def agent_node(state: AgentState): | |
messages = state["messages"] | |
if len(messages) == 0 or not isinstance(messages[0], SystemMessage): | |
messages = [system_prompt_message] + messages | |
# Create model and bind tools | |
model = ChatOpenAI(model="gpt-4o", temperature=0) | |
model_with_tools = model.bind_tools(tools) | |
# Generate response with tools | |
return {"messages": [model_with_tools.invoke(messages)]} | |
# Add nodes | |
workflow.add_node("agent", agent_node) | |
# Use prebuilt ToolNode | |
tool_node = ToolNode(tools=tools) | |
workflow.add_node("tools", tool_node) | |
# Add conditional edges using prebuilt tools_condition | |
workflow.add_conditional_edges( | |
"agent", | |
tools_condition, | |
{ | |
"tools": "tools", | |
END: END | |
} | |
) | |
# Add edge back to agent after tools execution | |
workflow.add_edge("tools", "agent") | |
# Set the entry point | |
workflow.add_edge(START, "agent") | |
# Compile the graph | |
return workflow.compile() | |
# Function to run the agent | |
def run_agent(user_input, domain="general research", messages=None): | |
""" | |
Run the agent with a user input | |
""" | |
# Create the graph | |
graph = create_agent_graph(domain) | |
# Initialize messages if not provided | |
if messages is None: | |
messages = [HumanMessage(content=user_input)] | |
else: | |
messages.append(HumanMessage(content=user_input)) | |
# Run the graph | |
result = graph.invoke({"messages": messages}) | |
return result["messages"] | |
if __name__ == "__main__": | |
# Test the agent | |
domain = "artificial intelligence" | |
query = "What are the latest developments in natural language processing?" | |
messages = run_agent(query, domain) | |
for message in messages: | |
if isinstance(message, AIMessage): | |
print("AI:", message.content) | |
elif isinstance(message, HumanMessage): | |
print("Human:", message.content) | |
elif isinstance(message, ToolMessage): | |
print(f"Tool ({message.name}):", message.content[:100] + "..." if len(message.content) > 100 else message.content) |