File size: 1,370 Bytes
43845c7
1ccbee8
21b7409
1ccbee8
 
 
1af4802
 
1ccbee8
a3397bd
648f926
43845c7
21b7409
 
 
 
60a2039
8359c66
13686ad
 
1af4802
aa05cc8
60a2039
 
1af4802
 
 
 
 
3dc0350
21b7409
 
76aaf8d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from langgraph.constants import START
from langgraph.graph import END, StateGraph

from planning_ai.nodes.hallucination_node import (
    check_hallucination,
    fix_hallucination,
    map_check,
    map_fix,
)
from planning_ai.nodes.map_node import generate_summary, map_documents
from planning_ai.nodes.reduce_node import generate_final_report
from planning_ai.states import OverallState


def create_graph():
    graph = StateGraph(OverallState)
    # graph.add_node("add_entities", add_entities)
    graph.add_node("generate_summary", generate_summary)
    graph.add_node("check_hallucination", check_hallucination)
    graph.add_node("fix_hallucination", fix_hallucination)
    graph.add_node("generate_final_report", generate_final_report)

    # graph.add_edge(START, "add_entities")
    graph.add_conditional_edges(START, map_documents, ["generate_summary"])
    graph.add_conditional_edges("generate_summary", map_check, ["check_hallucination"])
    graph.add_conditional_edges("check_hallucination", map_fix, ["fix_hallucination"])
    graph.add_conditional_edges("fix_hallucination", map_check, ["check_hallucination"])

    graph.add_edge("check_hallucination", "generate_final_report")
    graph.add_edge("generate_final_report", END)

    return graph.compile()


def plot_mermaid():
    graph = create_graph()
    print(graph.get_graph().draw_mermaid())