File size: 1,496 Bytes
90f65f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# file: agent_graph_factory.py

from typing import TypedDict, Annotated, List
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from tools.word_counter import count_words

# 1. Define the Agent State
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], lambda x, y: x + y]

def create_graph_app(llm: Runnable) -> Runnable:
    """
    Factory function to create the LangGraph app.
    Takes a language model as input and returns a compiled runnable graph.
    """
    tools = [count_words]
    llm_with_tools = llm.bind_tools(tools)

    # 2. Define the Nodes
    def call_model(state):
        messages = state["messages"]
        response = llm_with_tools.invoke(messages)
        return {"messages": [response]}

    tool_node = ToolNode(tools)

    # 3. Define the Conditional Edge
    def should_continue(state):
        last_message = state["messages"][-1]
        if last_message.tool_calls:
            return "continue"
        return "end"

    # 4. Build the Graph
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", call_model)
    workflow.add_node("action", tool_node)
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {"continue": "action", "end": END}
    )
    workflow.add_edge("action", "agent")
    app = workflow.compile()
    return app