wanhanisah's picture
Update app.py
770a7a5 verified
# app.py
import os
import hashlib
import io
import json
import math
import zipfile
import time
import random
import tempfile
import uuid
import streamlit.components.v1 as components
from textwrap import dedent
from pathlib import Path
from typing import Optional, List, Dict, Tuple
import numpy as np
import pandas as pd
import nibabel as nib
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from matplotlib import animation
from PIL import Image, ImageDraw, ImageFont
from skimage.measure import label as cc_label, regionprops
from scipy.ndimage import (
binary_fill_holes,
binary_dilation,
binary_erosion,
binary_closing,
) # use SciPy morphology
from skimage.morphology import disk
import streamlit as st
import base64
from streamlit_drawable_canvas import st_canvas
from utils.layer_util import ResizeAndConcatenate
# =========================
# Config / paths
# =========================
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# Uncomment next line to force CPU only
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
MODEL_PATH = "./models_final/SEG459.h5"
# ---------------------------------------------------------
# Keras custom_objects safety:
# If the saved model references these symbols by name, Keras
# may try to resolve them even with compile=False.
# Provide harmless dummy callables instead of None.
# ---------------------------------------------------------
def _dummy(*args, **kwargs):
# Works for losses/metrics/custom functions; never used for inference.
return tf.constant(0.0)
_CUSTOM_OBJECTS = {
"focal_tversky_loss": _dummy,
"dice_coef_no_bkg": _dummy,
"ResizeAndConcatenate": ResizeAndConcatenate, # must stay real
"dice_myo": _dummy,
"dice_blood": _dummy,
"dice": _dummy,
}
@st.cache_resource(show_spinner=False)
def get_model():
# Optional: avoid re-tracing / keep memory stable in some environments
# tf.keras.backend.clear_session()
return keras.models.load_model(
MODEL_PATH,
custom_objects=_CUSTOM_OBJECTS,
compile=False,
)
OUTPUT_ROOT = "/tmp/NIFTI_OUTPUTS"
# Inference & metrics settings
SIZE_X, SIZE_Y = 256, 256 # UNet3+ input grid
TARGET_HW = (SIZE_Y, SIZE_X)
N_CLASSES = 3
BATCH_SIZE = 16
MYO_DENSITY = 1.05 # g/mL → mg via mm^3 * g/mL
# Resize mode for input images -> model grid
RESIZE_MODE = "auto" # "auto", "bilinear", "area", or "fft"
# --- Progress weights (must sum to 1.0) ---
W_LOAD_PREP = 0.15
W_PREDICT = 0.65
W_POST_GIF = 0.20
assert abs((W_LOAD_PREP + W_PREDICT + W_POST_GIF) - 1.0) < 1e-6
# Island removal (3D)
ENABLE_ISLAND_REMOVAL = True
ISLAND_MIN_SLICE_SPAN = 2
ISLAND_MIN_AREA_PER_SLICE = 10
ISLAND_CENTROID_DIST_THRESH = 40
# Orientation / display
ORIENT_TARGET = None # None (keep native), "LPS", or "RAS"
DISPLAY_MATCH_DICOM = False
DISPLAY_RULES = {
"LPS": dict(rot90_cw=True, flip_ud=True, flip_lr=False),
"RAS": dict(rot90_cw=True, flip_ud=False, flip_lr=True),
None: dict(rot90_cw=False, flip_ud=False, flip_lr=False),
}
CURRENT_DISPLAY_ORIENT = ORIENT_TARGET
# ED/ES robust selection (mid-slices subset)
USE_MID_SLICES_FOR_ED_ES = True
MID_K = 4
MID_MIN_VALID_FRAC = 0.7
MID_A_BLOOD_MIN = 30
MID_A_MYO_MIN = 30
GIF_FPS = 1
GIF_DPI = 300
# Bigger drawing canvas for manual corrections (display only; masks still 256x256)
CANVAS_SCALE = 3 # 2x/3x/4x larger canvas than model grid
# Roundel-style display width for the LV editor
DISPLAY_W = 400
# =========================
# Branding / UI
# =========================
LOGO_URL = (
"https://raw.githubusercontent.com/whanisa/Segmentation/main/icon/logo.png"
)
LOGO_LINK = "https://github.com/whanisa/Segmentation"
LOGO_HEIGHT_PX = 120
SAFE_INSET_PX = 18
# --- CSS
def _inject_layout_css():
"""
Desktop-first layout with responsive overrides.
✅ File uploader styling (HF-proof):
- Label strip stays WHITE (wrapper remains transparent)
- Dropzone surface GREY (#f0f2f6)
- "Browse files" button WHITE
"""
CONTENT_MEASURE_PX = 920
LEFT_OFFSET_PX = 40
UPLOAD_WIDTH_PX = 420
mobile_logo_h = max(48, int(LOGO_HEIGHT_PX * 0.7))
st.markdown(
f"""
<style>
/* =========================
0) Design tokens
========================= */
:root {{
--content-measure: {CONTENT_MEASURE_PX}px;
--left-offset: {LEFT_OFFSET_PX}px;
--upload-width: {UPLOAD_WIDTH_PX}px;
--logo-height: {LOGO_HEIGHT_PX}px;
--edge-x: max(12px, env(safe-area-inset-left));
--header-clear: 40px;
--edge-y: calc(env(safe-area-inset-top) + var(--header-clear));
--tabs-top-gap: calc(var(--logo-height) + 16px);
--tabs-left-shift: 32px;
--accent: #ef4444;
--surface: #ffffff;
--text: #111827;
--muted: #374151;
--body: #333333;
--link: #0b66c3;
--page-pad-x: 18px;
/* Uploader dropzone (match theme secondaryBackgroundColor) */
--uploader-bg: #f0f2f6;
--uploader-bg-hover: #e9ecef;
--uploader-border: rgba(17, 24, 39, 0.16);
--uploader-border-hover: rgba(17, 24, 39, 0.26);
--uploader-radius: 12px;
/* Browse button (WHITE like original) */
--uploader-btn-bg: #ffffff;
--uploader-btn-bg-hover: #f8f9fa;
--uploader-btn-border: rgba(17, 24, 39, 0.18);
}}
/* =========================
1) Force light surfaces
========================= */
html, body {{
background: var(--surface) !important;
color: var(--text) !important;
}}
.stApp {{
background: var(--surface) !important;
color: var(--text) !important;
}}
@media (prefers-color-scheme: dark) {{
html, body, .stApp {{
background: var(--surface) !important;
color: var(--text) !important;
}}
}}
/* Ensure fixed logo isn't clipped */
.stApp, .appview-container, .main {{
overflow: visible !important;
}}
/* =========================
2) Spacing tweaks
========================= */
.appview-container .main .block-container {{
padding-top: 0.75rem;
padding-bottom: 1rem;
}}
div[data-testid="stElementContainer"] {{
margin-bottom: 0.2rem !important;
}}
/* =========================
3) Outer layout wrappers
========================= */
.content-wrap {{
width: min(1300px, 100%);
margin: 0 auto;
padding: 0 var(--page-pad-x);
box-sizing: border-box;
}}
.measure-wrap {{
max-width: var(--content-measure);
margin-left: var(--left-offset);
margin-right: auto;
}}
#upload-wrap {{
max-width: var(--upload-width);
margin-left: var(--left-offset);
margin-right: auto;
}}
/* ==========================================================
✅ FILE UPLOADER (HF-proof, short, stable)
========================================================== */
/* A) Keep uploader wrapper transparent so label strip stays white */
section[data-testid="stFileUploader"],
div[data-testid="stFileUploader"],
.stFileUploader,
section[data-testid="stFileUploader"] > div,
div[data-testid="stFileUploader"] > div,
.stFileUploader > div {{
background: transparent !important;
background-color: transparent !important;
border: 0 !important;
border-radius: 0 !important;
box-shadow: none !important;
}}
/* B) Dropzone surface targets (BaseWeb + testids + common nested nodes) */
div[data-baseweb="file-uploader"],
div[data-baseweb="file-uploader"] > div,
div[data-baseweb="file-uploader"] > div > div,
div[data-baseweb="file-uploader"] [role="button"],
div[data-testid="stFileUploaderDropzone"],
section[data-testid="stFileUploaderDropzone"],
div[data-testid="stFileUploaderDropzone"] > div,
section[data-testid="stFileUploaderDropzone"] > div,
div[data-testid="stFileUploaderDropzone"] > div > div,
section[data-testid="stFileUploaderDropzone"] > div > div {{
background: var(--uploader-bg) !important;
background-color: var(--uploader-bg) !important;
border: 1px dashed var(--uploader-border) !important;
border-radius: var(--uploader-radius) !important;
box-shadow: none !important;
padding: 14px 44px 14px 14px !important;
}}
/* Hover */
div[data-baseweb="file-uploader"] [role="button"]:hover,
div[data-baseweb="file-uploader"] > div:hover,
div[data-testid="stFileUploaderDropzone"]:hover,
section[data-testid="stFileUploaderDropzone"]:hover {{
background: var(--uploader-bg-hover) !important;
background-color: var(--uploader-bg-hover) !important;
border-color: var(--uploader-border-hover) !important;
}}
/* Focus ring */
div[data-baseweb="file-uploader"] [role="button"]:focus-within,
div[data-testid="stFileUploaderDropzone"]:focus-within,
section[data-testid="stFileUploaderDropzone"]:focus-within {{
border-color: rgba(239, 68, 68, 0.45) !important;
box-shadow: 0 0 0 3px rgba(239, 68, 68, 0.15) !important;
}}
/* Text/icon inside uploader */
div[data-baseweb="file-uploader"] *,
div[data-testid="stFileUploaderDropzone"] *,
section[data-testid="stFileUploaderDropzone"] * {{
color: var(--text) !important;
}}
/* C) Browse button = WHITE (like original) */
div[data-baseweb="file-uploader"] button,
div[data-baseweb="file-uploader"] [role="button"] button,
div[data-testid="stFileUploaderDropzone"] button,
section[data-testid="stFileUploaderDropzone"] button,
section[data-testid="stFileUploader"] button,
div[data-testid="stFileUploader"] button {{
background: var(--uploader-btn-bg) !important;
background-color: var(--uploader-btn-bg) !important;
color: var(--text) !important;
border: 1px solid var(--uploader-btn-border) !important;
border-radius: 10px !important;
font-weight: 600 !important;
box-shadow: none !important;
}}
div[data-baseweb="file-uploader"] button:hover,
div[data-baseweb="file-uploader"] [role="button"] button:hover,
div[data-testid="stFileUploaderDropzone"] button:hover,
section[data-testid="stFileUploaderDropzone"] button:hover,
section[data-testid="stFileUploader"] button:hover,
div[data-testid="stFileUploader"] button:hover {{
background: var(--uploader-btn-bg-hover) !important;
background-color: var(--uploader-btn-bg-hover) !important;
border-color: rgba(239, 68, 68, 0.55) !important;
color: var(--accent) !important;
}}
/* =========================
4) Fixed-edge logo
========================= */
#fixed-edge-logo {{
position: fixed;
left: var(--edge-x);
top: var(--edge-y);
z-index: 1000;
pointer-events: none;
}}
#fixed-edge-logo img {{
height: var(--logo-height);
width: auto;
display: block;
}}
.edge-logo-spacer {{
height: var(--tabs-top-gap);
}}
/* =========================
5) Typography
========================= */
.hero-title {{
font-size: 40px;
line-height: 1.25;
font-weight: 800;
margin: 0 0 20px;
text-align: justify;
text-justify: inter-word;
color: var(--text);
}}
.text-wrap p {{
margin: 0 0 14px 0;
font-size: 17px;
line-height: 1.5;
text-align: justify;
text-justify: inter-word;
color: var(--body);
hyphens: auto;
-webkit-hyphens: auto;
-ms-hyphens: auto;
}}
.text-wrap a {{
color: var(--link) !important;
text-decoration: underline;
text-underline-offset: 2px;
text-decoration-thickness: 1.5px;
}}
.note-text {{
font-size: 14px;
color: var(--body);
line-height: 1.4;
margin-top: 4px;
}}
/* =========================
6) Tabs styling
========================= */
div[data-testid="stTabs"] [role="tablist"],
.stTabs [role="tablist"],
div[data-baseweb="tab-list"] {{
margin-left: calc(var(--left-offset) + var(--tabs-left-shift)) !important;
margin-right: var(--page-pad-x) !important;
border-bottom: 0 !important;
padding-bottom: 6px !important;
}}
div[data-testid="stTabs"] [role="tab"],
.stTabs [role="tab"],
div[data-baseweb="tab-list"] button[role="tab"] {{
color: var(--muted) !important;
background: transparent !important;
border: none !important;
outline: none !important;
padding: 6px 14px 10px 14px !important;
margin: 0 4px !important;
font-weight: 600 !important;
border-bottom: 3px solid transparent !important;
}}
div[data-testid="stTabs"] [role="tab"][aria-selected="true"],
.stTabs [role="tab"][aria-selected="true"],
div[data-baseweb="tab-list"] button[aria-selected="true"] {{
color: var(--accent) !important;
border-bottom-color: var(--accent) !important;
}}
div[data-baseweb="tab-highlight"] {{
display: none !important;
}}
/* =========================
7) Responsive overrides
========================= */
@media (max-width: 1024px) {{
:root {{
--content-measure: 92vw;
--left-offset: 16px;
--upload-width: 92vw;
--tabs-left-shift: 12px;
--page-pad-x: 14px;
}}
.hero-title {{ font-size: 34px; }}
}}
@media (max-width: 640px) {{
:root {{
--content-measure: 94vw;
--left-offset: 0px;
--upload-width: 94vw;
--tabs-left-shift: 0px;
--header-clear: 64px;
--page-pad-x: 12px;
}}
#fixed-edge-logo img {{
height: {mobile_logo_h}px;
}}
.edge-logo-spacer {{
height: calc(var(--logo-height) + 8px);
}}
.hero-title {{
font-size: clamp(22px, 6vw, 30px);
line-height: 1.2;
text-align: left;
text-justify: auto;
}}
.text-wrap p {{
font-size: 15px;
text-align: left;
hyphens: none;
}}
}}
</style>
""",
unsafe_allow_html=True,
)
# =========================
# Small utilities
# =========================
def log(msg: str):
print(f"[INFO] {msg}")
def _safe_rerun():
"""Compatibility rerun for old/new Streamlit."""
if hasattr(st, "rerun"):
st.rerun()
elif hasattr(st, "experimental_rerun"):
st.experimental_rerun()
else:
# very old / unusual environment
raise RuntimeError("No rerun method available in this Streamlit version.")
def get_nifti_SF_quick(path: str) -> Tuple[int, int]:
img = nib.load(path)
shape = img.shape # (X,Y,Z) or (X,Y,Z,T)
S = int(shape[2]) if len(shape) >= 3 else 1
F = int(shape[3]) if len(shape) >= 4 else 1
return S, F
def get_session_tmpdir() -> str:
if "session_tmpdir" not in st.session_state:
st.session_state["session_tmpdir"] = tempfile.mkdtemp(
prefix=f"segapp_{uuid.uuid4().hex}_"
)
return st.session_state["session_tmpdir"]
def get_session_gifs_dir() -> str:
d = os.path.join(get_session_tmpdir(), "GIFs")
os.makedirs(d, exist_ok=True)
return d
def get_session_csv_dir() -> str:
d = os.path.join(get_session_tmpdir(), "CSV")
os.makedirs(d, exist_ok=True)
return d
def invalidate_download_caches():
# Anything derived from cases/gifs that must be rebuilt after edits
for k in [
"mask_zip_bytes",
"mask_zip_name",
"mask_zip_bytes_corrected",
"mask_zip_name_corrected",
"gif_zip_bytes",
"gif_zip_name",
"gif_zip_bytes_corrected",
"gif_zip_name_corrected",
]:
st.session_state.pop(k, None)
def normalize_images(x):
x = tf.convert_to_tensor(x, dtype=tf.float32)
mn = tf.reduce_min(x, axis=[1, 2], keepdims=True)
mx = tf.reduce_max(x, axis=[1, 2], keepdims=True)
rng = mx - mn
x_norm = tf.where(rng > 0.0, (x - mn) / rng, tf.zeros_like(x))
return x_norm.numpy()
def _tf_resize_bilinear(img, *, target_h=SIZE_Y, target_w=SIZE_X):
arr = img[np.newaxis, ..., np.newaxis].astype(np.float32)
out = tf.image.resize(
arr, [target_h, target_w], method="bilinear", antialias=True
)
return np.squeeze(out.numpy()).astype(np.float32)
def _tf_resize_auto(img, target_h=SIZE_Y, target_w=SIZE_X):
H, W = img.shape
method = "area" if (H > target_h or W > target_w) else "bilinear"
arr = img[np.newaxis, ..., np.newaxis].astype(np.float32)
out = tf.image.resize(arr, [target_h, target_w], method=method, antialias=True)
return np.squeeze(out.numpy()).astype(np.float32)
def _resize_nn(img, new_h, new_w):
arr = img[None, ..., None].astype(np.float32)
out = tf.image.resize(arr, [new_h, new_w], method="nearest")
return np.squeeze(out.numpy()).astype(img.dtype)
def _fft_kspace_resize(img, target_h=SIZE_Y, target_w=SIZE_X):
"""
Resize via k-space:
- Upsample: zero-fill
- Downsample: crop center k-space
"""
img = np.asarray(img, dtype=np.float32)
H, W = img.shape
if H == target_h and W == target_w:
return img
F = np.fft.fftshift(np.fft.fft2(img))
# Create target k-space container
kspace = np.zeros((target_h, target_w), dtype=np.complex64)
# Compute copy ranges (centered)
y_min = max(0, (target_h - H) // 2)
x_min = max(0, (target_w - W) // 2)
src_y_min = max(0, (H - target_h) // 2)
src_x_min = max(0, (W - target_w) // 2)
copy_h = min(H, target_h)
copy_w = min(W, target_w)
kspace[y_min:y_min + copy_h, x_min:x_min + copy_w] = \
F[src_y_min:src_y_min + copy_h, src_x_min:src_x_min + copy_w]
img_res = np.fft.ifft2(np.fft.ifftshift(kspace))
return np.real(img_res).astype(np.float32)
def display_xform(img2d, orient_target=None, enable=DISPLAY_MATCH_DICOM):
if orient_target is None:
orient_target = CURRENT_DISPLAY_ORIENT
if not enable:
return img2d
rule = DISPLAY_RULES.get(orient_target, DISPLAY_RULES[None])
out = img2d
if rule.get("rot90_cw"):
out = np.rot90(out, k=-1)
if rule.get("flip_ud"):
out = np.flipud(out)
if rule.get("flip_lr"):
out = np.fliplr(out)
return out
def _normalize_to_uint8(img2d: np.ndarray) -> np.ndarray:
arr = np.asarray(img2d, dtype=np.float32)
if arr.size == 0:
return np.zeros_like(arr, dtype=np.uint8)
mn = np.min(arr)
mx = np.max(arr)
if mx > mn:
arr = (arr - mn) / (mx - mn)
else:
arr = np.zeros_like(arr)
arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
return arr
def overlay_rgb_alpha(base_u8: np.ndarray, mask_u8: np.ndarray,
alpha_myo: float = 0.45, alpha_blood: float = 0.45) -> np.ndarray:
"""
base_u8: (H,W) uint8 grayscale
mask_u8: (H,W) uint8 labels (0,1,2)
returns: (H,W,3) uint8 RGB
"""
rgb = np.stack([base_u8]*3, axis=-1).astype(np.float32)
myo = (mask_u8 == 1)
blood = (mask_u8 == 2)
# myocardium = blue
if myo.any():
blue = np.array([0, 0, 255], dtype=np.float32)
rgb[myo] = (1 - alpha_myo) * rgb[myo] + alpha_myo * blue
# blood pool = red (apply after myo so blood stays red inside the ring)
if blood.any():
red = np.array([255, 0, 0], dtype=np.float32)
rgb[blood] = (1 - alpha_blood) * rgb[blood] + alpha_blood * red
return rgb.clip(0, 255).astype(np.uint8)
def render_overlay_pil_scaled(img2d_256, mask2d_256, *, scale=CANVAS_SCALE):
# 1) display-transform both consistently
img_disp = display_xform(img2d_256)
mask_disp = display_xform(mask2d_256.astype(np.uint8))
# 2) uint8 grayscale
base_u8 = _normalize_to_uint8(img_disp)
# 3) overlay in RGB
rgb = overlay_rgb_alpha(base_u8, mask_disp, alpha_myo=0.45, alpha_blood=0.45)
# 4) upscale to match canvas
pil = Image.fromarray(rgb)
W = SIZE_X * scale
H = SIZE_Y * scale
pil = pil.resize((W, H), resample=Image.NEAREST)
return pil
def build_zip_bytes(files: list, root_folder: Optional[str] = None):
"""
files: list of tuples (arc_rel_path, bytes_content)
root_folder: if provided, makes a top-level folder inside zip.
if None or "", files are placed at zip root.
"""
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as z:
for rel_path, content in files:
if root_folder:
arc = f"{root_folder}/{rel_path}".replace("\\", "/")
else:
arc = f"{rel_path}".replace("\\", "/")
z.writestr(arc, content)
buf.seek(0)
return buf.getvalue()
def build_corrected_gif_zip_bytes():
zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem
gif_root = f"{zipstem}_GIF_corrected"
gif_zip_name = f"{gif_root}.zip"
gif_files = []
orig = st.session_state.get("gif_paths_original", {}) or {}
corr = st.session_state.get("gif_paths_corrected", {}) or {}
# corrected overrides original
all_pids = sorted(set(orig.keys()) | set(corr.keys()))
for pid in all_pids:
gif_path = corr.get(pid) or orig.get(pid)
if gif_path and os.path.exists(gif_path):
with open(gif_path, "rb") as f:
# filename rule: pid.gif (pid derived from nifti name)
gif_files.append((f"{pid}.gif", f.read()))
st.session_state["gif_zip_bytes_corrected"] = build_zip_bytes(
gif_files,
root_folder=gif_root
)
st.session_state["gif_zip_name_corrected"] = gif_zip_name
def build_corrected_masks_zip_bytes_per_mouse():
zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem
root = f"{zipstem}_Masks_corrected"
zip_name = f"{root}.zip"
mask_files = []
cases = st.session_state.get("cases", {}) or {}
for pid, case in cases.items():
native_h, native_w, S, F = case["native_shape"]
affine = case["affine"]
# ✅ use corrected only if mask was edited; otherwise keep original mask
if case.get("mask_edited", False):
preds_src = case["preds_4d"]
else:
preds_src = case.get("preds_4d_orig", case["preds_4d"])
mask_native = resize_masks_to_native(preds_src, native_h, native_w)
session_tmp = get_session_tmpdir()
tmp_mask_path = os.path.join(session_tmp, f"{pid}_mask.nii.gz")
save_native_mask_nifti_with_labels(mask_native, affine, tmp_mask_path)
with open(tmp_mask_path, "rb") as f:
mask_files.append((f"{pid}/{pid}_mask.nii.gz", f.read()))
try:
os.remove(tmp_mask_path)
except OSError:
pass
st.session_state["mask_zip_bytes_corrected"] = build_zip_bytes(mask_files, root_folder=root)
st.session_state["mask_zip_name_corrected"] = zip_name
def _alpha_overlay_rgb(base_u8, mask_u8):
"""
base_u8: (H,W) uint8 grayscale
mask_u8: (H,W) uint8 labels (0,1,2)
returns RGB uint8
"""
base_rgb = np.stack([base_u8]*3, axis=-1).astype(np.float32)
# myocardium (1) blue tint
myo = (mask_u8 == 1)
base_rgb[myo, 2] = np.clip(base_rgb[myo, 2] * 0.6 + 255 * 0.4, 0, 255)
# blood (2) red tint
blood = (mask_u8 == 2)
base_rgb[blood, 0] = np.clip(base_rgb[blood, 0] * 0.6 + 255 * 0.4, 0, 255)
return base_rgb.astype(np.uint8)
def build_frame_montage(images_4d_256, preds_4d, frame_idx, slice_indices, tile_cols=4, add_overlay=True):
"""
Returns a PIL Image montage for a given frame over selected slices.
"""
tiles = []
for s in slice_indices:
img = images_4d_256[:, :, s, frame_idx]
img_u8 = _normalize_to_uint8(img)
if add_overlay:
mask = preds_4d[:, :, s, frame_idx].astype(np.uint8)
tile = overlay_rgb_alpha(img_u8, mask, alpha_myo=0.45, alpha_blood=0.45)
tile_pil = Image.fromarray(tile)
else:
tile_pil = Image.fromarray(img_u8).convert("RGB")
# annotate slice number
d = ImageDraw.Draw(tile_pil)
d.text((5, 5), f"S{s+1}", fill=(255,255,255))
tiles.append(tile_pil)
if not tiles:
return Image.new("RGB", (256, 256), (0,0,0))
w, h = tiles[0].size
n = len(tiles)
cols = min(tile_cols, n)
rows = int(math.ceil(n / cols))
canvas = Image.new("RGB", (cols*w, rows*h), (0,0,0))
for i, tile in enumerate(tiles):
r = i // cols
c = i % cols
canvas.paste(tile, (c*w, r*h))
return canvas
def _render_overlay_png(img2d, mask2d, title=None):
fig, ax = plt.subplots(figsize=(3, 3))
base = display_xform(img2d)
myo_mask = display_xform((mask2d == 1).astype(np.uint8)).astype(bool)
blood_mask = display_xform((mask2d == 2).astype(np.uint8)).astype(bool)
ax.imshow(base, cmap="gray", interpolation="none")
if myo_mask.any():
ax.imshow(
np.ma.masked_where(~myo_mask, myo_mask),
cmap="Blues",
alpha=0.45,
vmin=0,
vmax=1,
interpolation="none",
)
if blood_mask.any():
ax.imshow(
np.ma.masked_where(~blood_mask, blood_mask),
cmap="jet",
alpha=0.45,
vmin=0,
vmax=1,
interpolation="none",
)
ax.axis("off")
if title:
ax.set_title(title, fontsize=10)
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
return buf
# Colour and label maps for simple painting (kept for future use if needed)
_COLOR_MAP_RGB = {
"Blood pool": np.array([255, 0, 0], dtype=np.int16), # red
"Myocardium": np.array([0, 0, 255], dtype=np.int16), # blue
"Eraser (background)": np.array([0, 0, 0], dtype=np.int16), # black
}
_LABEL_MAP = {
"Blood pool": 2,
"Myocardium": 1,
"Eraser (background)": 0,
}
def _apply_canvas_to_mask_slice(
mask_slice: np.ndarray,
canvas_rgba: np.ndarray,
class_choice: str,
color_tol: float = 30.0,
) -> np.ndarray:
"""
Simple label-painting: update pixels where canvas color ≈ selected class color.
(Currently not exposed in the UI, but kept for possible future use.)
"""
if canvas_rgba is None:
return mask_slice
if class_choice not in _COLOR_MAP_RGB or class_choice not in _LABEL_MAP:
return mask_slice
target_rgb = _COLOR_MAP_RGB[class_choice]
label_value = _LABEL_MAP[class_choice]
rgba = canvas_rgba.astype(np.int16)
rgb = rgba[:, :, :3]
alpha = rgba[:, :, 3]
diff = np.linalg.norm(rgb - target_rgb[None, None, :], axis=-1)
drawn_mask = (alpha > 0) & (diff < color_tol)
updated = mask_slice.copy()
updated[drawn_mask] = label_value
return updated
def _thicken_close_and_fill(
strokes: np.ndarray,
dilate_iter: int = 1,
close_radius: int = 2,
) -> np.ndarray:
"""
Helper for LV contour editing:
1. Thicken strokes (binary_dilation) so small gaps shrink.
2. Binary closing (dilate+erode) to bridge breaks.
3. Fill interior (binary_fill_holes).
Input/Output: boolean mask.
"""
if strokes is None or not strokes.any():
return strokes
# Small 3x3-ish disk to thicken the strokes
selem_dilate = disk(1)
# Slightly larger disk for closing, to bridge small end-point gaps
selem_close = disk(close_radius)
thick = binary_dilation(
strokes,
structure=selem_dilate,
iterations=dilate_iter,
)
closed = binary_closing(
thick,
structure=selem_close,
)
filled = binary_fill_holes(closed)
return filled
def _apply_lv_contour_style(
mask_slice: np.ndarray,
canvas_rgba: np.ndarray,
color_tol: float = 30.0,
) -> np.ndarray:
"""
Horos-style LV editing for a single 2D slice.
Assumes the user has drawn:
• RED (#FF0000) -> inner border (endocardial / blood pool)
• BLUE (#0000FF) -> outer border (epicardial / myocardium)
We:
1. Detect red and blue stroke pixels from the RGBA canvas.
2. Thicken & binary-close each stroke mask, then fill holes.
3. Set:
- blood pool = interior of red
- myocardium = region between blue and red
and overwrite any previous label 1/2 in this slice.
"""
if canvas_rgba is None:
return mask_slice
rgba = canvas_rgba.astype(np.int16)
rgb = rgba[:, :, :3]
alpha = rgba[:, :, 3]
endo_rgb = np.array([255, 0, 0], dtype=np.int16) # red (blood pool)
epi_rgb = np.array([0, 0, 255], dtype=np.int16) # blue (myocardium)
diff_endo = np.linalg.norm(rgb - endo_rgb[None, None, :], axis=-1)
diff_epi = np.linalg.norm(rgb - epi_rgb[None, None, :], axis=-1)
endo_strokes = (alpha > 0) & (diff_endo < color_tol)
epi_strokes = (alpha > 0) & (diff_epi < color_tol)
# If nothing was drawn, just return the slice unchanged
if not endo_strokes.any() and not epi_strokes.any():
return mask_slice
# Thicken, close and fill each contour mask
endo_region = _thicken_close_and_fill(endo_strokes)
if endo_region is None:
endo_region = np.zeros_like(endo_strokes, dtype=bool)
epi_region = _thicken_close_and_fill(epi_strokes)
if epi_region is None:
epi_region = np.zeros_like(epi_strokes, dtype=bool)
# Ensure epicardial region at least contains the endocardial one
epi_region = epi_region | endo_region
blood_region = endo_region
myo_region = epi_region & ~endo_region
updated = mask_slice.copy()
# Clear existing LV labels on this slice, then re-assign
updated[(updated == 1) | (updated == 2)] = 0
updated[blood_region] = 2 # blood pool
updated[myo_region] = 1 # myocardium
return updated
# =========================
# NIfTI I/O + spacing
# =========================
def _reorient_nifti(img: nib.Nifti1Image, target: Optional[str]):
if not target:
return img, None
tgt = target.upper()
if tgt not in ("LPS", "RAS"):
raise ValueError("ORIENT_TARGET must be None, 'LPS', or 'RAS'")
cur = nib.orientations.io_orientation(img.affine)
wanted = nib.orientations.axcodes2ornt(tuple(tgt))
xfm = nib.orientations.ornt_transform(cur, wanted)
if np.allclose(xfm, np.array([[0, 1], [1, 1], [2, 1]])):
return img, xfm
data = img.get_fdata()
data_re = nib.orientations.apply_orientation(data, xfm)
aff_re = img.affine @ nib.orientations.inv_ornt_aff(xfm, img.shape)
return nib.Nifti1Image(data_re, aff_re, header=img.header), xfm
def load_nifti_4d(path, orient_target: Optional[str] = ORIENT_TARGET):
img_native = nib.load(path)
img, _ = _reorient_nifti(img_native, orient_target)
data = img.get_fdata(dtype=np.float32) # (X,Y,Z[,T])
if data.ndim == 3:
data = data[..., None]
data_4d = np.transpose(data, (1, 0, 2, 3)).astype(np.float32) # -> (H,W,S,F)
zooms = img.header.get_zooms()
x_mm = float(zooms[0]) if len(zooms) > 0 else 1.0
y_mm = float(zooms[1]) if len(zooms) > 1 else 1.0
z_mm = float(zooms[2]) if len(zooms) > 2 else 1.0
t = float(zooms[3]) if len(zooms) > 3 else None
# data_4d = transpose(data, (1,0,2,3)) => (H,W)=(Y,X)
spacing = dict(
row_mm=y_mm, # H spacing
col_mm=x_mm, # W spacing
slice_thickness_mm=z_mm,
frame_time_ms=t,
)
return data_4d, spacing, img.affine
# =========================
# Inference
# =========================
def nifti_to_model_batches(data_4d):
"""
Fast path (vectorized) for tf resize modes.
Falls back to slow loop for fft resize.
Returns:
x: (S*F, 256,256,1) float32
index: list of (s,f)
shape4d: (H,W,S,F) native
resize_method_used: str (for debugging/logging)
"""
H, W, S, F = data_4d.shape
# ✅ Validate resize mode early (fail fast)
valid = {"auto", "bilinear", "area", "fft"}
if RESIZE_MODE not in valid:
raise ValueError(f"RESIZE_MODE must be one of {sorted(valid)}")
# index mapping (matches stack order below: S-major, F-minor)
index = [(s, f) for s in range(S) for f in range(F)] # matches fft loop
# ---- FFT needs per-slice handling (keep loop) ----
if RESIZE_MODE == "fft":
batches = []
for s in range(S):
for f in range(F):
img = data_4d[..., s, f]
img_resized = _fft_kspace_resize(img, target_h=SIZE_Y, target_w=SIZE_X)
batches.append(img_resized[..., None])
x = np.stack(batches, axis=0).astype(np.float32)
return x, index, (H, W, S, F), "fft"
# ---- Vectorized TF resize path (FAST) ----
# Reshape (H,W,S,F) -> (S*F, H, W, 1)
stack = (
np.transpose(data_4d, (2, 3, 0, 1))
.reshape(S * F, H, W, 1)
.astype(np.float32)
)
# Decide method ONCE per case (since H,W are constant within a NIfTI)
if RESIZE_MODE == "auto":
method = "area" if (H > SIZE_Y or W > SIZE_X) else "bilinear"
else:
method = RESIZE_MODE # "bilinear" or "area"
x_tf = tf.image.resize(stack, [SIZE_Y, SIZE_X], method=method, antialias=True)
x = x_tf.numpy().astype(np.float32)
return x, index, (H, W, S, F), method
def _ensure_logits_last(preds):
if isinstance(preds, (list, tuple)):
preds = preds[-1]
return preds
def predict_nifti_4d(
model,
data_4d,
batch_size=None,
*,
progress_cb=None,
):
"""
Run inference on a single (H,W,S,F) NIfTI case and return:
preds_4d: (256,256,S,F) uint8 labels
imgs_4d_resized: (256,256,S,F) float32 resized images
progress_cb (optional): callable(int images_done)
- Called with number of 2D images processed so far within THIS file.
- "2D images" here means S*F slices/frames after stacking.
- This enables smooth progress updates (31->32->33...) within a file.
"""
x, index, shape4d, resize_method_used = nifti_to_model_batches(data_4d)
# Normalize (already vectorized)
x_norm = normalize_images(x)
# ---- Progress-aware prediction ----
# If no progress_cb, keep the original fast path (single predict call).
if progress_cb is None:
preds = _ensure_logits_last(
model.predict(x_norm, verbose=0, batch_size=batch_size)
)
else:
# We must do batched prediction to emit progress updates.
N = x_norm.shape[0] # = S*F
bs = int(batch_size or BATCH_SIZE or 16)
bs = max(1, bs)
outs = []
done = 0
# initial callback (optional, but helps UI start at 0 within-file)
try:
progress_cb(0)
except Exception:
pass
for start in range(0, N, bs):
end = min(start + bs, N)
batch = x_norm[start:end]
pred_b = _ensure_logits_last(model.predict(batch, verbose=0))
outs.append(pred_b)
done = end
# update progress after each batch
try:
progress_cb(done)
except Exception:
pass
preds = np.concatenate(outs, axis=0)
labels = np.argmax(preds, axis=-1).astype(np.uint8) # (S*F, 256,256)
Hn, Wn, S, F = SIZE_Y, SIZE_X, shape4d[2], shape4d[3]
# Reshape back:
# x: (S*F, Hn, Wn, 1) -> (S, F, Hn, Wn) -> (Hn, Wn, S, F)
imgs_4d_resized = x[..., 0].reshape(S, F, Hn, Wn).transpose(2, 3, 0, 1)
# labels: (S*F, Hn, Wn) -> (S, F, Hn, Wn) -> (Hn, Wn, S, F)
preds_4d = labels.reshape(S, F, Hn, Wn).transpose(2, 3, 0, 1)
return preds_4d, imgs_4d_resized
def resize_masks_to_native(preds_4d_256, native_h, native_w):
Hm, Wm, S, F = preds_4d_256.shape
out = np.zeros((native_h, native_w, S, F), dtype=preds_4d_256.dtype)
for f in range(F):
for s in range(S):
out[..., s, f] = _resize_nn(
preds_4d_256[..., s, f], native_h, native_w
)
return out
# =========================
# Mask export (native)
# =========================
def save_native_mask_nifti(mask_hwst, affine, out_path):
"""
mask_hwst: (H,W,S,F) uint8 labels
affine: affine from load_nifti_4d()
"""
data_xyzt = np.transpose(mask_hwst, (1, 0, 2, 3)).astype(np.uint8)
img = nib.Nifti1Image(data_xyzt, affine)
nib.save(img, out_path)
return out_path
DEFAULT_LABELS = {
0: {"name": "Background", "color": [0, 0, 0]},
1: {"name": "Myocardium", "color": [0, 0, 255]},
2: {"name": "LV Blood Pool", "color": [255, 0, 0]},
}
def _make_label_json(label_defs: dict) -> str:
payload = {
"type": "segmentation_labels",
"version": 1,
"labels": {str(k): v for k, v in label_defs.items()},
}
return json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
def save_native_mask_nifti_with_labels(mask_hwst, affine, out_path, label_defs=DEFAULT_LABELS):
"""
mask_hwst: (H,W,S,F) uint8 labels
Saves NIfTI with:
- header['descrip'] short string (max 80 bytes)
- NIfTI header extension containing label JSON
"""
data_xyzt = np.transpose(mask_hwst, (1, 0, 2, 3)).astype(np.uint8)
img = nib.Nifti1Image(data_xyzt, affine)
# short description (widely preserved)
desc = "Labels:0=BG,1=Myo,2=Blood"
img.header["descrip"] = desc.encode("ascii", "ignore")[:80]
# embedded JSON extension (many tools preserve it; some viewers won't display it)
label_json = _make_label_json(label_defs)
ext = nib.nifti1.Nifti1Extension(40, label_json.encode("utf-8")) # 40=comment
exts = img.header.extensions
exts.clear()
exts.append(ext)
nib.save(img, out_path)
return out_path
# =========================
# Cleaning + ED/ES + metrics
# =========================
def clean_predictions_per_frame_3d(mask_4d):
H, W, S, F = mask_4d.shape
out = mask_4d.copy()
for f in range(F):
vol_f = out[:, :, :, f]
for cls in (1, 2):
m = vol_f == cls
if not m.any():
continue
cc = cc_label(m, connectivity=1)
props = regionprops(cc)
if not props:
continue
dom = max(props, key=lambda r: r.area)
dom_centroid = np.array(dom.centroid)
keep = {dom.label}
for r in props:
if r.label == dom.label:
continue
zmin, zmax = r.bbox[2], r.bbox[5]
slice_span = zmax - zmin
areas = [
np.count_nonzero(cc[:, :, z] == r.label)
for z in range(zmin, zmax)
]
median_area = np.median(areas) if areas else 0
dist = np.linalg.norm(
np.array(r.centroid) - dom_centroid
)
if (
slice_span >= ISLAND_MIN_SLICE_SPAN
and median_area >= ISLAND_MIN_AREA_PER_SLICE
and dist <= ISLAND_CENTROID_DIST_THRESH
):
keep.add(r.label)
drop = (cc > 0) & (~np.isin(cc, list(keep)))
vol_f[drop] = 0
out[:, :, :, f] = vol_f
return out
def compute_per_frame_metrics(
preds_4d, spacing, labels={"myo": 1, "blood": 2}
):
row_mm = float(spacing["row_mm"])
col_mm = float(spacing["col_mm"])
thk = float(spacing["slice_thickness_mm"])
voxel_mm3 = row_mm * col_mm * thk
H, W, S, F = preds_4d.shape
blood_counts = (preds_4d == labels["blood"]).sum(axis=(0, 1, 2))
myo_counts = (preds_4d == labels["myo"]).sum(axis=(0, 1, 2))
volume_uL = blood_counts * voxel_mm3
myo_mass_mg = myo_counts * voxel_mm3 * MYO_DENSITY
return pd.DataFrame(
{
"Frame": np.arange(F, dtype=int),
"Volume_uL": volume_uL,
"MyocardiumMass_mg": myo_mass_mg,
}
)
def slice_validity_matrix(preds_4d, A_blood_min=30, A_myo_min=30):
H, W, S, F = preds_4d.shape
blood = preds_4d == 2
myo = preds_4d == 1
areas_blood = blood.reshape(H * W, S, F).sum(axis=0)
areas_myo = myo.reshape(H * W, S, F).sum(axis=0)
has_blood = areas_blood >= A_blood_min
has_myo = areas_myo >= A_myo_min
return has_blood, has_myo, areas_blood, areas_myo
def choose_mid_slices(has_blood, has_myo, K=4, min_frac=0.7):
S, F = has_blood.shape
valid_frac = (has_blood & has_myo).sum(axis=1) / max(F, 1)
target = int(S // 2)
best = None
for start in range(0, max(S - K + 1, 1)):
block = list(range(start, min(start + K, S)))
score = valid_frac[block].mean() - 0.01 * np.mean(
[abs(s - target) for s in block]
)
if best is None or score > best[0]:
best = (score, block)
_, chosen = best
if np.mean(valid_frac[chosen]) < min_frac:
order = np.argsort(-valid_frac)
chosen = sorted(order[:K].tolist())
return chosen
def frame_volumes_subset_uL(preds_4d, spacing, slice_indices):
voxel = (
float(spacing["row_mm"])
* float(spacing["col_mm"])
* float(spacing["slice_thickness_mm"])
)
F = preds_4d.shape[3]
vols = np.zeros(F, dtype=np.float32)
for f in range(F):
sub = preds_4d[:, :, slice_indices, f]
vols[f] = (sub == 2).sum() * voxel
return vols
def pick_ed_es_from_volumes(
vols_uL, prefer_frame0=True, rel_tol=0.05, min_sep=1
):
ed = int(np.argmax(vols_uL))
if prefer_frame0 and abs(vols_uL[0] - vols_uL[ed]) <= rel_tol * max(
vols_uL[ed], 1e-6
):
ed = 0
es = int(np.argsort(vols_uL)[0])
for c in np.argsort(vols_uL):
if abs(int(c) - ed) >= min_sep:
es = int(c)
break
return ed, es
# =========================
# GIF (ED vs ES)
# =========================
def gif_animation_for_patient_pred_only(
images_4d, preds_4d, patient_id, ed_idx, es_idx, output_dir
):
os.makedirs(output_dir, exist_ok=True)
def overlay(ax, img, pred, alpha_myo=0.45, alpha_blood=0.45):
base = display_xform(img)
myo_mask = display_xform((pred == 1).astype(np.uint8)).astype(bool)
blood_mask = display_xform((pred == 2).astype(np.uint8)).astype(bool)
ax.imshow(base, cmap="gray", interpolation="none")
if myo_mask.any():
ax.imshow(
np.ma.masked_where(~myo_mask, myo_mask),
cmap="Blues",
alpha=alpha_myo,
vmin=0,
vmax=1,
interpolation="none",
)
if blood_mask.any():
ax.imshow(
np.ma.masked_where(~blood_mask, blood_mask),
cmap="jet",
alpha=alpha_blood,
vmin=0,
vmax=1,
interpolation="none",
)
ax.axis("off")
H, W, S, F = images_4d.shape
fig, axarr = plt.subplots(1, 2, figsize=(8, 4))
plt.tight_layout(rect=[0, 0, 1, 0.92])
def update(slice_idx):
axarr[0].clear()
axarr[1].clear()
overlay(
axarr[0],
images_4d[:, :, slice_idx, ed_idx],
preds_4d[:, :, slice_idx, ed_idx],
)
axarr[0].set_title(f"ED (frame {ed_idx+1}) | Slice {slice_idx+1}")
overlay(
axarr[1],
images_4d[:, :, slice_idx, es_idx],
preds_4d[:, :, slice_idx, es_idx],
)
axarr[1].set_title(f"ES (frame {es_idx+1}) | Slice {slice_idx+1}")
fig.suptitle(f"Mouse ID: {patient_id}", fontsize=14, y=0.98)
# Slower animation (2 seconds per slice), fps still 2
anim = animation.FuncAnimation(fig, update, frames=S, interval=2500)
out_path = os.path.join(output_dir, f"{patient_id}_pred.gif")
anim.save(out_path, writer="pillow", fps=GIF_FPS)
plt.close(fig)
return out_path
# =========================
# CSV writer
# =========================
def write_all_in_one_csv(rows, per_frame_rows, csv_dir):
df_summary = pd.DataFrame(rows)
summary_cols = [
"Patient_ID",
"EDV_uL",
"ESV_uL",
"SV_uL",
"EF_%",
"MyocardiumMass_ED_mg",
"MyocardiumMass_ES_mg",
"ED_frame_index",
"ES_frame_index",
"Seg_PixelSpacing_row_mm",
"Seg_PixelSpacing_col_mm",
"Seg_SliceThickness_mm",
"Native_PixelSpacing_row_mm",
"Native_PixelSpacing_col_mm",
"Native_SliceThickness_mm",
]
df_summary = df_summary.reindex(columns=summary_cols)
if per_frame_rows:
df_perframe = pd.concat(per_frame_rows, ignore_index=True)
else:
df_perframe = pd.DataFrame(
columns=["Patient_ID", "Frame", "Volume_uL", "MyocardiumMass_mg"]
)
if not df_perframe.empty:
df_perframe["Frame"] = df_perframe["Frame"].astype(int) + 1
if not df_summary.empty:
for c in ("ED_frame_index", "ES_frame_index"):
if c in df_summary.columns:
df_summary[c] = df_summary[c].astype("Int64") + 1
for c in (
"EF_%",
"EDV_uL",
"ESV_uL",
"SV_uL",
"MyocardiumMass_ED_mg",
"MyocardiumMass_ES_mg",
"Seg_PixelSpacing_row_mm",
"Seg_PixelSpacing_col_mm",
"Seg_SliceThickness_mm",
"Native_PixelSpacing_row_mm",
"Native_PixelSpacing_col_mm",
"Native_SliceThickness_mm",
):
if c in df_summary.columns:
df_summary[c] = df_summary[c].astype(float).map(
lambda x: f"{x:.2f}"
)
for c in ("Volume_uL", "MyocardiumMass_mg"):
if c in df_perframe.columns and not df_perframe.empty:
df_perframe[c] = df_perframe[c].astype(float).map(
lambda x: f"{x:.2f}"
)
all_in_one = df_perframe.merge(
df_summary[summary_cols],
on="Patient_ID",
how="left",
)[
[
"Patient_ID",
"ED_frame_index",
"ES_frame_index",
"EDV_uL",
"ESV_uL",
"SV_uL",
"EF_%",
"MyocardiumMass_ED_mg",
"MyocardiumMass_ES_mg",
"Frame",
"Volume_uL",
"MyocardiumMass_mg",
"Seg_PixelSpacing_row_mm",
"Seg_PixelSpacing_col_mm",
"Seg_SliceThickness_mm",
"Native_PixelSpacing_row_mm",
"Native_PixelSpacing_col_mm",
"Native_SliceThickness_mm",
]
]
os.makedirs(csv_dir, exist_ok=True)
out_csv = os.path.join(csv_dir, "Results.csv")
all_in_one.to_csv(out_csv, index=False)
log(f"CSV written: {out_csv}")
return out_csv
# =========================
# Same-tab download
# =========================
def _same_tab_download_button(
label: str,
data_bytes: bytes,
file_name: str,
mime: str = "text/csv",
*,
key: Optional[str] = None,
):
import html, hashlib, base64
import streamlit as st
import streamlit.components.v1 as components
b64 = base64.b64encode(data_bytes).decode("ascii")
btn_id = f"dl_{(key or file_name)}_{hashlib.sha256((key or file_name).encode()).hexdigest()[:8]}"
st.markdown(
f"""
<style>
a#{btn_id} {{
appearance: none;
display: inline-flex; align-items: center; justify-content: center;
padding: 0.5rem 0.75rem;
border-radius: 0.5rem;
border: 1px solid rgba(49,51,63,.2);
background: var(--background-color);
color: var(--text-color);
font-weight: 600; text-decoration: none !important;
box-shadow: 0 1px 2px rgba(0,0,0,0.04);
transition: color .15s ease, border-color .15s ease,
box-shadow .15s ease, transform .05s ease, background-color .15s;
cursor: pointer;
user-select: none;
-webkit-tap-highlight-color: transparent;
}}
a#{btn_id}:hover, a#{btn_id}:focus {{
background: var(--background-color);
color: var(--accent) !important;
border-color: var(--accent) !important;
box-shadow: 0 2px 6px rgba(239,68,68,0.20);
transform: translateY(-1px);
}}
a#{btn_id}:active,
a#{btn_id}.pressed {{
background: var(--accent) !important;
border-color: var(--accent) !important;
color: #fff !important;
box-shadow: 0 3px 10px rgba(239,68,68,0.35);
transform: translateY(0);
}}
a#{btn_id}:focus-visible {{
outline: none;
box-shadow: 0 0 0 0.2rem rgba(239,68,68,0.35);
}}
</style>
""",
unsafe_allow_html=True,
)
st.markdown(
f'''
<div class="dl-wrap">
<a class="dl-btn" id="{btn_id}" href="#"
data-b64="{b64}"
data-mime="{html.escape(mime)}"
data-fname="{html.escape(file_name)}">{html.escape(label)}</a>
</div>
''',
unsafe_allow_html=True,
)
components.html(
f"""
<script>
(function () {{
try {{
const doc = window.parent.document;
const a = doc.getElementById("{btn_id}");
if (!a) return;
const pressOn = () => a.classList.add("pressed");
const pressOff = () => a.classList.remove("pressed");
a.addEventListener("mousedown", pressOn, true);
a.addEventListener("mouseup", pressOff, true);
a.addEventListener("mouseleave",pressOff, true);
a.addEventListener("touchstart",pressOn, {{passive:true}});
a.addEventListener("touchend", pressOff, true);
a.addEventListener("touchcancel",pressOff, true);
a.addEventListener("click", function(ev) {{
ev.preventDefault();
ev.stopImmediatePropagation();
const b64 = a.getAttribute("data-b64");
const mime = a.getAttribute("data-mime") || "application/octet-stream";
const fname = a.getAttribute("data-fname") || "download";
const bstr = atob(b64);
const len = bstr.length;
const u8 = new Uint8Array(len);
for (let i = 0; i < len; i++) u8[i] = bstr.charCodeAt(i);
const blob = new Blob([u8], {{ type: mime }});
const url = URL.createObjectURL(blob);
const tmp = doc.createElement("a");
tmp.href = url;
tmp.download = fname;
tmp.style.display = "none";
doc.body.appendChild(tmp);
tmp.click();
setTimeout(() => {{
URL.revokeObjectURL(url);
tmp.remove();
pressOff();
}}, 150);
}}, true);
}} catch (err) {{
console.debug("download handler error:", err);
}}
}})();
</script>
""",
height=0,
)
def render_download_group(buttons: List[dict], group_key: str):
"""
buttons: list of dicts with keys: label, bytes, filename, mime, key
"""
import base64, html, hashlib
ids = []
rows_html = []
for b in buttons:
b64 = base64.b64encode(b["bytes"]).decode("ascii")
btn_id = f"dl_{group_key}_{hashlib.sha256(b['key'].encode()).hexdigest()[:8]}"
ids.append(btn_id)
rows_html.append(f"""
<div class="dl-row">
<a class="dl-btn" id="{btn_id}" href="#"
data-b64="{b64}"
data-mime="{html.escape(b['mime'])}"
data-fname="{html.escape(b['filename'])}">
{html.escape(b['label'])}
</a>
</div>
""")
st.markdown(
f"""
<div class="dl-group">
{''.join(rows_html)}
</div>
""",
unsafe_allow_html=True,
)
# one JS hook for all buttons in this group
components.html(
f"""
<script>
(function () {{
const doc = window.parent.document;
function wire(id) {{
const a = doc.getElementById(id);
if (!a) return;
const pressOn = () => a.classList.add("pressed");
const pressOff = () => a.classList.remove("pressed");
a.addEventListener("mousedown", pressOn, true);
a.addEventListener("mouseup", pressOff, true);
a.addEventListener("mouseleave", pressOff, true);
a.addEventListener("click", function(ev) {{
ev.preventDefault();
ev.stopImmediatePropagation();
const b64 = a.getAttribute("data-b64");
const mime = a.getAttribute("data-mime") || "application/octet-stream";
const fname = a.getAttribute("data-fname") || "download";
const bstr = atob(b64);
const len = bstr.length;
const u8 = new Uint8Array(len);
for (let i = 0; i < len; i++) u8[i] = bstr.charCodeAt(i);
const blob = new Blob([u8], {{ type: mime }});
const url = URL.createObjectURL(blob);
const tmp = doc.createElement("a");
tmp.href = url;
tmp.download = fname;
tmp.style.display = "none";
doc.body.appendChild(tmp);
tmp.click();
setTimeout(() => {{
URL.revokeObjectURL(url);
tmp.remove();
pressOff();
}}, 150);
}}, true);
}}
{''.join([f"wire('{i}');" for i in ids])}
}})();
</script>
""",
height=0,
)
# =========================
# Per-file processing
# =========================
def process_nifti_case(
nifti_path: str,
model,
rows_acc: List[Dict],
per_frame_rows_acc: List[pd.DataFrame],
*,
progress_cb=None,
):
"""
progress_cb(file_frac: float, msg: Optional[str]) where file_frac in [0,1]
"""
def _p(frac: float, msg: Optional[str] = None):
if progress_cb is None:
return
try:
progress_cb(float(np.clip(frac, 0.0, 1.0)), msg)
except Exception:
pass
name = Path(nifti_path).name
pid = name.replace(".nii.gz", "").replace(".nii", "")
log(f"[CASE] Using NIfTI input: {nifti_path}")
# ---- Step A: load + prep (0 -> W_LOAD_PREP) ----
_p(0.00, "Loading NIfTI...")
imgs_4d, spacing, final_aff = load_nifti_4d(nifti_path, orient_target=ORIENT_TARGET)
_p(W_LOAD_PREP, "Preparing inference...")
# Orientation bookkeeping (unchanged)
global CURRENT_DISPLAY_ORIENT
if ORIENT_TARGET is None:
try:
axc = nib.aff2axcodes(final_aff)
if tuple(axc[:3]) == ("L", "P", "S"):
CURRENT_DISPLAY_ORIENT = "LPS"
elif tuple(axc[:3]) == ("R", "A", "S"):
CURRENT_DISPLAY_ORIENT = "RAS"
else:
CURRENT_DISPLAY_ORIENT = "LPS"
except Exception:
CURRENT_DISPLAY_ORIENT = "LPS"
else:
CURRENT_DISPLAY_ORIENT = ORIENT_TARGET
native_h, native_w, S, F = imgs_4d.shape
row_native = float(spacing["row_mm"])
col_native = float(spacing["col_mm"])
thk_native = float(spacing["slice_thickness_mm"])
seg_row_mm = row_native * (native_h / float(SIZE_Y))
seg_col_mm = col_native * (native_w / float(SIZE_X))
seg_spacing = dict(
row_mm=seg_row_mm,
col_mm=seg_col_mm,
slice_thickness_mm=thk_native,
frame_time_ms=spacing.get("frame_time_ms", None),
)
# ---- Step B: prediction (W_LOAD_PREP -> W_LOAD_PREP + W_PREDICT) ----
total_images = max(int(S * F), 1)
def _pred_progress(done_images: int):
frac = float(done_images) / float(total_images)
frac = float(np.clip(frac, 0.0, 1.0))
file_frac = W_LOAD_PREP + W_PREDICT * frac
_p(file_frac, f"Predicting... ({int(round(frac * 100))}%)")
preds_4d_256, imgs_4d_256 = predict_nifti_4d(
model,
imgs_4d,
batch_size=BATCH_SIZE,
progress_cb=_pred_progress, # ✅ smooth & real
)
preds_4d = preds_4d_256
# ---- Step C: postprocess + metrics + GIF (remaining W_POST_GIF) ----
post0 = W_LOAD_PREP + W_PREDICT
# Split W_POST_GIF into real milestones
# island: 35%, metrics: 25%, gif: 40% (of the post chunk)
w_island = W_POST_GIF * 0.35
w_metrics = W_POST_GIF * 0.25
w_gif = W_POST_GIF * 0.40
# Island removal
if ENABLE_ISLAND_REMOVAL:
_p(post0 + 0.10 * w_island, "Cleaning islands...")
preds_4d = clean_predictions_per_frame_3d(preds_4d)
_p(post0 + w_island, "Cleaning islands... done")
else:
_p(post0 + w_island, "Skipping island removal")
preds_4d_orig = preds_4d.copy() # ✅ immutable snapshot
# Metrics
_p(post0 + w_island + 0.10 * w_metrics, "Computing metrics...")
voxel_mm3 = (
seg_spacing["row_mm"]
* seg_spacing["col_mm"]
* seg_spacing["slice_thickness_mm"]
)
mid_slices = None
if USE_MID_SLICES_FOR_ED_ES:
has_blood, has_myo, _, _ = slice_validity_matrix(
preds_4d, A_blood_min=MID_A_BLOOD_MIN, A_myo_min=MID_A_MYO_MIN
)
mid_slices = choose_mid_slices(
has_blood,
has_myo,
K=min(MID_K, preds_4d.shape[2]),
min_frac=MID_MIN_VALID_FRAC,
)
vols_subset = frame_volumes_subset_uL(preds_4d, seg_spacing, mid_slices)
ed_idx, es_idx = pick_ed_es_from_volumes(
vols_subset, prefer_frame0=True, rel_tol=0.05, min_sep=1
)
vols_full = np.array(
[(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)],
dtype=np.float32,
)
EDV_uL = float(vols_full[ed_idx])
ESV_uL = float(vols_full[es_idx])
else:
vols_full = np.array(
[(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)],
dtype=np.float32,
)
ed_idx, es_idx = pick_ed_es_from_volumes(
vols_full, prefer_frame0=True, rel_tol=0.05, min_sep=1
)
EDV_uL = float(vols_full[ed_idx])
ESV_uL = float(vols_full[es_idx])
SV_uL = EDV_uL - ESV_uL
EF_pct = (SV_uL / EDV_uL * 100.0) if EDV_uL > 0 else 0.0
per_frame_df = compute_per_frame_metrics(preds_4d, seg_spacing)
# Safer indexing
try:
myo_mass_ED_mg = float(per_frame_df.loc[per_frame_df["Frame"] == ed_idx, "MyocardiumMass_mg"].values[0])
except Exception:
myo_mass_ED_mg = 0.0
try:
myo_mass_ES_mg = float(per_frame_df.loc[per_frame_df["Frame"] == es_idx, "MyocardiumMass_mg"].values[0])
except Exception:
myo_mass_ES_mg = 0.0
per_frame_df.insert(0, "Patient_ID", pid)
per_frame_rows_acc.append(per_frame_df)
rows_acc.append(
{
"Patient_ID": pid,
"EDV_uL": EDV_uL,
"ESV_uL": ESV_uL,
"SV_uL": SV_uL,
"EF_%": EF_pct,
"MyocardiumMass_ED_mg": myo_mass_ED_mg,
"MyocardiumMass_ES_mg": myo_mass_ES_mg,
"ED_frame_index": int(ed_idx),
"ES_frame_index": int(es_idx),
"Seg_SliceThickness_mm": seg_spacing["slice_thickness_mm"],
"Seg_PixelSpacing_row_mm": seg_spacing["row_mm"],
"Seg_PixelSpacing_col_mm": seg_spacing["col_mm"],
"Native_SliceThickness_mm": spacing["slice_thickness_mm"],
"Native_PixelSpacing_row_mm": spacing["row_mm"],
"Native_PixelSpacing_col_mm": spacing["col_mm"],
}
)
_p(post0 + w_island + w_metrics, "Computing metrics... done")
# GIF
_p(post0 + w_island + w_metrics + 0.10 * w_gif, "Creating GIF...")
gif_path = gif_animation_for_patient_pred_only(
imgs_4d_256, preds_4d, pid, ed_idx, es_idx, get_session_gifs_dir()
)
log(f"[CASE] GIF saved: {gif_path}")
if "gif_paths" not in st.session_state:
st.session_state["gif_paths"] = {}
st.session_state["gif_paths"][pid] = gif_path
case_info = dict(
pid=pid,
imgs_4d_256=imgs_4d_256,
preds_4d=preds_4d,
preds_4d_orig=preds_4d_orig,
seg_spacing=seg_spacing,
native_spacing=spacing,
native_shape=(native_h, native_w, S, F),
affine=final_aff,
ed_idx_auto=int(ed_idx),
es_idx_auto=int(es_idx),
ed_idx_user=int(ed_idx),
es_idx_user=int(es_idx),
ed_es_confirmed=False,
mid_slices=mid_slices,
)
case_info["dirty"] = False
case_info["mask_edited"] = False
if "cases" not in st.session_state:
st.session_state["cases"] = {}
st.session_state["cases"][pid] = case_info
# ✅ Done for this file
_p(1.0, "Done")
# =========================
# Recompute metrics from corrected masks
# =========================
def recompute_metrics_from_session_cases_dirty_gif_only():
"""
Choice 1:
- Recompute metrics for ALL cases (cheap; keeps CSV consistent)
- Regenerate GIF ONLY for cases marked case["dirty"] == True (expensive; saved)
- Clear dirty flag after regenerating GIF
Uses user-confirmed ED/ES if available; otherwise falls back to current case ed/es.
"""
if "cases" not in st.session_state or not st.session_state["cases"]:
return [], [], False
rows: List[Dict] = []
per_frame_rows: List[pd.DataFrame] = []
updated_any_gif = False
if "gif_paths_corrected" not in st.session_state:
st.session_state["gif_paths_corrected"] = {}
for pid, case in st.session_state["cases"].items():
preds_4d = case["preds_4d"]
imgs_4d_256 = case["imgs_4d_256"]
seg_spacing = case["seg_spacing"]
spacing = case["native_spacing"]
# --- choose ED/ES indices ---
# Prefer user-confirmed indices if present; else fall back to stored ones.
if case.get("ed_es_confirmed", False):
ed_idx = int(case["ed_idx_user"])
es_idx = int(case["es_idx_user"])
else:
ed_idx = int(case["ed_idx_auto"])
es_idx = int(case["es_idx_auto"])
# --- recompute metrics for ALL cases (cheap) ---
voxel_mm3 = (
seg_spacing["row_mm"]
* seg_spacing["col_mm"]
* seg_spacing["slice_thickness_mm"]
)
F = preds_4d.shape[3]
vols_full = np.array(
[(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)],
dtype=np.float32,
)
EDV_uL = float(vols_full[ed_idx]) if 0 <= ed_idx < F else 0.0
ESV_uL = float(vols_full[es_idx]) if 0 <= es_idx < F else 0.0
SV_uL = EDV_uL - ESV_uL
EF_pct = (SV_uL / EDV_uL * 100.0) if EDV_uL > 0 else 0.0
per_frame_df = compute_per_frame_metrics(preds_4d, seg_spacing)
# Guard in case indices are out of bounds (shouldn't happen, but safe)
try:
myo_mass_ED_mg = float(
per_frame_df.loc[per_frame_df["Frame"] == ed_idx, "MyocardiumMass_mg"].values[0]
)
except Exception:
myo_mass_ED_mg = 0.0
try:
myo_mass_ES_mg = float(
per_frame_df.loc[per_frame_df["Frame"] == es_idx, "MyocardiumMass_mg"].values[0]
)
except Exception:
myo_mass_ES_mg = 0.0
per_frame_df.insert(0, "Patient_ID", pid)
per_frame_rows.append(per_frame_df)
rows.append(
{
"Patient_ID": pid,
"EDV_uL": EDV_uL,
"ESV_uL": ESV_uL,
"SV_uL": SV_uL,
"EF_%": EF_pct,
"MyocardiumMass_ED_mg": myo_mass_ED_mg,
"MyocardiumMass_ES_mg": myo_mass_ES_mg,
"ED_frame_index": int(ed_idx),
"ES_frame_index": int(es_idx),
"Seg_SliceThickness_mm": seg_spacing["slice_thickness_mm"],
"Seg_PixelSpacing_row_mm": seg_spacing["row_mm"],
"Seg_PixelSpacing_col_mm": seg_spacing["col_mm"],
"Native_SliceThickness_mm": spacing["slice_thickness_mm"],
"Native_PixelSpacing_row_mm": spacing["row_mm"],
"Native_PixelSpacing_col_mm": spacing["col_mm"],
}
)
# --- regenerate GIF only if dirty ---
if case.get("dirty", False):
gif_path = gif_animation_for_patient_pred_only(
imgs_4d_256,
preds_4d,
pid,
ed_idx,
es_idx,
get_session_gifs_dir(),
)
st.session_state["gif_paths_corrected"][pid] = gif_path
case["dirty"] = False
updated_any_gif = True
return rows, per_frame_rows, updated_any_gif
# ==========================================================
# Roundel-style LV-only manual correction
# ==========================================================
# This section adapts the in-house Roundel biventricular editor interface
# for the public Hugging Face / Streamlit app.
#
# Important differences from the in-house Roundel app:
# - This app is LV-only: labels are 0=background, 1=LV myocardium, 2=LV blood pool.
# - Images/masks come from st.session_state["cases"] after model inference.
# - No persistent server-side case database is used; notes/status are session-only.
# - The original Roundel visual workflow is preserved: ED/ES Finder → Mask Editor → Final Result.
# ==========================================================
LV_MYO_LABEL = 1
LV_BP_LABEL = 2
ROUNDDEL_LV_OVERLAY_COLORS = {
0: (200, 200, 200, 0),
LV_MYO_LABEL: (0, 0, 255, 70), # LV myocardium = blue
LV_BP_LABEL: (255, 10, 10, 90), # LV blood pool = red
}
ROUNDDEL_LV_BRUSH_LABELS = {
LV_MYO_LABEL: "LV Myocardium 🔵",
LV_BP_LABEL: "LV Blood Pool 🔴",
}
def _streamlit_segmented_control(label, options, default=None, key=None, label_visibility="visible"):
"""
Compatibility helper:
- New Streamlit: st.segmented_control
- Older Streamlit: st.radio
Important:
If a key is provided, initialise st.session_state[key] once and then
do NOT pass default/value/index to the widget. This avoids Streamlit
warnings/errors when we programmatically switch views.
"""
if default is None:
default = options[0]
if key is not None:
if key not in st.session_state:
st.session_state[key] = default
elif st.session_state[key] not in options:
st.session_state[key] = default
if hasattr(st, "segmented_control"):
try:
return st.segmented_control(
label,
options=options,
key=key,
label_visibility=label_visibility,
)
except TypeError:
pass
return st.radio(
label,
options=options,
horizontal=True,
key=key,
label_visibility=label_visibility,
)
def _roundel_case_label(pid: str) -> str:
case = (st.session_state.get("cases", {}) or {}).get(pid, {})
done = bool(case.get("manual_review_done", False))
note = (case.get("manual_note", "") or "").strip()
status_icon = "✅" if done else "⬜"
note_icon = " 📝" if note else ""
return f"{status_icon} {pid}{note_icon}"
def _roundel_normalize(image):
image = np.asarray(image, dtype=np.float32)
return (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-9)
def _roundel_get_overlay(image_slice_u8, mask_slice, channels_to_show):
"""
Returns an RGB PIL image, similar to the in-house Roundel overlay renderer.
image_slice_u8: 2D uint8 image at 256x256.
mask_slice: 2D uint8 label mask at 256x256.
"""
base = Image.fromarray(np.stack([image_slice_u8] * 3, axis=-1)).convert("RGBA")
H, W = image_slice_u8.shape
for ch in channels_to_show:
ch_mask = (mask_slice == int(ch))
if np.any(ch_mask):
overlay = np.zeros((H, W, 4), dtype=np.uint8)
overlay[ch_mask] = np.array(ROUNDDEL_LV_OVERLAY_COLORS[int(ch)], dtype=np.uint8)
base = Image.alpha_composite(base, Image.fromarray(overlay, mode="RGBA"))
return base.convert("RGB")
def _roundel_overlay_selector_ui():
"""
LV-only version of Roundel overlay selection.
"""
st.session_state.setdefault("roundel_overlay_overlapped_all", True)
st.session_state.setdefault("roundel_overlay_lv_bp", True)
st.session_state.setdefault("roundel_overlay_lv_myo", True)
overlapped = st.checkbox("Overlapped All", key="roundel_overlay_overlapped_all")
col_a, col_b = st.columns(2)
with col_a:
lv_bp = st.checkbox("LV Blood Pool", key="roundel_overlay_lv_bp", disabled=overlapped)
with col_b:
lv_myo = st.checkbox("LV Myocardium", key="roundel_overlay_lv_myo", disabled=overlapped)
if overlapped:
return [LV_MYO_LABEL, LV_BP_LABEL]
channels = []
if lv_myo:
channels.append(LV_MYO_LABEL)
if lv_bp:
channels.append(LV_BP_LABEL)
return channels if channels else [LV_MYO_LABEL, LV_BP_LABEL]
def _roundel_select_brush():
"""
LV-only version of Roundel brush selection.
"""
if "roundel_brush_mode" not in st.session_state:
st.session_state["roundel_brush_mode"] = "Paint ✏️"
if "roundel_stroke_width_preset" not in st.session_state:
st.session_state["roundel_stroke_width_preset"] = "thin"
if "roundel_stroke_width_value" not in st.session_state:
st.session_state["roundel_stroke_width_value"] = 6
action = st.radio(
"Brush Stroke Selection",
options=["Paint ✏️", "Erase ✂️"],
index=["Paint ✏️", "Erase ✂️"].index(st.session_state["roundel_brush_mode"]),
horizontal=True,
key="roundel_brush_mode_radio",
)
st.session_state["roundel_brush_mode"] = action
stroke_width_map = {"thin": 6, "medium": 20, "thick": 40}
preset = st.radio(
"Stroke Width (Preset)",
options=list(stroke_width_map.keys()),
index=list(stroke_width_map.keys()).index(st.session_state.get("roundel_stroke_width_preset", "thin")),
horizontal=True,
key="roundel_stroke_width_preset_radio",
)
st.session_state["roundel_stroke_width_preset"] = preset
preset_val = stroke_width_map[preset]
last_preset = st.session_state.get("_roundel_last_stroke_preset", None)
if last_preset != preset:
st.session_state["roundel_stroke_width_value"] = preset_val
st.session_state["_roundel_last_stroke_preset"] = preset
st.session_state["roundel_stroke_width_value"] = int(
np.clip(st.session_state.get("roundel_stroke_width_value", preset_val), 6, 40)
)
stroke_val = st.slider(
"Stroke Width",
min_value=6,
max_value=40,
step=1,
key="roundel_stroke_width_value",
)
if action == "Paint ✏️":
channel = st.radio(
"Mask",
options=[LV_MYO_LABEL, LV_BP_LABEL],
format_func=lambda x: ROUNDDEL_LV_BRUSH_LABELS[x],
index=0,
horizontal=True,
key="roundel_brush_label_radio",
)
else:
channel = 0
return int(channel), action, int(stroke_val)
def _extract_stroke_mask_from_canvas(canvas_rgba, background_rgb, diff_thresh=10):
"""
Roundel-style stroke extraction by comparing the returned canvas image
against the exact background image passed to st_canvas.
canvas_rgba: numpy (Hc,Wc,4) from canvas_result.image_data
background_rgb: numpy (Hc,Wc,3) from the SAME background passed to st_canvas
Returns:
binary uint8 (Hc,Wc), where 1 == newly drawn stroke pixels.
"""
if canvas_rgba is None or background_rgb is None:
return None
canvas_rgb = canvas_rgba[:, :, :3].astype(np.int16)
bg_rgb = background_rgb[:, :, :3].astype(np.int16)
alpha = canvas_rgba[:, :, 3].astype(np.int16)
diff = np.max(np.abs(canvas_rgb - bg_rgb), axis=-1)
stroke = (diff >= int(diff_thresh)) & (alpha > 0)
return stroke.astype(np.uint8)
def _resize_binary_mask(mask2d, target_h, target_w):
if mask2d is None:
return None
if mask2d.shape[0] == target_h and mask2d.shape[1] == target_w:
return mask2d.astype(np.uint8)
pil = Image.fromarray((mask2d > 0).astype(np.uint8) * 255)
pil = pil.resize((int(target_w), int(target_h)), resample=Image.NEAREST)
return (np.array(pil) > 0).astype(np.uint8)
def build_processed_stroke_and_enclosed_region(strokes, stroke_width, do_fill_holes: bool):
"""
Adapted from roundel_biventricular_utils.py.
Process a drawn stroke and also estimate the enclosed region.
Returns
-------
processed_mask : uint8 binary mask
The mask that will be used for painting/erasing.
enclosed_region : uint8 binary mask
Pixels enclosed by the drawn contour. For myocardium painting, this lets us
preserve an already existing LV blood pool inside a drawn myocardium ring.
"""
if strokes is None or not np.any(strokes):
empty = np.zeros_like(strokes, dtype=np.uint8) if strokes is not None else None
return strokes, empty
strokes = (strokes > 0).astype(np.uint8)
dilation_factor = int(max(1, round(stroke_width / 8)))
closed = binary_closing(strokes, iterations=max(1, dilation_factor))
thick = binary_dilation(closed, iterations=max(1, dilation_factor))
filled = binary_fill_holes(thick).astype(np.uint8)
enclosed_region = ((filled > 0) & (thick == 0)).astype(np.uint8)
if do_fill_holes:
working = filled
else:
working = thick
out = binary_erosion(working, iterations=max(1, dilation_factor))
out = out.astype(np.uint8)
return out, enclosed_region.astype(np.uint8)
def _roundel_apply_stroke_to_label_slice(mask_slice, stroke_mask, enclosed_region, channel, action):
"""
LV-only label-mask painting/erasing.
mask_slice labels:
0 = background
1 = LV myocardium
2 = LV blood pool
"""
updated = mask_slice.copy()
pm = (stroke_mask > 0)
if not np.any(pm):
return updated, False
if action == "Erase ✂️" or int(channel) == 0:
updated[pm] = 0
return updated, True
channel = int(channel)
if channel == LV_MYO_LABEL:
# Preserve existing blood pool if user draws a myocardium ring around it.
if enclosed_region is not None and np.any(enclosed_region):
protected_bp = (enclosed_region.astype(bool)) & (updated == LV_BP_LABEL)
writable = pm & (~protected_bp)
else:
writable = pm
updated[writable] = LV_MYO_LABEL
return updated, bool(np.any(writable))
if channel == LV_BP_LABEL:
updated[pm] = LV_BP_LABEL
return updated, True
return updated, False
def _roundel_frame_index_slider(T, frames, initial_idx, label, disabled_flag, key):
"""
LV-only Roundel ED/ES frame slider with montage preview.
Uses 0-based indices to match the original in-house Roundel app.
Important:
Do NOT pass value=... if the same widget key is also initialised
in st.session_state, otherwise Streamlit shows:
"The widget with key ... was created with a default value but also had
its value set via the Session State API."
"""
initial_idx = int(np.clip(initial_idx, 0, T - 1))
if key not in st.session_state:
st.session_state[key] = initial_idx
else:
st.session_state[key] = int(np.clip(st.session_state[key], 0, T - 1))
idx = st.slider(
f"{label} | *{int(st.session_state[key])}*",
min_value=0,
max_value=T - 1,
key=key,
disabled=disabled_flag,
)
idx = int(np.clip(idx, 0, T - 1))
st.image(frames[idx], use_column_width=True)
return idx
def _roundel_build_all_slice_preview_frames(case):
"""
Build one montage frame for every timeframe.
"""
imgs_4d_256 = case["imgs_4d_256"]
preds_4d = case["preds_4d"]
_, _, S, F = preds_4d.shape
frames = []
slice_indices = list(range(S))
for f in range(F):
frames.append(
build_frame_montage(
imgs_4d_256,
preds_4d,
frame_idx=f,
slice_indices=slice_indices,
tile_cols=4,
add_overlay=True,
)
)
return frames
def _roundel_confirm_ed_es_frames_callback(selected_pid, ed_idx, es_idx):
"""
Confirm the selected ED/ES frame indices and automatically switch
the internal Roundel view to the Mask Editor.
Important:
This is a button callback, so it runs before Streamlit recreates
the segmented-control widget on the next rerun. That avoids the
StreamlitAPIException caused by modifying a widget key after creation.
"""
cases = st.session_state.get("cases", {})
if selected_pid not in cases:
return
case = cases[selected_pid]
case["ed_idx_user"] = int(ed_idx)
case["es_idx_user"] = int(es_idx)
case["ed_es_confirmed"] = True
case["dirty"] = True
invalidate_download_caches()
# Automatically load the Mask Editor view after confirmation.
st.session_state["roundel_view_segmented"] = "Mask Editor 📝"
def _roundel_change_slice_callback(selected_pid, delta, max_slices):
"""
Change slice index safely from Previous/Next buttons.
This runs as a button callback before Streamlit recreates the slider,
so it is allowed to modify the slider's session_state key.
"""
key = f"roundel_slice_idx_{selected_pid}"
current = int(st.session_state.get(key, 0))
st.session_state[key] = int(np.clip(current + int(delta), 0, int(max_slices) - 1))
def _roundel_change_frame_callback(selected_pid, delta, max_frames):
"""
Change frame index safely from Previous/Next buttons.
Also switches frame mode to Custom.
"""
frame_key = f"roundel_frame_idx_{selected_pid}"
mode_key = f"roundel_frame_mode_{selected_pid}"
current = int(st.session_state.get(frame_key, 0))
st.session_state[frame_key] = int(np.clip(current + int(delta), 0, int(max_frames) - 1))
st.session_state[mode_key] = "Custom"
def roundel_lv_edv_esv_view(selected_pid, case):
"""
Roundel-style EDV/ESV Finder, adapted to LV-only.
"""
preds_4d = case["preds_4d"]
_, _, S, F = preds_4d.shape
if f"roundel_preview_frames_{selected_pid}" not in st.session_state:
st.session_state[f"roundel_preview_frames_{selected_pid}"] = _roundel_build_all_slice_preview_frames(case)
frames = st.session_state[f"roundel_preview_frames_{selected_pid}"]
disabled_flag = bool(case.get("ed_es_confirmed", False))
display_lv_dia_idx = int(case.get("ed_idx_user", case.get("ed_idx_auto", 0)))
display_lv_sys_idx = int(case.get("es_idx_user", case.get("es_idx_auto", 0)))
st.markdown("#### Left Ventricle")
col_edv, col_esv = st.columns(2)
with col_edv:
lv_dia_idx = _roundel_frame_index_slider(
F,
frames,
display_lv_dia_idx,
"LV End-Diastolic Index",
disabled_flag,
key=f"roundel_lv_edv_{selected_pid}",
)
with col_esv:
lv_sys_idx = _roundel_frame_index_slider(
F,
frames,
display_lv_sys_idx,
"LV End-Systolic Index",
disabled_flag,
key=f"roundel_lv_esv_{selected_pid}",
)
st.write("")
if not disabled_flag:
st.button(
"Confirm ED & ES Frames",
type="primary",
use_container_width=True,
key=f"roundel_confirm_ed_es_{selected_pid}",
on_click=_roundel_confirm_ed_es_frames_callback,
args=(selected_pid, int(lv_dia_idx), int(lv_sys_idx)),
)
else:
st.success("ED & ES frames confirmed!")
if st.button(
"Reselect ED/ES Frames",
use_container_width=True,
key=f"roundel_reselect_ed_es_{selected_pid}",
):
case["ed_es_confirmed"] = False
st.session_state["roundel_view_segmented"] = "EDV/ESV Finder 🔍"
_safe_rerun()
def _roundel_ensure_nav_state(selected_pid):
for key, default in [
(f"roundel_slice_idx_{selected_pid}", 0),
(f"roundel_frame_idx_{selected_pid}", 0),
(f"roundel_frame_mode_{selected_pid}", "End-Diastole"),
(f"roundel_canvas_nonce_{selected_pid}", 0),
]:
if key not in st.session_state:
st.session_state[key] = default
def roundel_lv_mask_editor_view(selected_pid, case):
"""
Roundel-style LV mask editor.
"""
if not case.get("ed_es_confirmed", False):
st.error("Select and confirm ED & ES frames first.")
return
imgs_4d_256 = case["imgs_4d_256"]
preds_4d = case["preds_4d"]
_, _, S, F = preds_4d.shape
_roundel_ensure_nav_state(selected_pid)
col1, col2, col3 = st.columns([1, 1.5, 1.5])
with col1:
st.markdown("#### Left Ventricle")
st.caption("LV-only editor for labels: myocardium + blood pool.")
channel, action, stroke_width = _roundel_select_brush()
st.caption("Image Selection")
lv_dia_idx = int(case["ed_idx_user"])
lv_sys_idx = int(case["es_idx_user"])
st.caption(f"LV ED = **{lv_dia_idx}** | LV ES = **{lv_sys_idx}**")
slice_key = f"roundel_slice_idx_{selected_pid}"
frame_key = f"roundel_frame_idx_{selected_pid}"
mode_key = f"roundel_frame_mode_{selected_pid}"
st.slider("Slice Index", 0, S - 1, key=slice_key)
col_prev_s, col_next_s = st.columns(2)
with col_prev_s:
st.button(
"Previous",
use_container_width=True,
key=f"roundel_slice_prev_btn_{selected_pid}",
on_click=_roundel_change_slice_callback,
args=(selected_pid, -1, S),
)
with col_next_s:
st.button(
"Next",
use_container_width=True,
key=f"roundel_slice_next_btn_{selected_pid}",
on_click=_roundel_change_slice_callback,
args=(selected_pid, 1, S),
)
frame_mode = st.radio(
"Frame",
options=["End-Diastole", "End-Systole", "Custom"],
index=["End-Diastole", "End-Systole", "Custom"].index(st.session_state.get(mode_key, "End-Diastole")),
horizontal=False,
key=f"roundel_frame_mode_radio_{selected_pid}",
)
st.session_state[mode_key] = frame_mode
if frame_mode == "End-Diastole":
st.session_state[frame_key] = lv_dia_idx
elif frame_mode == "End-Systole":
st.session_state[frame_key] = lv_sys_idx
else:
st.session_state[frame_key] = int(np.clip(st.session_state[frame_key], 0, F - 1))
st.slider("Frame Index", 0, F - 1, key=frame_key)
col_prev_f, col_next_f = st.columns(2)
with col_prev_f:
st.button(
"Previous",
use_container_width=True,
key=f"roundel_frame_prev_btn_{selected_pid}",
on_click=_roundel_change_frame_callback,
args=(selected_pid, -1, F),
)
with col_next_f:
st.button(
"Next",
use_container_width=True,
key=f"roundel_frame_next_btn_{selected_pid}",
on_click=_roundel_change_frame_callback,
args=(selected_pid, 1, F),
)
d = int(np.clip(st.session_state[f"roundel_slice_idx_{selected_pid}"], 0, S - 1))
t = int(np.clip(st.session_state[f"roundel_frame_idx_{selected_pid}"], 0, F - 1))
image_slice = imgs_4d_256[:, :, d, t]
image_slice_u8 = (_roundel_normalize(image_slice) * 255).astype(np.uint8)
mask_slice = preds_4d[:, :, d, t].astype(np.uint8)
with col2:
edit_mode = st.radio(
"Segmentation Editor",
["Editor", "Viewer"],
index=0,
horizontal=True,
key=f"roundel_seg_editor_mode_{selected_pid}",
)
channels_to_show = _roundel_overlay_selector_ui()
if edit_mode == "Viewer":
st.image(image_slice_u8, width=DISPLAY_W)
else:
stroke_color = (
"rgba(200, 200, 200, 0.55)"
if action == "Erase ✂️"
else "rgba(0, 0, 255, 0.85)" if channel == LV_MYO_LABEL
else "rgba(255, 10, 10, 0.85)"
)
bg_img = _roundel_get_overlay(image_slice_u8, mask_slice, channels_to_show)
canvas_w = int(DISPLAY_W)
canvas_h = int(SIZE_Y * DISPLAY_W / SIZE_X)
bg_img_for_canvas = bg_img.resize((canvas_w, canvas_h), resample=Image.NEAREST)
bg_np = np.array(bg_img_for_canvas, dtype=np.uint8)
canvas_key = (
f"roundel_canvas_{selected_pid}_{d}_{t}_"
f"{st.session_state.get(f'roundel_canvas_nonce_{selected_pid}', 0)}"
)
canvas_result = st_canvas(
stroke_width=stroke_width,
stroke_color=stroke_color,
background_image=bg_img_for_canvas,
update_streamlit=True,
height=canvas_h,
width=canvas_w,
drawing_mode="freedraw",
key=canvas_key,
)
col_save, col_clear = st.columns([1, 0.35])
with col_save:
save_contour = st.button("Save Contour", type="primary", use_container_width=True, key=f"roundel_save_contour_btn_{selected_pid}")
if save_contour and canvas_result and (canvas_result.image_data is not None):
canvas_rgba = np.array(canvas_result.image_data, dtype=np.uint8)
stroke_mask_canvas = _extract_stroke_mask_from_canvas(canvas_rgba, bg_np, diff_thresh=10)
if stroke_mask_canvas is None or not np.any(stroke_mask_canvas):
st.warning("No new strokes detected.")
else:
stroke_mask = _resize_binary_mask(stroke_mask_canvas, target_h=SIZE_Y, target_w=SIZE_X)
# Match Roundel behaviour: paint strokes can be closed/filled.
# For myocardium, enclosed existing blood pool is protected below.
do_fill = (action == "Paint ✏️") and (channel in (LV_MYO_LABEL, LV_BP_LABEL))
stroke_mask, enclosed_region = build_processed_stroke_and_enclosed_region(
stroke_mask,
stroke_width,
do_fill_holes=do_fill,
)
updated_slice, changed = _roundel_apply_stroke_to_label_slice(
mask_slice=preds_4d[:, :, d, t].astype(np.uint8),
stroke_mask=stroke_mask,
enclosed_region=enclosed_region,
channel=channel,
action=action,
)
if changed:
preds_4d[:, :, d, t] = updated_slice
case["preds_4d"] = preds_4d
case["dirty"] = True
case["mask_edited"] = True
st.session_state.pop(f"roundel_preview_frames_{selected_pid}", None)
invalidate_download_caches()
st.success("Contour saved.")
else:
st.warning("Stroke was detected, but no writable mask pixels were changed.")
st.session_state[f"roundel_canvas_nonce_{selected_pid}"] = int(st.session_state.get(f"roundel_canvas_nonce_{selected_pid}", 0)) + 1
_safe_rerun()
with col_clear:
if st.button("Clear Slice", use_container_width=True, key=f"roundel_clear_slice_btn_{selected_pid}"):
preds_4d[:, :, d, t] = 0
case["preds_4d"] = preds_4d
case["dirty"] = True
case["mask_edited"] = True
st.session_state.pop(f"roundel_preview_frames_{selected_pid}", None)
invalidate_download_caches()
st.session_state[f"roundel_canvas_nonce_{selected_pid}"] = int(st.session_state.get(f"roundel_canvas_nonce_{selected_pid}", 0)) + 1
st.success("LV mask cleared for this slice/frame.")
_safe_rerun()
with col3:
view_mode = st.radio(
"Corrected Mask",
["Static", "Viewer"],
index=0,
horizontal=True,
key=f"roundel_corrected_mask_view_mode_{selected_pid}",
)
if view_mode == "Viewer":
corrected = _roundel_get_overlay(
image_slice_u8,
preds_4d[:, :, d, t].astype(np.uint8),
[LV_MYO_LABEL, LV_BP_LABEL],
)
st.image(corrected, width=DISPLAY_W)
else:
if f"roundel_preview_frames_{selected_pid}" not in st.session_state:
st.session_state[f"roundel_preview_frames_{selected_pid}"] = _roundel_build_all_slice_preview_frames(case)
frames = st.session_state[f"roundel_preview_frames_{selected_pid}"]
ed = int(case["ed_idx_user"])
es = int(case["es_idx_user"])
if t == es:
st.image(frames[es], width=int(DISPLAY_W * 1.5))
else:
st.image(frames[ed], width=int(DISPLAY_W * 1.5))
def _roundel_metrics_for_case(case):
preds_4d = case["preds_4d"]
seg_spacing = case["seg_spacing"]
ed_idx = int(case["ed_idx_user"]) if case.get("ed_es_confirmed", False) else int(case["ed_idx_auto"])
es_idx = int(case["es_idx_user"]) if case.get("ed_es_confirmed", False) else int(case["es_idx_auto"])
voxel_mm3 = (
float(seg_spacing["row_mm"])
* float(seg_spacing["col_mm"])
* float(seg_spacing["slice_thickness_mm"])
)
F = preds_4d.shape[3]
ed_idx = int(np.clip(ed_idx, 0, F - 1))
es_idx = int(np.clip(es_idx, 0, F - 1))
vols = np.array([(preds_4d[..., f] == LV_BP_LABEL).sum() * voxel_mm3 for f in range(F)], dtype=np.float64)
masses = np.array([(preds_4d[..., f] == LV_MYO_LABEL).sum() * voxel_mm3 * MYO_DENSITY for f in range(F)], dtype=np.float64)
edv = float(vols[ed_idx])
esv = float(vols[es_idx])
sv = float(edv - esv)
ef = float((sv / edv * 100.0) if edv > 0 else 0.0)
return {
"ed_idx": ed_idx,
"es_idx": es_idx,
"edv_ul": edv,
"esv_ul": esv,
"sv_ul": sv,
"ef_percent": ef,
"mass_ed_mg": float(masses[ed_idx]),
"mass_es_mg": float(masses[es_idx]),
}
def _render_corrected_downloads_group():
"""
Corrected outputs: CSV + masks + GIFs.
"""
if "csv_bytes_corrected" not in st.session_state or "csv_name_corrected" not in st.session_state:
return
st.markdown("#### Updated Results")
if "mask_zip_bytes_corrected" not in st.session_state:
build_corrected_masks_zip_bytes_per_mouse()
if "gif_zip_bytes_corrected" not in st.session_state:
build_corrected_gif_zip_bytes()
buttons = [
{
"label": "Download CSV (corrected)",
"bytes": st.session_state["csv_bytes_corrected"],
"filename": st.session_state["csv_name_corrected"],
"mime": "text/csv",
"key": "corr_csv",
},
{
"label": "Download Masks (corrected)",
"bytes": st.session_state["mask_zip_bytes_corrected"],
"filename": st.session_state["mask_zip_name_corrected"],
"mime": "application/zip",
"key": "corr_masks",
},
{
"label": "Download GIF (corrected)",
"bytes": st.session_state["gif_zip_bytes_corrected"],
"filename": st.session_state["gif_zip_name_corrected"],
"mime": "application/zip",
"key": "corr_gif",
},
]
render_download_group(buttons, group_key="corr")
def _roundel_recompute_corrected_outputs():
"""
Wrapper around the existing corrected-output pipeline.
"""
new_rows, new_per_frame_rows, updated_any_gif = recompute_metrics_from_session_cases_dirty_gif_only()
if not new_rows:
st.error("No cases available to recompute metrics.")
return
csv_path = write_all_in_one_csv(new_rows, new_per_frame_rows, get_session_csv_dir())
base_name = "Results"
if st.session_state.get("_last_zip_name"):
base_name = Path(st.session_state["_last_zip_name"]).stem + "_Results_corrected"
csv_download_name = f"{base_name}.csv"
with open(csv_path, "rb") as fcsv:
csv_bytes = fcsv.read()
st.session_state["csv_bytes_corrected"] = csv_bytes
st.session_state["csv_name_corrected"] = csv_download_name
invalidate_download_caches()
st.success(
"CSV updated. Regenerated results for edited mice."
if updated_any_gif
else "CSV updated. No edited mice detected — GIFs unchanged."
)
def roundel_lv_final_result_view(selected_pid, case):
"""
Roundel-style final result view, LV-only.
"""
if not case.get("ed_es_confirmed", False):
st.error("Select and confirm ED & ES frames first.")
return
metrics = _roundel_metrics_for_case(case)
ed_idx = int(metrics["ed_idx"])
es_idx = int(metrics["es_idx"])
# Generate/update this case GIF if needed.
if case.get("dirty", False) or selected_pid not in (st.session_state.get("gif_paths_corrected", {}) or {}):
if "gif_paths_corrected" not in st.session_state:
st.session_state["gif_paths_corrected"] = {}
gif_path = gif_animation_for_patient_pred_only(
case["imgs_4d_256"],
case["preds_4d"],
selected_pid,
ed_idx,
es_idx,
get_session_gifs_dir(),
)
st.session_state["gif_paths_corrected"][selected_pid] = gif_path
case["dirty"] = False
col_metrics, col_gif = st.columns([0.30, 0.70])
with col_metrics:
st.caption("LV Metrics")
st.metric("EDV", f"{metrics['edv_ul']:.1f} µL")
st.metric("ESV", f"{metrics['esv_ul']:.1f} µL")
st.metric("SV", f"{metrics['sv_ul']:.1f} µL")
st.metric("EF", f"{metrics['ef_percent']:.1f} %")
st.metric("Mass (ED)", f"{metrics['mass_ed_mg']:.1f} mg")
st.metric("Mass (ES)", f"{metrics['mass_es_mg']:.1f} mg")
with col_gif:
st.caption("Final LV Mask")
gif_path = (st.session_state.get("gif_paths_corrected", {}) or {}).get(selected_pid)
if gif_path and os.path.exists(gif_path):
with open(gif_path, "rb") as fg:
st.image(fg.read(), caption=f"{selected_pid} – ED/ES GIF (corrected)", use_column_width=True)
else:
st.info("No corrected GIF available yet.")
st.markdown("---")
if st.button("Recompute cardiac function & regenerate CSV/GIFs from corrected masks", type="primary", use_container_width=True, key="roundel_recompute_metrics_btn"):
with st.spinner("Recomputing metrics and regenerating CSV/GIFs from corrected masks..."):
_roundel_recompute_corrected_outputs()
_render_corrected_downloads_group()
if st.button("Mark review complete ✅", use_container_width=True, key=f"roundel_mark_done_{selected_pid}"):
case["manual_review_done"] = True
st.success("Review marked as complete.")
def roundel_lv_editor_app():
"""
Top-level Roundel LV Editor tab.
"""
st.write("# Roundel App (2D LV)")
if "cases" not in st.session_state or not st.session_state["cases"]:
st.info("Process a ZIP file first in the **Segmentation App** tab.")
return
# No extra red enable button.
# The editor is enabled directly from the Segmentation App tab when the user selects Yes.
if not st.session_state.get("open_roundel_lv_editor", False):
st.info(
"Manual correction is optional. Select **Yes** in the **Segmentation App** tab "
"to enable the Roundel LV Editor."
)
return
cases = st.session_state["cases"]
case_ids = list(cases.keys())
col1, col2 = st.columns([0.35, 0.65])
with col1:
selected_pid = st.selectbox(
"Select NIfTI / mouse ID",
options=case_ids,
index=0,
format_func=lambda pid: _roundel_case_label(pid),
key="roundel_sax_series_uid_selectbox",
)
case = cases[selected_pid]
with col2:
seg_spacing = case.get("seg_spacing", {})
native_spacing = case.get("native_spacing", {})
frame_interval = native_spacing.get("frame_time_ms", None)
frame_interval_label = f"{float(frame_interval):.3g} ms" if frame_interval else "not available"
st.markdown(
f"**Pixel Size**: {seg_spacing.get('row_mm', np.nan):.4g} x "
f"{seg_spacing.get('col_mm', np.nan):.4g} mm | "
f"**Slice Thickness**: {seg_spacing.get('slice_thickness_mm', np.nan):.4g} mm | "
f"**Frame Interval**: {frame_interval_label}"
)
done_key = f"roundel_case_done__{selected_pid}"
note_key = f"roundel_case_note__{selected_pid}"
if done_key not in st.session_state:
st.session_state[done_key] = bool(case.get("manual_review_done", False))
if note_key not in st.session_state:
st.session_state[note_key] = str(case.get("manual_note", ""))
done_val = st.checkbox("Status: Complete", key=done_key)
case["manual_review_done"] = bool(done_val)
note_val = st.text_area(
"Notes",
key=note_key,
height=90,
placeholder="Type notes for this case…",
)
case["manual_note"] = str(note_val)
view = _streamlit_segmented_control(
"Tab",
options=["EDV/ESV Finder 🔍", "Mask Editor 📝", "Final Result ✅"],
default="EDV/ESV Finder 🔍",
key="roundel_view_segmented",
label_visibility="hidden",
)
st.divider()
if view == "EDV/ESV Finder 🔍":
roundel_lv_edv_esv_view(selected_pid, case)
if view == "Mask Editor 📝":
roundel_lv_mask_editor_view(selected_pid, case)
if view == "Final Result ✅":
roundel_lv_final_result_view(selected_pid, case)
def switch_to_tab_by_label(tab_label: str):
"""
Programmatically click a Streamlit tab by its visible label.
Streamlit does not currently provide a native Python API to select st.tabs().
This JavaScript clicks the tab button in the parent document.
Notes:
- This is a UI workaround.
- It works by matching the visible tab text exactly.
- If Streamlit changes its internal DOM, this may need updating.
"""
components.html(
f"""
<script>
(function() {{
const tabLabel = {json.dumps(tab_label)};
const doc = window.parent.document;
function normaliseText(x) {{
return (x || '').replace(/\\s+/g, ' ').trim();
}}
function clickTab() {{
const tabs = Array.from(doc.querySelectorAll('button[role="tab"]'));
const target = tabs.find(tab => normaliseText(tab.innerText) === tabLabel);
if (target) {{
target.click();
target.scrollIntoView({{behavior: "smooth", block: "nearest", inline: "center"}});
return true;
}}
return false;
}}
if (!clickTab()) {{
setTimeout(clickTab, 100);
setTimeout(clickTab, 250);
setTimeout(clickTab, 500);
setTimeout(clickTab, 1000);
}}
}})();
</script>
""",
height=0,
)
def _manual_correction_radio_changed():
"""
When the user selects Yes/No in the Segmentation App tab:
- Yes -> enable Roundel LV Editor and request tab switch
- No -> disable Roundel LV Editor
"""
choice = st.session_state.get("manual_correction_choice", "No")
if choice == "Yes":
st.session_state["open_roundel_lv_editor"] = True
st.session_state["_switch_to_roundel_lv_editor"] = True
else:
st.session_state["open_roundel_lv_editor"] = False
st.session_state["_switch_to_roundel_lv_editor"] = False
# =========================
# UI
# =========================
def main():
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
_inject_layout_css()
st.markdown(
f'''
<div id="fixed-edge-logo" aria-hidden="true" role="presentation">
<img src="{LOGO_URL}" alt="Pre-Clinical Cardiac MRI Segmentation">
</div>
<div class="edge-logo-spacer"></div>
''',
unsafe_allow_html=True,
)
tab1, tab_roundel, tab2, tab3 = st.tabs(
["Segmentation App", "Roundel LV Editor", "Dataset", "NIfTI converter"]
)
# ===== Tab 1: Segmentation App =====
with tab1:
st.markdown(
"""
<style>
[data-testid="stHeading"] a,
h1 a[href^="#"],
h2 a[href^="#"],
h3 a[href^="#"] {
display: none !important;
visibility: hidden !important;
}
</style>
""",
unsafe_allow_html=True,
)
HERO_HTML = dedent(
"""\
<div class="content-wrap">
<div class="measure-wrap">
<div class="text-wrap">
<h1 class="hero-title">
Open-Source Pre-Clinical Image Segmentation:<br>
Mouse cardiac MRI datasets with a deep learning segmentation framework
</h1>
</div>
<div class="text-wrap">
<p>We present the first publicly-available pre-clinical cardiac MRI dataset, along with an open-source DL segmentation model (both available on GitHub:
<a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank" rel="noopener">https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git</a>) and this web-based interface for easy deployment.</p>
<p>The dataset comprises complete cine short-axis cardiac MRI images from 130 mice with diverse phenotypes. It also contains expert manual segmentations of left ventricular (LV) blood pool and myocardium at end-diastole, end-systole, as well as additional timeframes with artefacts to improve robustness.</p>
<p>Using this resource, we developed an open-source DL segmentation model based on the UNet3+ architecture.</p>
<p>This Streamlit application consists of the inference model to provide an easy-to-use interface for our DL segmentation model, without the need for local installation. The application requires the complete SAX cine image data to be uploaded in NIfTI format, as a ZIP file using the simple file browser below.</p>
<p>Pre-processing and inference are then performed on all 2D images. The resulting blood-pool and myocardial volumes are combined across all slices at each timeframe and output to a .csv file. The blood-pool volumes are used to identify ED and ES, and these volumes are displayed as a GIF with the segmentations overlaid.</p>
<p class="note-text">(Note: This Hugging Face model was developed as part of a manuscript submitted to the <em>Journal of Cardiovascular Magnetic Resonance</em>)</p>
</div>
</div>
</div>
"""
)
st.markdown(HERO_HTML, unsafe_allow_html=True)
st.markdown(
'<div class="content-wrap"><div class="measure-wrap" id="upload-wrap">',
unsafe_allow_html=True,
)
st.markdown(
"""
<h2 style='margin-bottom:0.2rem;'>
Data Upload <span style='font-size:33px;'>📤</span>
</h2>
""",
unsafe_allow_html=True,
)
uploaded_zip = st.file_uploader(
"Upload ZIP of NIfTI files 🐭",
type="zip",
label_visibility="visible",
)
st.markdown(
"""
<p style="margin-top:0.3rem; font-size:15px; color:#444;">
Or download our <a href="https://huggingface.co/spaces/mrphys/Pre-clinical_DL_segmentation/tree/main/NIfTI_dataset" target="_blank" rel="noopener">
sample NIfTI dataset</a> to try it out!
</p>
""",
unsafe_allow_html=True,
)
st.markdown("</div></div>", unsafe_allow_html=True)
# Reset state when a new ZIP is uploaded (by CONTENT, not only name)
if uploaded_zip is not None:
zip_bytes_now = uploaded_zip.getvalue()
zip_hash_now = hashlib.sha256(zip_bytes_now).hexdigest()
if st.session_state.get("_last_zip_hash") != zip_hash_now:
# Clear all dynamic Roundel/manual-correction keys from previous uploads
for k in list(st.session_state.keys()):
if (
k.startswith("roundel_")
or k.startswith("_roundel_")
or k.startswith("manual_correction")
or k.startswith("_manual_correction")
):
st.session_state.pop(k, None)
for key in [
"csv_bytes_orig",
"csv_name_orig",
"csv_bytes_corrected",
"csv_name_corrected",
"rows_count",
"cases",
"gif_paths",
"gif_paths_original",
"gif_paths_corrected",
"session_tmpdir",
# cached zips
"mask_zip_bytes",
"mask_zip_name",
"gif_zip_bytes",
"gif_zip_name",
"mask_zip_bytes_corrected",
"mask_zip_name_corrected",
"gif_zip_bytes_corrected",
"gif_zip_name_corrected",
# processing state
"processing_done",
"processing_running",
# manual correction / Roundel state
"open_roundel_lv_editor",
"manual_correction_choice",
"_switch_to_roundel_lv_editor",
"roundel_view_segmented",
"roundel_sax_series_uid_selectbox",
"roundel_overlay_overlapped_all",
"roundel_overlay_lv_bp",
"roundel_overlay_lv_myo",
"roundel_brush_mode",
"roundel_stroke_width_preset",
"roundel_stroke_width_value",
"_roundel_last_stroke_preset",
]:
st.session_state.pop(key, None)
invalidate_download_caches()
st.session_state["_last_zip_hash"] = zip_hash_now
st.session_state["_last_zip_name"] = uploaded_zip.name
st.session_state["_last_zip_bytes"] = zip_bytes_now # reuse without re-reading
def extract_zip(zip_path, extract_to):
os.makedirs(extract_to, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
valid_files = [
f
for f in zip_ref.namelist()
if "__MACOSX" not in f and not os.path.basename(f).startswith("._")
]
zip_ref.extractall(extract_to, members=valid_files)
# --- Process button / ZIP pipeline ---
if uploaded_zip and st.button("Process Data"):
# ✅ processing flags
st.session_state["processing_running"] = True
st.session_state["processing_done"] = False
zip_label = uploaded_zip.name or "ZIP"
# ✅ 1) fixed-position spinner placeholder (right after button)
spinner_ph = st.empty()
# ✅ 2) progress + status BELOW spinner
prog = st.progress(0.0)
status = st.empty()
# ✅ 3) GIF gallery BELOW everything (so it doesn't push spinner)
gif_gallery = st.container()
# Use cached bytes (set in the reset block); fallback to getvalue()
zip_bytes = st.session_state.get("_last_zip_bytes", uploaded_zip.getvalue())
try:
# ✅ render spinner INSIDE placeholder so it stays in one fixed slot
with spinner_ph:
with st.spinner(f"Processing {zip_label}..."):
tmpdir = tempfile.mkdtemp()
zpath = os.path.join(tmpdir, uploaded_zip.name)
# write uploaded zip to disk (DON'T use uploaded_zip.read())
with open(zpath, "wb") as f:
f.write(zip_bytes)
extract_zip(zpath, tmpdir)
nii_files: List[str] = []
for root, _, files in os.walk(tmpdir):
for fn in files:
low = fn.lower()
if low.endswith(".nii") or low.endswith(".nii.gz"):
nii_files.append(os.path.join(root, fn))
if not nii_files:
st.error("No NIfTI files (.nii / .nii.gz) found in the uploaded ZIP.")
else:
model = get_model()
log("[MODEL] Loaded segmentation model (cached).")
rows: List[Dict] = []
per_frame_rows: List[pd.DataFrame] = []
st.session_state["cases"] = {}
st.session_state["gif_paths"] = {}
nii_files_sorted = sorted(nii_files)
total = len(nii_files_sorted)
# Start at 0%
prog.progress(0.0)
status.markdown(
f"**ZIP:** `{zip_label}` \n"
f"**Progress:** 0% (0/{total}) \n"
f"**Now:** `-` \n"
f"**Step:** Waiting..."
)
for i, fp in enumerate(nii_files_sorted, start=1):
name = Path(fp).name
pid = name.replace(".nii.gz", "").replace(".nii", "")
# Fraction completed by finished files
base = (i - 1) / max(total, 1)
span = 1.0 / max(total, 1)
# Avoid backwards jitter (per-file)
overall_state = {"last_overall": float(base)}
def file_progress_cb(file_frac: float, msg: Optional[str] = None):
"""
file_frac: 0..1 for THIS file; map to global progress 0..1
msg: step message (optional)
"""
overall = base + span * float(np.clip(file_frac, 0.0, 1.0))
# Never go backwards
overall = max(overall, overall_state["last_overall"])
overall_state["last_overall"] = overall
prog.progress(float(np.clip(overall, 0.0, 1.0)))
pct = int(round(overall * 100))
line_msg = msg or "Working..."
status.markdown(
f"**ZIP:** `{zip_label}` \n"
f"**Progress:** {pct}% ({i}/{total}) \n"
f"**Now:** `{pid}` \n"
f"**Step:** {line_msg}"
)
# Make sure UI shows which file started immediately
file_progress_cb(0.0, "Starting...")
try:
process_nifti_case(
fp,
model,
rows,
per_frame_rows,
progress_cb=file_progress_cb, # ✅ weighted, real progress + step messages
)
# ✅ show GIF immediately (stacking)
gif_path = st.session_state.get("gif_paths", {}).get(pid)
if gif_path and os.path.exists(gif_path):
with gif_gallery:
with open(gif_path, "rb") as fgif:
st.image(
fgif.read(),
caption=f"{pid} – ED/ES GIF (original)",
use_column_width=True,
)
except Exception as e:
st.warning(f"Failed: {name}{e}")
# After each file, force progress to at least the file boundary
file_progress_cb(1.0, "Done")
# Ensure end at 100%
prog.progress(1.0)
status.markdown(f"✅ **Done:** `{zip_label}` — {total}/{total} (100%)")
csv_path = write_all_in_one_csv(rows, per_frame_rows, get_session_csv_dir())
csv_download_name = f"{Path(zip_label).stem}_Results.csv"
with open(csv_path, "rb") as fcsv:
st.session_state["csv_bytes_orig"] = fcsv.read()
st.session_state["csv_name_orig"] = csv_download_name
st.session_state["rows_count"] = len(rows)
st.session_state["gif_paths_original"] = dict(st.session_state.get("gif_paths", {}))
finally:
# ✅ always clear spinner + set flags, even if something errors
spinner_ph.empty()
st.session_state["processing_running"] = False
st.session_state["processing_done"] = True
invalidate_download_caches()
# --- Reserve stable slots so layout doesn't jump ---
downloads_slot = st.container()
downloads_spacer = st.empty()
if not st.session_state.get("processing_done", False):
downloads_spacer.markdown(
"<div style='min-height:120px'></div>",
unsafe_allow_html=True
)
else:
downloads_spacer.empty()
with downloads_slot:
# Show the green success banner once (only if results exist)
if "csv_bytes_orig" in st.session_state and "csv_name_orig" in st.session_state:
st.success(f"Processed {st.session_state.get('rows_count', 0)} NIfTI file(s).")
# =========================
# ORIGINAL downloads group
# =========================
# 1) Ensure ORIGINAL mask zip exists
if "cases" in st.session_state and st.session_state["cases"]:
if "mask_zip_bytes" not in st.session_state:
zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem
mask_zip_name = f"{zipstem}_Mask.zip"
mask_files = []
for pid, case in st.session_state["cases"].items():
native_h, native_w, S, F = case["native_shape"]
affine = case["affine"]
mask_native = resize_masks_to_native(case["preds_4d_orig"], native_h, native_w)
session_tmp = get_session_tmpdir()
tmp_mask_path = os.path.join(session_tmp, f"{pid}_mask.nii.gz")
save_native_mask_nifti_with_labels(mask_native, affine, tmp_mask_path)
with open(tmp_mask_path, "rb") as f:
mask_files.append((f"{pid}_mask.nii.gz", f.read()))
try:
os.remove(tmp_mask_path)
except OSError:
pass
st.session_state["mask_zip_bytes"] = build_zip_bytes(mask_files, root_folder=None)
st.session_state["mask_zip_name"] = mask_zip_name
# 2) Ensure ORIGINAL gif zip exists
if "gif_paths_original" in st.session_state:
if "gif_zip_bytes" not in st.session_state:
zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem
gif_root = f"{zipstem}_GIF"
gif_zip_name = f"{gif_root}.zip"
gif_files = []
orig = st.session_state.get("gif_paths_original", {}) or {}
for pid, gif_path in orig.items():
if gif_path and os.path.exists(gif_path):
with open(gif_path, "rb") as f:
gif_files.append((f"{pid}.gif", f.read()))
st.session_state["gif_zip_bytes"] = build_zip_bytes(gif_files, root_folder=gif_root)
st.session_state["gif_zip_name"] = gif_zip_name
# 3) Build ONE button group
buttons = []
if "csv_bytes_orig" in st.session_state and "csv_name_orig" in st.session_state:
buttons.append({
"label": "Download CSV",
"bytes": st.session_state["csv_bytes_orig"],
"filename": st.session_state["csv_name_orig"],
"mime": "text/csv",
"key": "orig_csv",
})
if "mask_zip_bytes" in st.session_state and "mask_zip_name" in st.session_state:
buttons.append({
"label": "Download Masks",
"bytes": st.session_state["mask_zip_bytes"],
"filename": st.session_state["mask_zip_name"],
"mime": "application/zip",
"key": "orig_masks",
})
if "gif_zip_bytes" in st.session_state and "gif_zip_name" in st.session_state:
buttons.append({
"label": "Download GIF",
"bytes": st.session_state["gif_zip_bytes"],
"filename": st.session_state["gif_zip_name"],
"mime": "application/zip",
"key": "orig_gif",
})
if buttons:
render_download_group(buttons, group_key="orig")
# ==========================================================
# Optional manual correction entry point
# ==========================================================
if "cases" in st.session_state and st.session_state["cases"]:
st.markdown("---")
st.markdown("## Optional: Manual correction")
manual_choice = st.radio(
"Would you like to manually correct the LV segmentation masks?",
options=["No", "Yes"],
horizontal=True,
key="manual_correction_choice",
on_change=_manual_correction_radio_changed,
)
# Keep state synced
st.session_state["open_roundel_lv_editor"] = (manual_choice == "Yes")
if manual_choice == "Yes":
st.success("Manual correction enabled. Opening the **Roundel LV Editor** tab...")
# Switch only once after user selects Yes
if st.session_state.get("_switch_to_roundel_lv_editor", False):
st.session_state["_switch_to_roundel_lv_editor"] = False
switch_to_tab_by_label("Roundel LV Editor")
else:
st.session_state["_switch_to_roundel_lv_editor"] = False
st.caption(
"Manual correction is optional. You can use the automatic CSV, mask and GIF downloads above."
)
# ===== Tab 2: Roundel LV Editor =====
with tab_roundel:
roundel_lv_editor_app()
# ===== Tab 2: Dataset =====
with tab2:
st.markdown(
"""
<style>
.ds-hero-section {
background: #082c3a;
padding: 30px 10px;
text-align: center;
margin-left: -100vw;
margin-right: -100vw;
left: 0; right: 0; position: relative;
}
.ds-hero-section-inner { max-width: 1100px; margin: 0 auto; }
.ds-heroimg {
max-width: 1000px; width: 100%; height: auto;
border-radius: 10px; box-shadow: 0 8px 24px rgba(0,0,0,.25);
display: block; margin: 0 auto;
}
.ds-caption {
text-align: center; color: #e0f2f1;
font-size: 18px; line-height: 1.5;
margin: 14px 6px 0; font-style: italic;
}
.ds-hr {
height: 8px;
border: 0; background: #ea580c;
margin: 24px 0 20px;
border-radius: 3px;
}
.ds-wrap {
max-width: var(--content-measure, 920px);
margin-left: var(--left-offset, 40px);
margin-right: auto;
background: #fff; padding: 16px 24px; border-radius: 6px;
}
.ds-section h2 {
font-size: 20px; font-weight: 700;
margin: 0 0 2px;
color: #082c3a;
}
.ds-section p {
font-size: 16px; line-height: 1.6; color: #333;
margin: 0 0 6px;
}
.ds-section ul {
margin: 2px 0 8px 18px;
padding: 0;
}
.ds-section li {
font-size: 16px; line-height: 1.6; color: #333;
margin-bottom: 10px;
}
.ds-section a {
color: #0b66c3 !important; text-decoration: underline !important;
}
h2 a, [data-testid="stHeading"] a { display: none !important; }
</style>
<div class="ds-hero-section">
<div class="ds-hero-section-inner">
<img class="ds-heroimg"
src="https://raw.githubusercontent.com/whanisa/Segmentation/main/icon/open_source.png"
alt="Illustration of mouse with heart representing open-source pre-clinical cardiac MRI dataset" />
<p class="ds-caption">
The first publicly-available pre-clinical cardiac MRI dataset,<br/>
with an open-source segmentation model and an easy-to-use web app.
</p>
</div>
</div>
<hr class="ds-hr"/>
<div class="ds-wrap">
<div class="ds-section">
<h2>Repository & Paper Resources</h2>
<p>GitHub:
<a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank">
Open-Source_Pre-Clinical_Segmentation
</a>
</p>
<h2>📊 Dataset Availability</h2>
<ul>
<li>
<strong>Full dataset (130 mice, HDF5 format):</strong><br/>
Available in our
<a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation/tree/master/Data" target="_blank">
GitHub repository
</a>.<br/>
Each .h5 file contains the complete cine SAX MRI and expert manual segmentations.
</li>
<li>
<strong>Sample datasets (3 mice, NIfTI format):</strong><br/>
Available here:
<a href="https://huggingface.co/spaces/mrphys/Pre-clinical_DL_segmentation/tree/main/NIfTI_dataset" target="_blank">
NIfTI Sample Dataset
</a>.<br/>
We provide 3 example NIfTI datasets for quick download and direct use within the app.
</li>
</ul>
</div>
<hr class="ds-hr"/>
<div class="ds-section">
<h2>Notes</h2>
<ul>
<li>Complete SAX cine MRI for 130 mice with expert LV blood & myocardium labels (ED/ES).</li>
</ul>
</div>
</div>
""",
unsafe_allow_html=True,
)
# ===== Tab 3: NIfTI converter =====
with tab3:
st.subheader("NIfTI converter")
st.markdown(
"""
**Working with Agilent data?**
Easily convert your fid files to NIfTI using our **fid2niix**.
```bash
fid2niix -z y -o /path/to/out -f "%p_%s" /path/to/fid_folder
```
**💡 Tips**
- Upload a ZIP file that includes both `fid` and `procpar`.
- Conversion outputs **NIfTI-1** format, ready to use with our web app.
"""
)
if __name__ == "__main__":
main()