Spaces:
Running
Running
File size: 4,818 Bytes
963aee4 2d44ef6 1af4802 2646ff0 13686ad 1af4802 963aee4 13686ad 82bbfd1 379aa81 1af4802 2646ff0 379aa81 1af4802 379aa81 1af4802 963aee4 2d44ef6 963aee4 1af4802 6aa18a5 379aa81 2646ff0 379aa81 2646ff0 13686ad 82bbfd1 963aee4 2d44ef6 963aee4 13686ad 963aee4 2d44ef6 963aee4 13686ad 963aee4 6aa18a5 2646ff0 1af4802 13686ad 1af4802 a3397bd 1af4802 82bbfd1 1af4802 a3397bd 13686ad 1af4802 13686ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from langgraph.types import Send
from planning_ai.chains.fix_chain import fix_chain
from planning_ai.chains.hallucination_chain import hallucination_chain
from planning_ai.logging import logger
from planning_ai.states import DocumentState, OverallState
MAX_ATTEMPTS = 3
def check_hallucination(state: DocumentState):
"""Checks for hallucinations in the summary of a document.
This function uses the `hallucination_chain` to evaluate the summary of a document.
If the hallucination score is 1, it indicates no hallucination, and the summary is
considered fixed. If the iteration count exceeds 5, the process is terminated.
Args:
state (DocumentState): The current state of the document, including its summary
and iteration count.
Returns:
dict: A dictionary containing either a list of fixed summaries or hallucinations
that need to be addressed.
"""
logger.info(f"Checking hallucinations for document {state['filename']}")
if state["processed"] or (state["refinement_attempts"] >= MAX_ATTEMPTS):
logger.error(f"Max attempts exceeded for document: {state['filename']}")
return {"documents": [{**state, "failed": True, "processed": True}]}
elif not state["is_hallucinated"]:
logger.info(f"Finished processing document: {state['filename']}")
return {"documents": [{**state, "processed": True}]}
try:
response = hallucination_chain.invoke(
{"document": state["document"], "summary": state["summary"]}
)
is_hallucinated = response.score == 0
refinement_attempts = state["refinement_attempts"] + 1
except Exception as e:
logger.error(f"Failed to decode JSON {state['filename']}: {e}")
return {
"documents": [
{
**state,
"summary": "",
"refinement_attempts": 0,
"is_hallucinated": True,
"failed": True,
"processed": True,
}
]
}
out = {
**state,
"hallucination": response,
"refinement_attempts": refinement_attempts,
"is_hallucinated": is_hallucinated,
}
logger.info(f"Hallucination for {state['filename']}: {is_hallucinated}")
return (
{"documents": [{**out, "processed": False}]}
if is_hallucinated
else {"documents": [{**out, "processed": True}]}
)
def fix_hallucination(state: DocumentState):
"""Attempts to fix hallucinations in a document's summary.
This function uses the `fix_chain` to correct hallucinations identified in a summary.
The corrected summary is then updated in the document state.
Args:
state (DocumentState): The current state of the document, including its summary
and hallucination details.
Returns:
dict: A dictionary containing the updated summaries after attempting to fix
hallucinations.
"""
logger.warning(f"Fixing hallucinations for document {state['filename']}")
try:
response = fix_chain.invoke(
{
"context": state["document"],
"summary": state["summary"],
"explanation": state["hallucination"].explanation,
}
)
except Exception as e:
logger.error(f"Failed to decode JSON {state['filename']}: {e}.")
return {
"documents": [
{
**state,
"summary": "",
"refinement_attempts": 0,
"is_hallucinated": True,
"failed": True,
"processed": True,
}
]
}
return {"documents": [{**state, "summary": response}]}
def map_check(state: OverallState):
"""Maps the check_hallucination function to each document in the overall state.
Args:
state (OverallState): The overall state containing multiple documents.
Returns:
list: A list of Send objects, each representing a request to check for hallucinations
in a document.
"""
return [Send("check_hallucination", doc) for doc in state["documents"]]
def map_fix(state: OverallState):
"""Maps the fix_hallucination function to each hallucinated document that is not processed.
Args:
state (OverallState): The overall state containing multiple documents.
Returns:
list: A list of Send objects, each representing a request to fix hallucinations
in a document that is hallucinated and not yet processed.
"""
return [
Send("fix_hallucination", doc)
for doc in state["documents"]
if doc["is_hallucinated"] and not doc["processed"]
]
|