| | from datetime import datetime |
| | import threading |
| | import time |
| | from bson import ObjectId |
| | import pandas as pd |
| | from langchain_core.prompts import ChatPromptTemplate |
| | import matplotlib.pyplot as plt |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Literal, Optional, TypedDict, Union |
| | import os, json |
| | from pydantic import BaseModel |
| | from langchain_core.messages import HumanMessage, SystemMessage |
| | from langgraph.checkpoint.memory import InMemorySaver |
| | from langgraph.graph.message import StateGraph |
| | from langgraph.graph.state import START, END |
| | from langchain_openai import ChatOpenAI |
| | from dotenv import load_dotenv |
| | from common import get_db |
| | from config import SheamiConfig |
| | import logging |
| |
|
| | from modules.models import ( |
| | HealthReport, |
| | SheamiMilestone, |
| | SheamiState, |
| | StandardizedReport, |
| | TestResultReferenceRange, |
| | ) |
| | from pdf_reader import pdf_bytes_to_text_ocr, pdf_to_text_ocr |
| | from pdf_helper import generate_pdf |
| |
|
| | logging.basicConfig() |
| | logger = logging.getLogger(__name__) |
| | logger.setLevel(logging.INFO) |
| |
|
| | load_dotenv(override=True) |
| | llm = ChatOpenAI(model=os.getenv("MODEL"), temperature=0.3) |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | from typing import Optional, List |
| | from pydantic import BaseModel, Field |
| | import re |
| |
|
| |
|
| | def safe_filename(name: str) -> str: |
| | |
| | name = name.replace(" ", "_") |
| | |
| | name = re.sub(r"[^A-Za-z0-9_\-]", "_", name) |
| | |
| | name = re.sub(r"_+", "_", name) |
| | return name.strip("_") |
| |
|
| |
|
| | import dateutil.parser |
| |
|
| |
|
| | def parse_any_date(date_str): |
| | if not date_str or pd.isna(date_str): |
| | return pd.NaT |
| | try: |
| | return dateutil.parser.parse(str(date_str), dayfirst=False, fuzzy=True) |
| | except Exception: |
| | return pd.NaT |
| |
|
| |
|
| | |
| | testname_standardizer_prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | ( |
| | "system", |
| | "You are a medical assistant. Normalize lab test names." |
| | "All outputs must use **title case** (e.g., 'Hemoglobin', 'Blood Glucose')." |
| | "Return ONLY valid JSON where keys are original names and values are standardized names. DO NOT return markdown formatting like backquotes etc.", |
| | ), |
| | ( |
| | "human", |
| | """Normalize the following lab test names to their standard medical equivalents. |
| | Test names: {test_names} |
| | """, |
| | ), |
| | ] |
| | ) |
| |
|
| | |
| | testname_standardizer_chain = testname_standardizer_prompt | llm |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def send_message(state: SheamiState, msg: str, append: bool = True): |
| | if append: |
| | |
| | state["messages"].append(msg) |
| | else: |
| | |
| | state["messages"][-1] = msg |
| |
|
| |
|
| | async def fn_init_node(state: SheamiState): |
| | os.makedirs(SheamiConfig.get_output_dir(state["thread_id"]), exist_ok=True) |
| | if "messages" not in state: |
| | state["messages"] = [] |
| | send_message(state=state, msg="Initializing ...") |
| | send_message(state=state, msg="Files received for processing ...", append=False) |
| | for idx, report in enumerate(state["uploaded_reports"]): |
| | send_message( |
| | state=state, |
| | msg=f"{idx+1}. <span class='highlighted-text'>{report.report_file_name}</span>", |
| | ) |
| | state["standardized_reports"] = [] |
| | state["trends_json"] = {} |
| | state["pdf_path"] = "" |
| | state["current_index"] = -1 |
| | state["units_processed"] = 0 |
| | state["units_total"] = 0 |
| | state["process_desc"] = "" |
| | state["overall_units_processed"] = 0 |
| | state["overall_units_total"] = 6 |
| | state["milestones"] = [] |
| |
|
| | run_id = await get_db().start_run( |
| | user_email=state["user_email"], |
| | patient_id=state["patient_id"], |
| | source_file_names=[ |
| | report.report_file_name for report in state["uploaded_reports"] |
| | ], |
| | source_file_contents=[ |
| | report.report_contents for report in state["uploaded_reports"] |
| | ], |
| | ) |
| | state["run_id"] = run_id |
| | send_message(state=state, msg=f"Initialized run [<span class='highlighted-text'>{run_id}</span>]") |
| | return state |
| |
|
| |
|
| | async def reset_process_desc(state: SheamiState, process_desc: str): |
| | |
| | if len(state["milestones"]) > 0: |
| | state["milestones"][-1].status = "completed" |
| | state["milestones"][-1].end_time = datetime.now() |
| | await get_db().add_or_update_milestone( |
| | run_id=state["run_id"], |
| | milestone=state["milestones"][-1].step_name, |
| | status="completed", |
| | end=True, |
| | ) |
| | state["process_desc"] = process_desc |
| | state["milestones"].append( |
| | SheamiMilestone( |
| | step_name=state["process_desc"], status="started", start_time=datetime.now() |
| | ) |
| | ) |
| | state["units_processed"] = 0 |
| | state["units_total"] = 0 |
| | await get_db().add_or_update_milestone( |
| | run_id=state["run_id"], milestone=state["process_desc"] |
| | ) |
| | return state |
| |
|
| |
|
| | async def fn_increment_index_node(state: SheamiState): |
| | state["current_index"] += 1 |
| | total_reports = len(state["uploaded_reports"]) |
| | try: |
| | report_file_name = state["uploaded_reports"][ |
| | state["current_index"] |
| | ].report_file_name |
| | state["process_desc"] = ( |
| | f"Standardizing {state["current_index"]+1} of {total_reports} reports - {report_file_name} ..." |
| | ) |
| | except: |
| | pass |
| | return state |
| |
|
| |
|
| | async def call_llm(report: HealthReport, ocr: bool): |
| | llm_structured = llm.with_structured_output(StandardizedReport) |
| | ocr_instructions = """ |
| | The input is pre-parsed structured text from an OCR engine (output.STRING). |
| | - Each line corresponds to one recognized piece of text. |
| | - Do NOT merge unrelated lines together. |
| | - Use each line to reconstruct tests faithfully without skipping. |
| | - Do not hallucinate results or ranges; only use what is explicitly present. |
| | """ |
| |
|
| | system_msg = f""" |
| | You are a medical report parser. |
| | Your job is to convert the raw lab report text into the given schema. |
| | |
| | Important: |
| | - Do not omit any test mentioned in the report. |
| | - Every test name in the input must appear in the output schema exactly once. |
| | - If a test panel has multiple sub-tests, ensure ALL are included. |
| | - If unsure about a value, still include the test with result = null. |
| | {ocr_instructions if ocr else ""} |
| | - If the report contains a test panel (e.g., 'CUE - COMPLETE URINE ANALYSIS'), |
| | break it down into its component sub-tests (e.g., pH, Specific Gravity, Protein, Glucose, Ketones, etc). |
| | - Each sub-test must appear as an individual entry in the schema, with its own name, result, unit, and reference range. |
| | - Do not summarize a panel as just 'positive/negative'. Capture all sub-results explicitly. |
| | - Preserve the hierarchy but ensure sub-tests are separate objects. |
| | """ |
| |
|
| | messages = [ |
| | SystemMessage(content=system_msg), |
| | HumanMessage( |
| | content=f"""Original report file name: {report.report_file_name} |
| | --- BEGIN REPORT --- |
| | {report.report_contents} |
| | --- END REPORT ---""" |
| | ), |
| | ] |
| | result: StandardizedReport = await llm_structured.ainvoke(messages) |
| | return result |
| |
|
| |
|
| | async def fn_standardize_current_report_node(state: SheamiState): |
| | idx = state["current_index"] |
| | report = state["uploaded_reports"][idx] |
| |
|
| | logger.info( |
| | "%s| Standardizing report %s", state["thread_id"], report.report_file_name |
| | ) |
| | send_message( |
| | state=state, |
| | msg=f"Standardizing report: {report.report_file_name}", |
| | append=False, |
| | ) |
| |
|
| | result = await call_llm(report=report, ocr=False) |
| | if not result.lab_results: |
| | send_message( |
| | state=state, |
| | msg=f"⛔ Could not extract any data from PDF : {report.report_file_name}. Trying OCR ... might take a while", |
| | append=False, |
| | ) |
| | report.report_contents = pdf_to_text_ocr( |
| | pdf_path=report.report_file_name_with_path |
| | ) |
| | |
| | run_stats_details = await get_db().get_run_stats_by_id(id=state["run_id"]) |
| | run_stats_details["source_file_contents"][state["current_index"]] = ( |
| | report.report_contents.replace("\\n", "\n") |
| | ) |
| | await get_db().update_run_stats( |
| | run_id=state["run_id"], |
| | source_file_contents=run_stats_details["source_file_contents"], |
| | ) |
| | result = await call_llm(report=report, ocr=True) |
| | if not result.lab_results: |
| | send_message( |
| | state=state, |
| | msg=f"⛔ OCR couldn't extract : {report.report_file_name}.", |
| | append=False, |
| | ) |
| | else: |
| | send_message( |
| | state=state, |
| | msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results using OCR for report : <span class='highlighted-text'>{report.report_file_name}</span>.", |
| | append=False, |
| | ) |
| | else: |
| | send_message( |
| | state=state, |
| | msg=f"✅ Extracted <span class='highlighted-text'>{len(result.lab_results)}</span> lab results from : <span class='highlighted-text'>{report.report_file_name}</span>.", |
| | append=False, |
| | ) |
| |
|
| | state["standardized_reports"].append(result) |
| |
|
| | with open( |
| | os.path.join( |
| | SheamiConfig.get_output_dir(state["thread_id"]), f"report_{idx}.json" |
| | ), |
| | "w", |
| | encoding="utf-8", |
| | ) as f: |
| | f.write(result.model_dump_json(indent=2)) |
| |
|
| | state["units_processed"] = idx + 1 |
| | return state |
| |
|
| |
|
| | |
| | def fn_is_report_available_to_process(state: SheamiState) -> str: |
| | if state["current_index"] < len(state["uploaded_reports"]): |
| | report = state["uploaded_reports"][state["current_index"]] |
| | send_message( |
| | state=state, |
| | msg=f"⏳ Initiating report standardization for: <span class='highlighted-text'>{report.report_file_name}</span>", |
| | append=state["current_index"] > 0, |
| | ) |
| | return "continue" |
| | else: |
| | send_message(state=state, msg="Standardizing reports: finished") |
| | return "done" |
| |
|
| |
|
| | def get_unique_test_names(state: SheamiState): |
| | test_names = set() |
| |
|
| | for report in state["standardized_reports"]: |
| | for result in report.lab_results: |
| | if hasattr(result, "test_name"): |
| | test_names.add(result.test_name) |
| | elif hasattr(result, "sub_results"): |
| | for sub in result.sub_results: |
| | if hasattr(sub, "test_name"): |
| | test_names.add(sub.test_name) |
| |
|
| | return list(test_names) |
| |
|
| |
|
| | async def fn_testname_standardizer_node(state: SheamiState): |
| | logger.info("%s| Standardizing Test Names: started", state["thread_id"]) |
| | send_message(state=state, msg="Standardizing Test Names: started", append=False) |
| |
|
| | |
| | unique_names = get_unique_test_names(state) |
| |
|
| | |
| | response = await testname_standardizer_chain.ainvoke({"test_names": unique_names}) |
| | raw_text = response.content |
| |
|
| | try: |
| | normalization_map: Dict[str, str] = json.loads(raw_text) |
| | except Exception as e: |
| | print("Exception in normalization: ", e) |
| | normalization_map = {name: name for name in unique_names} |
| |
|
| | |
| | for report in state["standardized_reports"]: |
| | for comp_result in report.lab_results: |
| | |
| | if getattr(comp_result, "test_name", None): |
| | comp_result.test_name = normalization_map.get( |
| | comp_result.test_name, comp_result.test_name |
| | ) |
| | |
| | if getattr(comp_result, "sub_results", None): |
| | for sub in comp_result.sub_results: |
| | if getattr(sub, "test_name", None): |
| | sub.test_name = normalization_map.get( |
| | sub.test_name, sub.test_name |
| | ) |
| |
|
| | logger.info("%s| Standardizing Test Names: finished", state["thread_id"]) |
| | send_message( |
| | state=state, |
| | msg=f"Identified <span class='highlighted-text'>{len(unique_names)}</span> unique tests", |
| | append=False, |
| | ) |
| | |
| | return state |
| |
|
| |
|
| | async def fn_unit_normalizer_node(state: SheamiState): |
| | logger.info("%s| Standardizing Units : started", state["thread_id"]) |
| | send_message(state=state, msg="Standardizing Units: started", append=False) |
| | """ |
| | Normalize units for lab test values across all standardized reports. |
| | Example: 'gms/dL', 'gm%', 'G/DL' → 'g/dL' |
| | """ |
| | unit_map = { |
| | "g/dl": "g/dL", |
| | "gms/dl": "g/dL", |
| | "gm%": "g/dL", |
| | "g/dl.": "g/dL", |
| | } |
| |
|
| | for report in state["standardized_reports"]: |
| | for lr in report.lab_results: |
| | |
| | if hasattr(lr, "test_unit") and lr.test_unit: |
| | normalized = lr.test_unit.lower().replace(" ", "") |
| | lr.test_unit = unit_map.get(normalized, lr.test_unit) |
| |
|
| | |
| | if hasattr(lr, "sub_results") and lr.sub_results: |
| | for sub in lr.sub_results: |
| | if sub.test_unit: |
| | normalized = sub.test_unit.lower().replace(" ", "") |
| | sub.test_unit = unit_map.get(normalized, sub.test_unit) |
| |
|
| | logger.info("%s| Standardizing Units : finished", state["thread_id"]) |
| | send_message(state=state, msg="Standardizing Units: finished", append=False) |
| | return state |
| |
|
| |
|
| | async def fn_db_update_node(state: SheamiState): |
| | |
| | report_id_list = await get_db().add_report_v2( |
| | patient_id=state["patient_id"], |
| | reports=state["standardized_reports"], |
| | run_id=state["run_id"], |
| | ) |
| | state["report_id_list"] = report_id_list |
| |
|
| | logger.info("report_id_list = %s", report_id_list) |
| | for report_id in report_id_list.split(","): |
| | await get_db().aggregate_trends_from_report(state["patient_id"], report_id) |
| |
|
| | return state |
| |
|
| |
|
| | async def fn_trends_aggregator_node(state: SheamiState): |
| | logger.info("%s| Aggregating Trends : started", state["thread_id"]) |
| | send_message(state=state, msg="Aggregating Trends : started", append=False) |
| |
|
| | import re |
| | import os |
| | import json |
| |
|
| | |
| | trends: dict[str, list[dict]] = {} |
| | ref_ranges: dict[str, dict] = {} |
| |
|
| | def try_parse_numeric(value) -> float | None: |
| | """ |
| | Return a float only for clean numeric strings like '75', '75.2', or '12%'. |
| | Avoids picking '0' out of '0-2 /hpf' etc. |
| | """ |
| | if value is None: |
| | return None |
| | s = str(value).strip() |
| | |
| | if re.fullmatch(r"[-+]?\d+(?:\.\d+)?", s): |
| | try: |
| | return float(s) |
| | except ValueError: |
| | return None |
| | |
| | m = re.fullmatch(r"([-+]?\d+(?:\.\d+)?)\s*%", s) |
| | if m: |
| | try: |
| | return float(m.group(1)) |
| | except ValueError: |
| | return None |
| | return None |
| |
|
| | def add_point( |
| | key: str, |
| | date: str | None, |
| | value: str, |
| | unit: str | None, |
| | rr: TestResultReferenceRange | None, |
| | original_report_file_name: str, |
| | ): |
| | num = try_parse_numeric(value) |
| | trends.setdefault(key, []).append( |
| | { |
| | "date": date or "unknown", |
| | "value": num if num is not None else value, |
| | "is_numeric": num is not None, |
| | "unit": unit or "", |
| | "orig_report": original_report_file_name, |
| | } |
| | ) |
| | if rr and key not in ref_ranges: |
| | ref_ranges[key] = {"min": rr.min, "max": rr.max} |
| |
|
| | total_reports = len(state["standardized_reports"]) |
| | for idx, report in enumerate(state["standardized_reports"]): |
| | logger.info("%s| Aggregating Trends for report-%d", state["thread_id"], idx) |
| | send_message( |
| | state=state, |
| | msg=f"Aggregating {idx+1}/{total_reports} trends : report-{idx+1}...", |
| | append=False, |
| | ) |
| |
|
| | for item in report.lab_results: |
| | |
| | if hasattr(item, "sub_results") and item.sub_results: |
| | panel = getattr(item, "section_name", "Panel") |
| | for sub in item.sub_results: |
| | key = f"{panel} · {sub.test_name}" |
| | add_point( |
| | key=key, |
| | date=sub.test_date, |
| | value=sub.result_value, |
| | unit=sub.test_unit, |
| | rr=sub.test_reference_range, |
| | original_report_file_name=report.original_report_file_name, |
| | ) |
| | |
| | else: |
| | key = item.test_name |
| | add_point( |
| | key=key, |
| | date=item.test_date, |
| | value=item.result_value, |
| | unit=item.test_unit, |
| | rr=item.test_reference_range, |
| | original_report_file_name=report.original_report_file_name, |
| | ) |
| |
|
| | |
| | state["trends_json"] = await get_db().get_trends_by_patient( |
| | patient_id=state["patient_id"], |
| | fields=["test_name", "trend_data", "test_reference_range", "inferred_range"], |
| | serializable=True, |
| | ) |
| |
|
| | |
| | output_dir = SheamiConfig.get_output_dir(state["thread_id"]) |
| | os.makedirs(output_dir, exist_ok=True) |
| | with open(os.path.join(output_dir, "trends.json"), "w", encoding="utf-8") as f: |
| | json.dump(state["trends_json"], f, indent=1, ensure_ascii=False) |
| |
|
| | logger.info("%s| Aggregating Trends : finished", state["thread_id"]) |
| | send_message(state=state, msg="Aggregating Trends : finished", append=False) |
| | return state |
| |
|
| |
|
| | async def fn_interpreter_node(state: SheamiState): |
| | logger.info("%s| Interpreting Trends : started", state["thread_id"]) |
| | send_message(state=state, msg="Interpreting Trends : started", append=False) |
| |
|
| | uploaded_reports = await get_db().get_reports_by_patient( |
| | patient_id=state["patient_id"] |
| | ) |
| | llm_input = json.dumps( |
| | { |
| | "patient_id": state["patient_id"], |
| | "patient_info": await get_db().get_patient_by_id( |
| | patient_id=state["patient_id"], |
| | fields=["name", "dob", "gender"], |
| | serializable=True, |
| | ), |
| | "uploaded_reports": [report["file_name"] for report in uploaded_reports], |
| | "trends_json": state["trends_json"], |
| | }, |
| | indent=1, |
| | ) |
| |
|
| | |
| | report_date = datetime.now().strftime("%d %B %Y") |
| |
|
| | |
| | messages = [ |
| | SystemMessage( |
| | content=( |
| | "Interpret the following medical trends and produce a clean, structured **HTML** report without any markdown formatting like backquotes etc. " |
| | "The report should have: " |
| | f"1. A header that says report generated on : {report_date}." |
| | "2. The names of the reports used to summarize this information." |
| | "3. Patient summary (patient id, name, age, sex if available)" |
| | "4. Test window (mention the from and to dates)" |
| | """ |
| | 5. Trend summaries |
| | Generate tables with the following columns: |
| | |
| | - Test Name |
| | - Most Recent Value, Previous Value, Older Value (use a hyphen "–" if a value is missing). Use these exact column names (do not call them latest value 1,2,or 3) |
| | - Unit |
| | - Reference Range |
| | - Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low) |
| | - Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal |
| | """ |
| | "6. Clinical insights. \n" |
| | "\nImportant Rules:\n" |
| | "- Format tables in proper <table> with <tr>, <th>, <td>. " |
| | "- Do not include charts, they will be programmatically added." |
| | """ |
| | 5. Trend summaries |
| | Generate HTML tables with the following structure and formatting rules: |
| | |
| | Columns: |
| | - Test Name |
| | - Latest Value 1, Latest Value 2, Latest Value 3 (use a hyphen "–" if a value is missing) |
| | - Unit |
| | - Reference Range |
| | - Inference (latest value only): ✅ if within normal range, ▲ if above normal (high), ▼ if below normal (low) |
| | - Trend Direction (across last 3 values): ⬆️ if values are rising, ⬇️ if values are falling, ➖ (or ✅) if stable/normal |
| | |
| | Formatting requirements: |
| | - The HTML will be shown in a UI (`gr.HTML`) and also rendered to PDF via WeasyPrint. |
| | - The table must ALWAYS fit within 100% of the container width. Do not allow horizontal scrolling, clipping, or overlapping columns. |
| | - Use `table-layout: fixed;` and `<colgroup>` with percentage widths that sum to 100%. |
| | - Allow text wrapping inside cells so narrow columns still display all content. |
| | - Example CSS to embed at the top of the HTML: |
| | |
| | <style> |
| | table { width: 100%; border-collapse: collapse; table-layout: fixed; } |
| | col { } |
| | th, td { |
| | font-size: 11px; |
| | padding: 4px 6px; |
| | white-space: normal; |
| | word-break: break-word; |
| | } |
| | </style> |
| | |
| | - Example `<colgroup>` (adjust if needed): |
| | <colgroup> |
| | <col style="width:20%"> <!-- Test Name --> |
| | <col style="width:8%"> <!-- Latest Value 1 --> |
| | <col style="width:8%"> <!-- Latest Value 2 --> |
| | <col style="width:8%"> <!-- Latest Value 3 --> |
| | <col style="width:8%"> <!-- Unit --> |
| | <col style="width:16%"> <!-- Reference Range --> |
| | <col style="width:16%"> <!-- Inference --> |
| | <col style="width:16%"> <!-- Trend Direction --> |
| | </colgroup> |
| | """ |
| | ) |
| | ), |
| | HumanMessage(content=llm_input), |
| | ] |
| | response = await llm.ainvoke(messages) |
| | interpretation_html = response.content |
| |
|
| | |
| | plots_dir = os.path.join(SheamiConfig.get_output_dir(state["thread_id"]), "plots") |
| | os.makedirs(plots_dir, exist_ok=True) |
| | plot_files = [] |
| |
|
| | for param in sorted(state["trends_json"], key=lambda x: x["test_name"]): |
| | test_name = param["test_name"] |
| | values = param["trend_data"] |
| |
|
| | x = [parse_any_date(v["date"]) for v in values] |
| | x = pd.to_datetime(x, errors="coerce") |
| |
|
| | try: |
| | y = [float(v["value"]) for v in values] |
| | except ValueError: |
| | continue |
| |
|
| | df_plot = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x"]).sort_values("x") |
| | x, y = df_plot["x"].to_numpy(), df_plot["y"].to_numpy() |
| |
|
| | plt.figure(figsize=(6, 4)) |
| | plt.plot(x, y, marker="o", linestyle="-", label="Observed values") |
| |
|
| | ref = param.get("test_reference_range") |
| | if ref: |
| | ymin, ymax = ref.get("min"), ref.get("max") |
| | if ymin is not None and ymax is not None: |
| | plt.axhspan( |
| | ymin, ymax, color="green", alpha=0.2, label="Reference range" |
| | ) |
| | elif ymax is not None: |
| | plt.axhline( |
| | y=ymax, color="red", linestyle="--", label="Upper threshold" |
| | ) |
| | elif ymin is not None: |
| | plt.axhline( |
| | y=ymin, color="blue", linestyle="--", label="Lower threshold" |
| | ) |
| |
|
| | |
| | plt.xlabel("Date") |
| | plt.ylabel(values[0].get("unit", "") if values else "") |
| | plt.grid(True) |
| | plt.xticks(rotation=45) |
| | plt.legend() |
| | plt.tight_layout() |
| |
|
| | filename = f"{safe_filename(test_name).replace(' ', '_')}_trend.png" |
| | filepath = os.path.join(plots_dir, filename) |
| | plt.savefig(filepath) |
| | plt.close() |
| | plot_files.append((test_name, filepath)) |
| |
|
| | |
| | pdf_path = os.path.join( |
| | SheamiConfig.get_output_dir(state["thread_id"]), "final_report.pdf" |
| | ) |
| |
|
| | generate_pdf( |
| | pdf_path=pdf_path, |
| | interpretation_html=interpretation_html, |
| | plot_files=plot_files, |
| | ) |
| |
|
| | |
| | state["pdf_path"] = pdf_path |
| | state["interpretation_html"] = interpretation_html |
| | logger.info("%s| Interpreting Trends : finished", state["thread_id"]) |
| | send_message(state=state, msg="Interpreting Trends : finished", append=False) |
| |
|
| | return state |
| |
|
| |
|
| | async def fn_final_cleanup_node(state: SheamiState): |
| | pdf_path = state["pdf_path"] |
| | schedule_cleanup(file_path=SheamiConfig.get_output_dir(state["thread_id"])) |
| | state["milestones"][-1].status = "completed" |
| | state["milestones"][-1].end_time = datetime.now() |
| | await get_db().add_or_update_milestone( |
| | run_id=state["run_id"], |
| | milestone=state["process_desc"], |
| | status="completed", |
| | end=True, |
| | ) |
| |
|
| | await get_db().update_run_stats(run_id=state["run_id"], status="completed") |
| |
|
| | |
| | |
| | with open(pdf_path, "rb") as f: |
| | pdf_bytes = f.read() |
| | final_report_id = await get_db().add_final_report_v2( |
| | patient_id=state["patient_id"], |
| | summary=state["interpretation_html"], |
| | pdf_bytes=pdf_bytes, |
| | file_name=f"health_trends_report_{state["patient_id"]}.pdf", |
| | ) |
| | logger.info("final_report_id = %s", final_report_id) |
| |
|
| |
|
| | def schedule_cleanup(file_path, delay=300): |
| | def cleanup(): |
| | time.sleep(delay) |
| | if os.path.exists(file_path): |
| | try: |
| | if os.path.isdir(file_path): |
| | import shutil |
| |
|
| | shutil.rmtree(file_path) |
| | else: |
| | os.remove(file_path) |
| | print(f"Cleaned up: {file_path}") |
| | except Exception as e: |
| | print(f"Cleanup failed for {file_path}: {e}") |
| |
|
| | threading.Thread(target=cleanup, daemon=True).start() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | async def fn_standardizer_node_notifier(state: SheamiState): |
| | state = await reset_process_desc(state, process_desc="Standardizing reports ...") |
| | state["units_total"] = len(state["uploaded_reports"]) |
| | send_message( |
| | state=state, msg="Standardizing reports now ... this might take a while ..." |
| | ) |
| | state["overall_units_processed"] += 1 |
| | return state |
| |
|
| |
|
| | async def fn_testname_standardizer_node_notifier(state: SheamiState): |
| | state = await reset_process_desc(state, process_desc="Standardizing test names ...") |
| | send_message(state=state, msg="Standardizing test names now ...") |
| | state["overall_units_processed"] += 1 |
| | return state |
| |
|
| |
|
| | async def fn_unit_normalizer_node_notifier(state: SheamiState): |
| | state = await reset_process_desc(state, process_desc="Standardizing units ...") |
| | send_message(state=state, msg="Standardizing measurement units now ...") |
| | state["overall_units_processed"] += 1 |
| | return state |
| |
|
| |
|
| | async def fn_trends_aggregator_node_notifier(state: SheamiState): |
| | state = await reset_process_desc(state, process_desc="Aggregating trends ...") |
| | send_message(state=state, msg="Aggregating trends now ...") |
| | state["overall_units_processed"] += 1 |
| | return state |
| |
|
| |
|
| | async def fn_interpreter_node_notifier(state: SheamiState): |
| | state = await reset_process_desc(state, process_desc="Plotting trends ...") |
| | send_message(state=state, msg="Interpreting and plotting trends now ...") |
| | state["overall_units_processed"] += 1 |
| | return state |
| |
|
| |
|
| | def create_graph(user_email: str, patient_id: str, thread_id: str): |
| | logger.info( |
| | "%s| Creating Graph : started for user:%s | patient:%s", |
| | thread_id, |
| | user_email, |
| | patient_id, |
| | ) |
| | memory = InMemorySaver() |
| | workflow = StateGraph(SheamiState) |
| | workflow.add_node("init", fn_init_node) |
| | workflow.add_node("standardize_current_report", fn_standardize_current_report_node) |
| | workflow.add_node("increment_index", fn_increment_index_node) |
| | workflow.add_node("testname_standardizer", fn_testname_standardizer_node) |
| | workflow.add_node("unit_normalizer", fn_unit_normalizer_node) |
| | workflow.add_node("db_update_node", fn_db_update_node) |
| | workflow.add_node("trends", fn_trends_aggregator_node) |
| | workflow.add_node("interpreter", fn_interpreter_node) |
| |
|
| | workflow.add_node("standardizer_notifier", fn_standardizer_node_notifier) |
| | workflow.add_node( |
| | "testname_standardizer_notifier", fn_testname_standardizer_node_notifier |
| | ) |
| | workflow.add_node("unit_normalizer_notifier", fn_unit_normalizer_node_notifier) |
| | workflow.add_node("trends_notifier", fn_trends_aggregator_node_notifier) |
| | workflow.add_node("interpreter_notifier", fn_interpreter_node_notifier) |
| | workflow.add_node("final_cleanup_node", fn_final_cleanup_node) |
| |
|
| | workflow.add_edge(START, "init") |
| | workflow.add_edge("init", "standardizer_notifier") |
| | workflow.add_edge("standardizer_notifier", "increment_index") |
| |
|
| | |
| | workflow.add_conditional_edges( |
| | "increment_index", |
| | fn_is_report_available_to_process, |
| | { |
| | "continue": "standardize_current_report", |
| | "done": "testname_standardizer_notifier", |
| | }, |
| | ) |
| | workflow.add_edge("standardize_current_report", "increment_index") |
| |
|
| | workflow.add_edge("testname_standardizer_notifier", "testname_standardizer") |
| | workflow.add_edge("testname_standardizer", "unit_normalizer_notifier") |
| | workflow.add_edge("unit_normalizer_notifier", "unit_normalizer") |
| | workflow.add_edge("unit_normalizer", "db_update_node") |
| | workflow.add_edge("db_update_node", "trends_notifier") |
| | workflow.add_edge("trends_notifier", "trends") |
| | workflow.add_edge("trends", "interpreter_notifier") |
| | workflow.add_edge("interpreter_notifier", "interpreter") |
| | workflow.add_edge("interpreter", "final_cleanup_node") |
| | workflow.add_edge("final_cleanup_node", END) |
| |
|
| | logger.info("%s| Creating Graph : finished", thread_id) |
| | return workflow.compile(checkpointer=memory) |
| |
|