planning-ai / planning_ai /nodes /hallucination_node.py
cjber's picture
fix: allow for specifying chapters
2d44ef6
raw
history blame contribute delete
4.82 kB
from langgraph.types import Send
from planning_ai.chains.fix_chain import fix_chain
from planning_ai.chains.hallucination_chain import hallucination_chain
from planning_ai.logging import logger
from planning_ai.states import DocumentState, OverallState
MAX_ATTEMPTS = 3
def check_hallucination(state: DocumentState):
"""Checks for hallucinations in the summary of a document.
This function uses the `hallucination_chain` to evaluate the summary of a document.
If the hallucination score is 1, it indicates no hallucination, and the summary is
considered fixed. If the iteration count exceeds 5, the process is terminated.
Args:
state (DocumentState): The current state of the document, including its summary
and iteration count.
Returns:
dict: A dictionary containing either a list of fixed summaries or hallucinations
that need to be addressed.
"""
logger.info(f"Checking hallucinations for document {state['filename']}")
if state["processed"] or (state["refinement_attempts"] >= MAX_ATTEMPTS):
logger.error(f"Max attempts exceeded for document: {state['filename']}")
return {"documents": [{**state, "failed": True, "processed": True}]}
elif not state["is_hallucinated"]:
logger.info(f"Finished processing document: {state['filename']}")
return {"documents": [{**state, "processed": True}]}
try:
response = hallucination_chain.invoke(
{"document": state["document"], "summary": state["summary"]}
)
is_hallucinated = response.score == 0
refinement_attempts = state["refinement_attempts"] + 1
except Exception as e:
logger.error(f"Failed to decode JSON {state['filename']}: {e}")
return {
"documents": [
{
**state,
"summary": "",
"refinement_attempts": 0,
"is_hallucinated": True,
"failed": True,
"processed": True,
}
]
}
out = {
**state,
"hallucination": response,
"refinement_attempts": refinement_attempts,
"is_hallucinated": is_hallucinated,
}
logger.info(f"Hallucination for {state['filename']}: {is_hallucinated}")
return (
{"documents": [{**out, "processed": False}]}
if is_hallucinated
else {"documents": [{**out, "processed": True}]}
)
def fix_hallucination(state: DocumentState):
"""Attempts to fix hallucinations in a document's summary.
This function uses the `fix_chain` to correct hallucinations identified in a summary.
The corrected summary is then updated in the document state.
Args:
state (DocumentState): The current state of the document, including its summary
and hallucination details.
Returns:
dict: A dictionary containing the updated summaries after attempting to fix
hallucinations.
"""
logger.warning(f"Fixing hallucinations for document {state['filename']}")
try:
response = fix_chain.invoke(
{
"context": state["document"],
"summary": state["summary"],
"explanation": state["hallucination"].explanation,
}
)
except Exception as e:
logger.error(f"Failed to decode JSON {state['filename']}: {e}.")
return {
"documents": [
{
**state,
"summary": "",
"refinement_attempts": 0,
"is_hallucinated": True,
"failed": True,
"processed": True,
}
]
}
return {"documents": [{**state, "summary": response}]}
def map_check(state: OverallState):
"""Maps the check_hallucination function to each document in the overall state.
Args:
state (OverallState): The overall state containing multiple documents.
Returns:
list: A list of Send objects, each representing a request to check for hallucinations
in a document.
"""
return [Send("check_hallucination", doc) for doc in state["documents"]]
def map_fix(state: OverallState):
"""Maps the fix_hallucination function to each hallucinated document that is not processed.
Args:
state (OverallState): The overall state containing multiple documents.
Returns:
list: A list of Send objects, each representing a request to fix hallucinations
in a document that is hallucinated and not yet processed.
"""
return [
Send("fix_hallucination", doc)
for doc in state["documents"]
if doc["is_hallucinated"] and not doc["processed"]
]