sheami / graph.py
vikramvasudevan's picture
Upload folder using huggingface_hub
c4b6dca verified
from datetime import datetime
import threading
import time
from bson import ObjectId
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, TypedDict, Union
import os, json
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.message import StateGraph
from langgraph.graph.state import START, END
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from common import get_db
from config import SheamiConfig
import logging
from modules.models import (
HealthReport,
SheamiMilestone,
SheamiState,
StandardizedReport,
TestResultReferenceRange,
)
from pdf_reader import pdf_bytes_to_text_ocr, pdf_to_text_ocr
from pdf_helper import generate_pdf
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
load_dotenv(override=True)
llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3)
# -----------------------------
# SCHEMA DEFINITIONS
# -----------------------------
from typing import Optional, List
from pydantic import BaseModel, Field
import re
def safe_filename(name: str) -> str:
# Replace spaces with underscores
name = name.replace(" ", "_")
# Replace any non-alphanumeric / dash / underscore with "_"
name = re.sub(r"[^A-Za-z0-9_\-]", "_", name)
# Collapse multiple underscores
name = re.sub(r"_+", "_", name)
return name.strip("_")
import dateutil.parser
def parse_any_date(date_str):
if not date_str or pd.isna(date_str):
return pd.NaT
try:
return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True)
except Exception:
return pd.NaT
# prompt template
testname_standardizer_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a medical assistant. Normalize lab test names."
"All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')."
"Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.",
),
(
"human",
"""Normalize the following lab test names to their standard medical equivalents.
Test names: {test_names}
""",
),
]
)
# chain = prompt → LLM → string
testname_standardizer_chain = testname_standardizer_prompt | llm
# -----------------------------
# GRAPH NODES
# -----------------------------
def send_message(state: SheamiState, msg: str, append: bool = True):
if append:
# append message
state["messages"].append(msg)
else:
# replace last message
state["messages"][-1] = msg
async def fn_init_node(state: SheamiState):
os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True)
if "messages" not in state:
state["messages"] = []
send_message(state=state, msg="Initializing ...")
send_message(state=state, msg="Files received for processing ...", append=False)
for idx, report in enumerate(state["uploaded_reports"]):
send_message(
state=state,
msg=f"{idx+1}. <span class='highlighted-text'>{report.report_file_name}</span>",
)
state["standardized_reports"] = []
state["trends_json"] = {}
state["pdf_path"] = ""
state["current_index"] = -1
state["units_processed"] = 0
state["units_total"] = 0
state["process_desc"] = ""
state["overall_units_processed"] = 0
state["overall_units_total"] = 6 # 6 steps totally
state["milestones"] = []
run_id = await get_db().start_run(
user_email=state["user_email"],
patient_id=state["patient_id"],
source_file_names=[
report.report_file_name for report in state["uploaded_reports"]
],
source_file_contents=[
report.report_contents for report in state["uploaded_reports"]
],
)
state["run_id"] = run_id
send_message(state=state, msg=f"Initialized run [<span class='highlighted-text'>{run_id}</span>]")
return state
async def reset_process_desc(state: SheamiState, process_desc: str):
# close previous milestone
if len(state["milestones"]) > 0:
state["milestones"][-1].status = "completed"
state["milestones"][-1].end_time = datetime.now()
await get_db().add_or_update_milestone(
run_id=state["run_id"],
milestone=state["milestones"][-1].step_name,
status="completed",
end=True,
)
state["process_desc"] = process_desc
state["milestones"].append(
SheamiMilestone(
step_name=state["process_desc"], status="started", start_time=datetime.now()
)
)
state["units_processed"] = 0
state["units_total"] = 0
await get_db().add_or_update_milestone(
run_id=state["run_id"], milestone=state["process_desc"]
)
return state
async def fn_increment_index_node(state: SheamiState):
state["current_index"] += 1
total_reports = len(state["uploaded_reports"])
try:
report_file_name = state["uploaded_reports"][
state["current_index"]
].report_file_name
state["process_desc"] = (
f"Standardizing {state["current_index"]+1} of {total_reports} reports - {report_file_name} ..."
)
except:
pass
return state
async def call_llm(report: HealthReport, ocr: bool):
llm_structured = llm.with_structured_output(StandardizedReport)
ocr_instructions = """
The input is pre-parsed structured text from an OCR engine (output.STRING).
- Each line corresponds to one recognized piece of text.
- Do NOT merge unrelated lines together.
- Use each line to reconstruct tests faithfully without skipping.
- Do not hallucinate results or ranges; only use what is explicitly present.
"""
system_msg = f"""
You are a medical report parser.
Your job is to convert the raw lab report text into the given schema.
Important:
- Do not omit any test mentioned in the report.
- Every test name in the input must appear in the output schema exactly once.
- If a test panel has multiple sub-tests, ensure ALL are included.
- If unsure about a value, still include the test with result = null.
{ocr_instructions if ocr else ""}
- If the report contains a test panel (e.g., 'CUE - COMPLETE URINE ANALYSIS'),
break it down into its component sub-tests (e.g., pH, Specific Gravity, Protein, Glucose, Ketones, etc).
- Each sub-test must appear as an individual entry in the schema, with its own name, result, unit, and reference range.
- Do not summarize a panel as just 'positive/negative'. Capture all sub-results explicitly.
- Preserve the hierarchy but ensure sub-tests are separate objects.
"""
messages = [
SystemMessage(content=system_msg),
HumanMessage(
content=f"""Original report file name: {report.report_file_name}
--- BEGIN REPORT ---
{report.report_contents}
--- END REPORT ---"""
),
]
result: StandardizedReport = await llm_structured.ainvoke(messages)
return result
async def fn_standardize_current_report_node(state: SheamiState):
idx = state["current_index"]
report = state["uploaded_reports"][idx]
logger.info(
"%s| Standardizing report %s", state["thread_id"], report.report_file_name
)
send_message(
state=state,
msg=f"Standardizing report: {report.report_file_name}",
append=False,
)
result = await call_llm(report=report, ocr=False)
if not result.lab_results:
send_message(
state=state,
msg=f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while",
append=False,
)
report.report_contents = pdf_to_text_ocr(
pdf_path=report.report_file_name_with_path
)
# logger.info("Parsed text using OCR: %s", report.report_contents)
run_stats_details = await get_db().get_run_stats_by_id(id=state["run_id"])
run_stats_details["source_file_contents"][state["current_index"]] = (
report.report_contents.replace("\\n", "\n")
)
await get_db().update_run_stats(
run_id=state["run_id"],
source_file_contents=run_stats_details["source_file_contents"],
)
result = await call_llm(report=report, ocr=True)
if not result.lab_results:
send_message(
state=state,
msg=f"⛔ OCR couldn't extract : {report.report_file_name}.",
append=False,
)
else:
send_message(
state=state,
msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results using OCR for report : <span class='highlighted-text'>{report.report_file_name}</span>.",
append=False,
)
else:
send_message(
state=state,
msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results from : <span class='highlighted-text'>{report.report_file_name}</span>.",
append=False,
)
state["standardized_reports"].append(result)
with open(
os.path.join(
SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json"
),
"w",
encoding="utf-8",
) as f:
f.write(result.model_dump_json(indent=2))
state["units_processed"] = idx + 1
return state
# edge
def fn_is_report_available_to_process(state: SheamiState) -> str:
if state["current_index"] < len(state["uploaded_reports"]):
report = state["uploaded_reports"][state["current_index"]]
send_message(
state=state,
msg=f"⏳ Initiating report standardization for: <span class='highlighted-text'>{report.report_file_name}</span>",
append=state["current_index"] > 0,
)
return "continue"
else:
send_message(state=state, msg="Standardizing reports: finished")
return "done"
def get_unique_test_names(state: SheamiState):
test_names = set()
for report in state["standardized_reports"]:
for result in report.lab_results:
if hasattr(result, "test_name"): # Normal LabResult
test_names.add(result.test_name)
elif hasattr(result, "sub_results"): # CompositeLabResult
for sub in result.sub_results:
if hasattr(sub, "test_name"):
test_names.add(sub.test_name)
return list(test_names)
async def fn_testname_standardizer_node(state: SheamiState):
logger.info("%s| Standardizing Test Names: started", state["thread_id"])
send_message(state=state, msg="Standardizing Test Names: started", append=False)
# collect unique names
unique_names = get_unique_test_names(state)
# run through LLM
response = await testname_standardizer_chain.ainvoke({"test_names": unique_names})
raw_text = response.content
try:
normalization_map: Dict[str, str] = json.loads(raw_text)
except Exception as e:
print("Exception in normalization: ", e)
normalization_map = {name: name for name in unique_names} # fallback
# apply mapping back
for report in state["standardized_reports"]:
for comp_result in report.lab_results:
# normalize composite-level name if present
if getattr(comp_result, "test_name", None):
comp_result.test_name = normalization_map.get(
comp_result.test_name, comp_result.test_name
)
# normalize sub_results
if getattr(comp_result, "sub_results", None):
for sub in comp_result.sub_results:
if getattr(sub, "test_name", None):
sub.test_name = normalization_map.get(
sub.test_name, sub.test_name
)
logger.info("%s| Standardizing Test Names: finished", state["thread_id"])
send_message(
state=state,
msg=f"Identified <span class='highlighted-text'>{len(unique_names)}</span> unique tests",
append=False,
)
# send_message(state=state, msg="Standardizing Test Names: finished")
return state
async def fn_unit_normalizer_node(state: SheamiState):
logger.info("%s| Standardizing Units : started", state["thread_id"])
send_message(state=state, msg="Standardizing Units: started", append=False)
"""
Normalize units for lab test values across all standardized reports.
Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL'
"""
unit_map = {
"g/dl": "g/dL",
"gms/dl": "g/dL",
"gm%": "g/dL",
"g/dl.": "g/dL",
}
for report in state["standardized_reports"]:
for lr in report.lab_results:
# case 1: simple result
if hasattr(lr, "test_unit") and lr.test_unit:
normalized = lr.test_unit.lower().replace(" ", "")
lr.test_unit = unit_map.get(normalized, lr.test_unit)
# case 2: composite result with sub_results
if hasattr(lr, "sub_results") and lr.sub_results:
for sub in lr.sub_results:
if sub.test_unit:
normalized = sub.test_unit.lower().replace(" ", "")
sub.test_unit = unit_map.get(normalized, sub.test_unit)
logger.info("%s| Standardizing Units : finished", state["thread_id"])
send_message(state=state, msg="Standardizing Units: finished", append=False)
return state
async def fn_db_update_node(state: SheamiState):
## add parsed reports
report_id_list = await get_db().add_report_v2(
patient_id=state["patient_id"],
reports=state["standardized_reports"],
run_id=state["run_id"],
)
state["report_id_list"] = report_id_list
logger.info("report_id_list = %s", report_id_list)
for report_id in report_id_list.split(","):
await get_db().aggregate_trends_from_report(state["patient_id"], report_id)
return state
async def fn_trends_aggregator_node(state: SheamiState):
logger.info("%s| Aggregating Trends : started", state["thread_id"])
send_message(state=state, msg="Aggregating Trends : started", append=False)
import re
import os
import json
# Aggregation buckets
trends: dict[str, list[dict]] = {}
ref_ranges: dict[str, dict] = {}
def try_parse_numeric(value) -> float | None:
"""
Return a float only for clean numeric strings like '75', '75.2', or '12%'.
Avoids picking '0' out of '0-2 /hpf' etc.
"""
if value is None:
return None
s = str(value).strip()
# pure number
if re.fullmatch(r"[-+]?\d+(?:\.\d+)?", s):
try:
return float(s)
except ValueError:
return None
# percent like "12%"
m = re.fullmatch(r"([-+]?\d+(?:\.\d+)?)\s*%", s)
if m:
try:
return float(m.group(1))
except ValueError:
return None
return None
def add_point(
key: str,
date: str | None,
value: str,
unit: str | None,
rr: TestResultReferenceRange | None,
original_report_file_name: str,
):
num = try_parse_numeric(value)
trends.setdefault(key, []).append(
{
"date": date or "unknown",
"value": num if num is not None else value,
"is_numeric": num is not None,
"unit": unit or "",
"orig_report": original_report_file_name,
}
)
if rr and key not in ref_ranges:
ref_ranges[key] = {"min": rr.min, "max": rr.max}
total_reports = len(state["standardized_reports"])
for idx, report in enumerate(state["standardized_reports"]):
logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx)
send_message(
state=state,
msg=f"Aggregating {idx+1}/{total_reports} trends : report-{idx+1}...",
append=False,
)
for item in report.lab_results:
# Case A: CompositeLabResult (e.g., CUE, LFT, etc.)
if hasattr(item, "sub_results") and item.sub_results:
panel = getattr(item, "section_name", "Panel")
for sub in item.sub_results:
key = f"{panel} · {sub.test_name}"
add_point(
key=key,
date=sub.test_date,
value=sub.result_value,
unit=sub.test_unit,
rr=sub.test_reference_range,
original_report_file_name=report.original_report_file_name,
)
# Case B: Simple LabResult
else:
key = item.test_name
add_point(
key=key,
date=item.test_date,
value=item.result_value,
unit=item.test_unit,
rr=item.test_reference_range,
original_report_file_name=report.original_report_file_name,
)
# Build trends JSON
state["trends_json"] = await get_db().get_trends_by_patient(
patient_id=state["patient_id"],
fields=["test_name", "trend_data", "test_reference_range", "inferred_range"],
serializable=True,
)
# Persist
output_dir = SheamiConfig.get_output_dir(state["thread_id"])
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "trends.json"), "w", encoding="utf-8") as f:
json.dump(state["trends_json"], f, indent=1, ensure_ascii=False)
logger.info("%s| Aggregating Trends : finished", state["thread_id"])
send_message(state=state, msg="Aggregating Trends : finished", append=False)
return state
async def fn_interpreter_node(state: SheamiState):
logger.info("%s| Interpreting Trends : started", state["thread_id"])
send_message(state=state, msg="Interpreting Trends : started", append=False)
uploaded_reports = await get_db().get_reports_by_patient(
patient_id=state["patient_id"]
)
llm_input = json.dumps(
{
"patient_id": state["patient_id"],
"patient_info": await get_db().get_patient_by_id(
patient_id=state["patient_id"],
fields=["name", "dob", "gender"],
serializable=True,
),
"uploaded_reports": [report["file_name"] for report in uploaded_reports],
"trends_json": state["trends_json"],
},
indent=1,
)
# logger.info("llm_input = %s", llm_input)
report_date = datetime.now().strftime("%d %B %Y") # e.g., "22 August 2025"
# 1. LLM narrative
messages = [
SystemMessage(
content=(
"Interpret the following medical trends and produce a clean, structured **HTML** report without any markdown formatting like backquotes etc. "
"The report should have: "
f"1. A header that says report generated on : {report_date}."
"2. The names of the reports used to summarize this information."
"3. Patient summary (patient id, name, age, sex if available)"
"4. Test window (mention the from and to dates)"
"""
5. Trend summaries
Generate tables with the following columns:
- Test Name
- Most Recent Value, Previous Value, Older Value (use a hyphen "–" if a value is missing). Use these exact column names (do not call them latest value 1,2,or 3)
- Unit
- Reference Range
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
"""
"6. Clinical insights. \n"
"\nImportant Rules:\n"
"- Format tables in proper <table> with <tr>, <th>, <td>. "
"- Do not include charts, they will be programmatically added."
"""
5. Trend summaries
Generate HTML tables with the following structure and formatting rules:
Columns:
- Test Name
- Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing)
- Unit
- Reference Range
- Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low)
- Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal
Formatting requirements:
- The HTML will be shown in a UI (`gr.HTML`) and also rendered to PDF via WeasyPrint.
- The table must ALWAYS fit within 100% of the container width. Do not allow horizontal scrolling, clipping, or overlapping columns.
- Use `table-layout: fixed;` and `<colgroup>` with percentage widths that sum to 100%.
- Allow text wrapping inside cells so narrow columns still display all content.
- Example CSS to embed at the top of the HTML:
<style>
table { width: 100%; border-collapse: collapse; table-layout: fixed; }
col { }
th, td {
font-size: 11px;
padding: 4px 6px;
white-space: normal;
word-break: break-word;
}
</style>
- Example `<colgroup>` (adjust if needed):
<colgroup>
<col style="width:20%"> <!-- Test Name -->
<col style="width:8%"> <!-- Latest Value 1 -->
<col style="width:8%"> <!-- Latest Value 2 -->
<col style="width:8%"> <!-- Latest Value 3 -->
<col style="width:8%"> <!-- Unit -->
<col style="width:16%"> <!-- Reference Range -->
<col style="width:16%"> <!-- Inference -->
<col style="width:16%"> <!-- Trend Direction -->
</colgroup>
"""
)
),
HumanMessage(content=llm_input),
]
response = await llm.ainvoke(messages)
interpretation_html = response.content # ✅ already HTML now
# 2. Generate plots for each parameter
plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots")
os.makedirs(plots_dir, exist_ok=True)
plot_files = []
for param in sorted(state["trends_json"], key=lambda x: x["test_name"]):
test_name = param["test_name"]
values = param["trend_data"]
x = [parse_any_date(v["date"]) for v in values]
x = pd.to_datetime(x, errors="coerce")
try:
y = [float(v["value"]) for v in values]
except ValueError:
continue # skip non-numeric
df_plot = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x"]).sort_values("x")
x, y = df_plot["x"].to_numpy(), df_plot["y"].to_numpy()
plt.figure(figsize=(6, 4))
plt.plot(x, y, marker="o", linestyle="-", label="Observed values")
ref = param.get("test_reference_range")
if ref:
ymin, ymax = ref.get("min"), ref.get("max")
if ymin is not None and ymax is not None:
plt.axhspan(
ymin, ymax, color="green", alpha=0.2, label="Reference range"
)
elif ymax is not None:
plt.axhline(
y=ymax, color="red", linestyle="--", label="Upper threshold"
)
elif ymin is not None:
plt.axhline(
y=ymin, color="blue", linestyle="--", label="Lower threshold"
)
# plt.title(f"{test_name} Trend")
plt.xlabel("Date")
plt.ylabel(values[0].get("unit", "") if values else "")
plt.grid(True)
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()
filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png"
filepath = os.path.join(plots_dir, filename)
plt.savefig(filepath)
plt.close()
plot_files.append((test_name, filepath))
# 3. Build PDF
pdf_path = os.path.join(
SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf"
)
generate_pdf(
pdf_path=pdf_path,
interpretation_html=interpretation_html, # ✅ HTML input
plot_files=plot_files,
)
# Save state
state["pdf_path"] = pdf_path
state["interpretation_html"] = interpretation_html
logger.info("%s| Interpreting Trends : finished", state["thread_id"])
send_message(state=state, msg="Interpreting Trends : finished", append=False)
return state
async def fn_final_cleanup_node(state: SheamiState):
pdf_path = state["pdf_path"]
schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"]))
state["milestones"][-1].status = "completed"
state["milestones"][-1].end_time = datetime.now()
await get_db().add_or_update_milestone(
run_id=state["run_id"],
milestone=state["process_desc"],
status="completed",
end=True,
)
await get_db().update_run_stats(run_id=state["run_id"], status="completed")
# add final report
# Save PDF along with metadata
with open(pdf_path, "rb") as f:
pdf_bytes = f.read()
final_report_id = await get_db().add_final_report_v2(
patient_id=state["patient_id"],
summary=state["interpretation_html"],
pdf_bytes=pdf_bytes,
file_name=f"health_trends_report_{state["patient_id"]}.pdf",
)
logger.info("final_report_id = %s", final_report_id)
def schedule_cleanup(file_path, delay=300): # 300 sec = 5 min
def cleanup():
time.sleep(delay)
if os.path.exists(file_path):
try:
if os.path.isdir(file_path):
import shutil
shutil.rmtree(file_path)
else:
os.remove(file_path)
print(f"Cleaned up: {file_path}")
except Exception as e:
print(f"Cleanup failed for {file_path}: {e}")
threading.Thread(target=cleanup, daemon=True).start()
# -----------------------------
# GRAPH CREATION
# -----------------------------
async def fn_standardizer_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Standardizing reports ...")
state["units_total"] = len(state["uploaded_reports"])
send_message(
state=state, msg="Standardizing reports now ... this might take a while ..."
)
state["overall_units_processed"] += 1
return state
async def fn_testname_standardizer_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Standardizing test names ...")
send_message(state=state, msg="Standardizing test names now ...")
state["overall_units_processed"] += 1
return state
async def fn_unit_normalizer_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Standardizing units ...")
send_message(state=state, msg="Standardizing measurement units now ...")
state["overall_units_processed"] += 1
return state
async def fn_trends_aggregator_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Aggregating trends ...")
send_message(state=state, msg="Aggregating trends now ...")
state["overall_units_processed"] += 1
return state
async def fn_interpreter_node_notifier(state: SheamiState):
state = await reset_process_desc(state, process_desc="Plotting trends ...")
send_message(state=state, msg="Interpreting and plotting trends now ...")
state["overall_units_processed"] += 1
return state
def create_graph(user_email: str, patient_id: str, thread_id: str):
logger.info(
"%s| Creating Graph : started for user:%s | patient:%s",
thread_id,
user_email,
patient_id,
)
memory = InMemorySaver()
workflow = StateGraph(SheamiState)
workflow.add_node("init", fn_init_node)
workflow.add_node("standardize_current_report", fn_standardize_current_report_node)
workflow.add_node("increment_index", fn_increment_index_node)
workflow.add_node("testname_standardizer", fn_testname_standardizer_node)
workflow.add_node("unit_normalizer", fn_unit_normalizer_node)
workflow.add_node("db_update_node", fn_db_update_node)
workflow.add_node("trends", fn_trends_aggregator_node)
workflow.add_node("interpreter", fn_interpreter_node)
workflow.add_node("standardizer_notifier", fn_standardizer_node_notifier)
workflow.add_node(
"testname_standardizer_notifier", fn_testname_standardizer_node_notifier
)
workflow.add_node("unit_normalizer_notifier", fn_unit_normalizer_node_notifier)
workflow.add_node("trends_notifier", fn_trends_aggregator_node_notifier)
workflow.add_node("interpreter_notifier", fn_interpreter_node_notifier)
workflow.add_node("final_cleanup_node", fn_final_cleanup_node)
workflow.add_edge(START, "init")
workflow.add_edge("init", "standardizer_notifier")
workflow.add_edge("standardizer_notifier", "increment_index")
# loop back if continue
workflow.add_conditional_edges(
"increment_index",
fn_is_report_available_to_process,
{
"continue": "standardize_current_report",
"done": "testname_standardizer_notifier",
},
)
workflow.add_edge("standardize_current_report", "increment_index")
workflow.add_edge("testname_standardizer_notifier", "testname_standardizer")
workflow.add_edge("testname_standardizer", "unit_normalizer_notifier")
workflow.add_edge("unit_normalizer_notifier", "unit_normalizer")
workflow.add_edge("unit_normalizer", "db_update_node")
workflow.add_edge("db_update_node", "trends_notifier")
workflow.add_edge("trends_notifier", "trends")
workflow.add_edge("trends", "interpreter_notifier")
workflow.add_edge("interpreter_notifier", "interpreter")
workflow.add_edge("interpreter", "final_cleanup_node")
workflow.add_edge("final_cleanup_node", END)
logger.info("%s| Creating Graph : finished", thread_id)
return workflow.compile(checkpointer=memory)