sheami / plot_helper.py
vikramvasudevan's picture
Upload folder using huggingface_hub
ec4bc45 verified
import gradio as gr
from typing import Any, Dict, List
from plotly.graph_objs import Figure, Scatter
import pandas as pd
import datetime as dt
import numpy as np
import pandas as pd
from common import get_db
import plotly.express as px
MAX_CHARTS_IN_PAGE = 40
NUM_COLS = 4
def coerce_to_number(val):
"""Try converting to int/float, else return original string."""
if val is None:
return None
try:
# First try integer
i = int(val)
return i
except (ValueError, TypeError):
try:
# Then try float
f = float(val)
return f
except (ValueError, TypeError):
return val # fallback to original (string, unit, etc.)
def build_trend_figure(trend_doc: Dict[str, Any]) -> Figure:
"""Make a Plotly line chart for a single test's trend_data with optional reference ranges."""
points = trend_doc.get("trend_data", [])
ref = trend_doc.get("test_reference_range") or {} # safe default {}
if not points:
fig = Figure()
fig.update_layout(
title="No trend data", xaxis_title="Date", yaxis_title="Value"
)
return fig
# Parse dates and values
date_value_pairs = []
for p in points:
date = pd.to_datetime(p.get("date"), errors="coerce")
value = coerce_to_number(p.get("value"))
if pd.notna(date) and value is not None:
date_value_pairs.append((date, value))
# Sort by date
date_value_pairs.sort(key=lambda x: x[0])
dates, values = zip(*date_value_pairs) if date_value_pairs else ([], [])
fig = Figure()
# === Reference Range Logic (only if present) ===
ref_min = coerce_to_number(ref.get("min")) if ref else None
ref_max = coerce_to_number(ref.get("max")) if ref else None
if ref_min is not None and ref_max is not None:
fig.add_shape(
type="rect",
x0=min(dates),
x1=max(dates),
y0=ref_min,
y1=ref_max,
fillcolor="rgba(0,200,0,0.1)", # light green
line=dict(width=0),
layer="below",
)
elif ref_min is not None:
fig.add_trace(
Scatter(
x=[min(dates), max(dates)],
y=[ref_min, ref_min],
mode="lines",
name="Min Ref",
line=dict(color="green", dash="dot"),
)
)
elif ref_max is not None:
fig.add_trace(
Scatter(
x=[min(dates), max(dates)],
y=[ref_max, ref_max],
mode="lines",
name="Max Ref",
line=dict(color="red", dash="dot"),
)
)
# === Actual Trend Data ===
fig.add_trace(
Scatter(
x=dates,
y=values,
mode="lines+markers",
name=trend_doc.get("test_name", "Trend"),
)
)
fig.update_layout(
margin=dict(l=30, r=20, t=40, b=30),
xaxis_title="Date",
yaxis_title="Value",
title=f"{trend_doc.get('test_name','')}",
)
fig.update_yaxes(autorange=True)
fig.update_xaxes(
autorange=True, tickformat="%Y-%m-%d", tickangle=-45, tickmode="auto"
)
return sanitize_plotly_figure(fig)
async def load_all_trend_figures(patient_id: str):
"""Fetch all test trend docs and return list of Plot figures."""
if not patient_id:
return []
db = get_db()
cursor = db.trends.find({"patient_id": __import__("bson").ObjectId(patient_id)})
docs = await cursor.to_list(length=None)
figures = [build_trend_figure(doc) for doc in docs if doc]
return figures
async def update_trends(patient_id, page=0, num_cols=NUM_COLS):
figures = await load_all_trend_figures(patient_id)
total_pages = (len(figures) - 1) // MAX_CHARTS_IN_PAGE + 1
start = page * MAX_CHARTS_IN_PAGE
end = start + MAX_CHARTS_IN_PAGE
page_figures = figures[start:end]
outputs = []
for i in range(MAX_CHARTS_IN_PAGE):
if i < len(page_figures):
title = page_figures[i].layout.title.text
page_figures[i].update_layout(title="")
outputs.append(gr.update(value=page_figures[i], visible=True, label=title))
else:
outputs.append(gr.update(visible=False, value=None, label=""))
# Enable/disable buttons
prev_disabled = page == 0
next_disabled = page >= total_pages - 1
# return as separate outputs + page + page info
return (
*outputs, # plots
page, # page number
f"Page {page+1} / {total_pages}", # page info
gr.update(interactive=not prev_disabled), # Prev button
gr.update(interactive=not next_disabled),
) # Next button
async def reset_trends():
"""
Clears all trend plots and resets page info.
Returns a list of gr.update(...) objects matching the outputs of update_trends.
"""
outputs = []
for _ in range(MAX_CHARTS_IN_PAGE):
outputs.append(gr.update(visible=False, value=None, label=""))
# Reset page number and page info
page = 0
page_info = "Page 0 / 0"
return (
*outputs,
page,
page_info,
gr.update(interactive=False),
gr.update(interactive=False),
)
def reset_vitals_plots():
"""
Clears all vitals plots and resets page info.
Returns a list of gr.update(...) objects matching the outputs of update_trends.
"""
outputs = []
for _ in range(20):
outputs.append(gr.update(visible=False, value=None, label=""))
return (*outputs,)
def reset_latest_vitals_labels():
"""
Clears all latest vitals labels and resets page info.
Returns a list of gr.update(...) objects matching the outputs of update_trends.
"""
outputs = []
for _ in range(20):
outputs.append(gr.update(visible=False, value=None, label=""))
return (*outputs,)
def _to_jsonable_dt(x):
if isinstance(x, pd.Timestamp):
return x.to_pydatetime() # or x.isoformat()
if isinstance(x, np.datetime64):
return pd.to_datetime(x).to_pydatetime()
return x
def sanitize_plotly_figure(fig):
# traces (x/xbins/…)
for tr in fig.data:
if hasattr(tr, "x") and tr.x is not None:
try:
tr.x = [_to_jsonable_dt(v) for v in tr.x]
except TypeError:
# x may be a scalar
tr.x = _to_jsonable_dt(tr.x)
# shapes (x0/x1)
if fig.layout.shapes:
for s in list(fig.layout.shapes):
if getattr(s, "x0", None) is not None:
s.x0 = _to_jsonable_dt(s.x0)
if getattr(s, "x1", None) is not None:
s.x1 = _to_jsonable_dt(s.x1)
# annotations (x)
if fig.layout.annotations:
for a in list(fig.layout.annotations):
if getattr(a, "x", None) is not None:
a.x = _to_jsonable_dt(a.x)
# axes ranges (range can contain datetimes)
if getattr(fig.layout, "xaxis", None) and getattr(fig.layout.xaxis, "range", None):
fig.layout.xaxis.range = [_to_jsonable_dt(v) for v in fig.layout.xaxis.range]
return fig
def next_page(page, figures_len):
total_pages = (figures_len - 1) // MAX_CHARTS_IN_PAGE + 1
return min(page + 1, total_pages - 1)
def prev_page(page):
return max(page - 1, 0)
async def render_vitals_plot_layout(patient_id):
docs = await get_db().get_vitals_trends_by_patient(patient_id)
figures = [build_trend_figure(doc) for doc in docs if doc]
# Pad/truncate to exactly 20 charts
if len(figures) > 20:
figures = figures[:20]
elif len(figures) < 20:
while len(figures) < 20:
empty_fig = Figure()
empty_fig.update_layout(
title="No Data",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
margin=dict(l=30, r=20, t=40, b=30),
)
figures.append(empty_fig)
plots = []
for fig in figures:
plots.append(gr.Plot(value=fig, label=fig.layout.title.text))
fig.update_layout(title=None)
return plots