Spaces:
Runtime error
Runtime error
from typing import TypedDict, Annotated, Sequence, Literal | |
from functools import lru_cache | |
from langchain_core.messages import BaseMessage | |
from langchain_anthropic import ChatAnthropic | |
from langchain_openai import ChatOpenAI | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph import StateGraph, END, add_messages | |
tools = [TavilySearchResults(max_results=1)] | |
def _get_model(model_name: str): | |
if model_name == "openai": | |
model = ChatOpenAI(temperature=0, model_name="gpt-4o") | |
elif model_name == "anthropic": | |
model = ChatAnthropic(temperature=0, model_name="claude-3-sonnet-20240229") | |
else: | |
raise ValueError(f"Unsupported model type: {model_name}") | |
model = model.bind_tools(tools) | |
return model | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], add_messages] | |
# Define the function that determines whether to continue or not | |
def should_continue(state): | |
messages = state["messages"] | |
last_message = messages[-1] | |
# If there are no tool calls, then we finish | |
if not last_message.tool_calls: | |
return "end" | |
# Otherwise if there is, we continue | |
else: | |
return "continue" | |
system_prompt = """Be a helpful assistant""" | |
# Define the function that calls the model | |
def call_model(state, config): | |
messages = state["messages"] | |
messages = [{"role": "system", "content": system_prompt}] + messages | |
model_name = config.get('configurable', {}).get("model_name", "anthropic") | |
model = _get_model(model_name) | |
response = model.invoke(messages) | |
# We return a list, because this will get added to the existing list | |
return {"messages": [response]} | |
# Define the function to execute tools | |
tool_node = ToolNode(tools) | |
# Define the config | |
class GraphConfig(TypedDict): | |
model_name: Literal["anthropic", "openai"] | |
# Define a new graph | |
workflow = StateGraph(AgentState, config_schema=GraphConfig) | |
# Define the two nodes we will cycle between | |
workflow.add_node("agent", call_model) | |
workflow.add_node("action", tool_node) | |
# Set the entrypoint as `agent` | |
# This means that this node is the first one called | |
workflow.set_entry_point("agent") | |
# We now add a conditional edge | |
workflow.add_conditional_edges( | |
# First, we define the start node. We use `agent`. | |
# This means these are the edges taken after the `agent` node is called. | |
"agent", | |
# Next, we pass in the function that will determine which node is called next. | |
should_continue, | |
# Finally we pass in a mapping. | |
# The keys are strings, and the values are other nodes. | |
# END is a special node marking that the graph should finish. | |
# What will happen is we will call `should_continue`, and then the output of that | |
# will be matched against the keys in this mapping. | |
# Based on which one it matches, that node will then be called. | |
{ | |
# If `tools`, then we call the tool node. | |
"continue": "action", | |
# Otherwise we finish. | |
"end": END, | |
}, | |
) | |
# We now add a normal edge from `tools` to `agent`. | |
# This means that after `tools` is called, `agent` node is called next. | |
workflow.add_edge("action", "agent") | |
# Finally, we compile it! | |
# This compiles it into a LangChain Runnable, | |
# meaning you can use it as you would any other runnable | |
graph = workflow.compile() |