from langgraph.graph import START, StateGraph from langgraph.checkpoint.memory import MemorySaver from langchain_core.messages import HumanMessage, AIMessage, BaseMessage from typing import Sequence from typing_extensions import TypedDict, Annotated class State(TypedDict): input: str chat_history: Annotated[Sequence[BaseMessage], "add_messages"] context: str answer: str # For direct RAG chains def build_graph(rag_chain): def call_model(state: State): response = rag_chain.invoke(state) return { "chat_history": [ HumanMessage(state["input"]), AIMessage(response["answer"]), ], "context": response.get("context", ""), "answer": response.get("answer", "") } workflow = StateGraph(state_schema=State) workflow.add_node("model", call_model) workflow.add_edge(START, "model") #memory = MemorySaver() #return workflow.compile(checkpointer=memory) return workflow.compile() # Stateless; relies on session memory only # For agent_chain.invoke def build_graph_with_callable(call_fn): def call_model(state: State): response = call_fn({"input": state["input"]}) return { "chat_history": [ HumanMessage(state["input"]), AIMessage(response.get("output", response.get("answer", ""))), ], "context": response.get("context", ""), "answer": response.get("output", response.get("answer", "")) } workflow = StateGraph(state_schema=State) workflow.add_node("model", call_model) workflow.add_edge(START, "model") return workflow.compile()