File size: 4,818 Bytes
963aee4
 
2d44ef6
1af4802
2646ff0
13686ad
 
1af4802
963aee4
13686ad
 
82bbfd1
 
 
 
 
 
 
 
 
 
 
 
 
 
379aa81
1af4802
2646ff0
379aa81
1af4802
 
379aa81
1af4802
963aee4
 
 
2d44ef6
963aee4
1af4802
 
6aa18a5
379aa81
2646ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379aa81
2646ff0
 
 
 
 
13686ad
 
 
82bbfd1
 
 
 
 
 
 
 
 
 
 
 
 
963aee4
2d44ef6
963aee4
 
13686ad
963aee4
2d44ef6
963aee4
13686ad
963aee4
6aa18a5
2646ff0
 
 
 
 
 
 
 
 
 
 
 
 
1af4802
13686ad
 
1af4802
a3397bd
 
 
 
 
 
 
 
 
1af4802
82bbfd1
 
1af4802
a3397bd
 
 
 
 
 
 
 
 
13686ad
1af4802
 
 
13686ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from langgraph.types import Send

from planning_ai.chains.fix_chain import fix_chain
from planning_ai.chains.hallucination_chain import hallucination_chain
from planning_ai.logging import logger
from planning_ai.states import DocumentState, OverallState

MAX_ATTEMPTS = 3


def check_hallucination(state: DocumentState):
    """Checks for hallucinations in the summary of a document.

    This function uses the `hallucination_chain` to evaluate the summary of a document.
    If the hallucination score is 1, it indicates no hallucination, and the summary is
    considered fixed. If the iteration count exceeds 5, the process is terminated.

    Args:
        state (DocumentState): The current state of the document, including its summary
            and iteration count.

    Returns:
        dict: A dictionary containing either a list of fixed summaries or hallucinations
        that need to be addressed.
    """
    logger.info(f"Checking hallucinations for document {state['filename']}")

    if state["processed"] or (state["refinement_attempts"] >= MAX_ATTEMPTS):
        logger.error(f"Max attempts exceeded for document: {state['filename']}")
        return {"documents": [{**state, "failed": True, "processed": True}]}
    elif not state["is_hallucinated"]:
        logger.info(f"Finished processing document: {state['filename']}")
        return {"documents": [{**state, "processed": True}]}

    try:
        response = hallucination_chain.invoke(
            {"document": state["document"], "summary": state["summary"]}
        )
        is_hallucinated = response.score == 0
        refinement_attempts = state["refinement_attempts"] + 1
    except Exception as e:
        logger.error(f"Failed to decode JSON {state['filename']}: {e}")
        return {
            "documents": [
                {
                    **state,
                    "summary": "",
                    "refinement_attempts": 0,
                    "is_hallucinated": True,
                    "failed": True,
                    "processed": True,
                }
            ]
        }

    out = {
        **state,
        "hallucination": response,
        "refinement_attempts": refinement_attempts,
        "is_hallucinated": is_hallucinated,
    }
    logger.info(f"Hallucination for {state['filename']}: {is_hallucinated}")
    return (
        {"documents": [{**out, "processed": False}]}
        if is_hallucinated
        else {"documents": [{**out, "processed": True}]}
    )


def fix_hallucination(state: DocumentState):
    """Attempts to fix hallucinations in a document's summary.

    This function uses the `fix_chain` to correct hallucinations identified in a summary.
    The corrected summary is then updated in the document state.

    Args:
        state (DocumentState): The current state of the document, including its summary
            and hallucination details.

    Returns:
        dict: A dictionary containing the updated summaries after attempting to fix
        hallucinations.
    """
    logger.warning(f"Fixing hallucinations for document {state['filename']}")

    try:
        response = fix_chain.invoke(
            {
                "context": state["document"],
                "summary": state["summary"],
                "explanation": state["hallucination"].explanation,
            }
        )
    except Exception as e:
        logger.error(f"Failed to decode JSON {state['filename']}: {e}.")
        return {
            "documents": [
                {
                    **state,
                    "summary": "",
                    "refinement_attempts": 0,
                    "is_hallucinated": True,
                    "failed": True,
                    "processed": True,
                }
            ]
        }
    return {"documents": [{**state, "summary": response}]}


def map_check(state: OverallState):
    """Maps the check_hallucination function to each document in the overall state.

    Args:
        state (OverallState): The overall state containing multiple documents.

    Returns:
        list: A list of Send objects, each representing a request to check for hallucinations
        in a document.
    """
    return [Send("check_hallucination", doc) for doc in state["documents"]]


def map_fix(state: OverallState):
    """Maps the fix_hallucination function to each hallucinated document that is not processed.

    Args:
        state (OverallState): The overall state containing multiple documents.

    Returns:
        list: A list of Send objects, each representing a request to fix hallucinations
        in a document that is hallucinated and not yet processed.
    """
    return [
        Send("fix_hallucination", doc)
        for doc in state["documents"]
        if doc["is_hallucinated"] and not doc["processed"]
    ]