cjber's picture
fix(updates to doc format):
5767260
raw
history blame contribute delete
1.37 kB
from langgraph.constants import START
from langgraph.graph import END, StateGraph
from planning_ai.nodes.hallucination_node import (
check_hallucination,
fix_hallucination,
map_check,
map_fix,
)
from planning_ai.nodes.map_node import generate_summary, map_documents
from planning_ai.nodes.reduce_node import generate_final_report
from planning_ai.states import OverallState
def create_graph():
graph = StateGraph(OverallState)
# graph.add_node("add_entities", add_entities)
graph.add_node("generate_summary", generate_summary)
graph.add_node("check_hallucination", check_hallucination)
graph.add_node("fix_hallucination", fix_hallucination)
graph.add_node("generate_final_report", generate_final_report)
# graph.add_edge(START, "add_entities")
graph.add_conditional_edges(START, map_documents, ["generate_summary"])
graph.add_conditional_edges("generate_summary", map_check, ["check_hallucination"])
graph.add_conditional_edges("check_hallucination", map_fix, ["fix_hallucination"])
graph.add_conditional_edges("fix_hallucination", map_check, ["check_hallucination"])
graph.add_edge("check_hallucination", "generate_final_report")
graph.add_edge("generate_final_report", END)
return graph.compile()
def plot_mermaid():
graph = create_graph()
print(graph.get_graph().draw_mermaid())