Spaces:
Runtime error
Runtime error
File size: 1,496 Bytes
90f65f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | # file: agent_graph_factory.py
from typing import TypedDict, Annotated, List
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from tools.word_counter import count_words
# 1. Define the Agent State
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], lambda x, y: x + y]
def create_graph_app(llm: Runnable) -> Runnable:
"""
Factory function to create the LangGraph app.
Takes a language model as input and returns a compiled runnable graph.
"""
tools = [count_words]
llm_with_tools = llm.bind_tools(tools)
# 2. Define the Nodes
def call_model(state):
messages = state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
tool_node = ToolNode(tools)
# 3. Define the Conditional Edge
def should_continue(state):
last_message = state["messages"][-1]
if last_message.tool_calls:
return "continue"
return "end"
# 4. Build the Graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{"continue": "action", "end": END}
)
workflow.add_edge("action", "agent")
app = workflow.compile()
return app |