from langgraph.constants import START from langgraph.graph import END, StateGraph from planning_ai.nodes.hallucination_node import ( check_hallucination, fix_hallucination, map_check, map_fix, ) from planning_ai.nodes.map_node import generate_summary, map_documents from planning_ai.nodes.reduce_node import generate_final_report from planning_ai.states import OverallState def create_graph(): graph = StateGraph(OverallState) # graph.add_node("add_entities", add_entities) graph.add_node("generate_summary", generate_summary) graph.add_node("check_hallucination", check_hallucination) graph.add_node("fix_hallucination", fix_hallucination) graph.add_node("generate_final_report", generate_final_report) # graph.add_edge(START, "add_entities") graph.add_conditional_edges(START, map_documents, ["generate_summary"]) graph.add_conditional_edges("generate_summary", map_check, ["check_hallucination"]) graph.add_conditional_edges("check_hallucination", map_fix, ["fix_hallucination"]) graph.add_conditional_edges("fix_hallucination", map_check, ["check_hallucination"]) graph.add_edge("check_hallucination", "generate_final_report") graph.add_edge("generate_final_report", END) return graph.compile() def plot_mermaid(): graph = create_graph() print(graph.get_graph().draw_mermaid())