Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import hashlib | |
| import io | |
| import json | |
| import math | |
| import zipfile | |
| import time | |
| import random | |
| import tempfile | |
| import uuid | |
| import streamlit.components.v1 as components | |
| from textwrap import dedent | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import nibabel as nib | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import matplotlib.pyplot as plt | |
| from matplotlib import animation | |
| from PIL import Image, ImageDraw, ImageFont | |
| from skimage.measure import label as cc_label, regionprops | |
| from scipy.ndimage import ( | |
| binary_fill_holes, | |
| binary_dilation, | |
| binary_erosion, | |
| binary_closing, | |
| ) # use SciPy morphology | |
| from skimage.morphology import disk | |
| import streamlit as st | |
| import base64 | |
| from streamlit_drawable_canvas import st_canvas | |
| from utils.layer_util import ResizeAndConcatenate | |
| # ========================= | |
| # Config / paths | |
| # ========================= | |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
| # Uncomment next line to force CPU only | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| MODEL_PATH = "./models_final/SEG459.h5" | |
| # --------------------------------------------------------- | |
| # Keras custom_objects safety: | |
| # If the saved model references these symbols by name, Keras | |
| # may try to resolve them even with compile=False. | |
| # Provide harmless dummy callables instead of None. | |
| # --------------------------------------------------------- | |
| def _dummy(*args, **kwargs): | |
| # Works for losses/metrics/custom functions; never used for inference. | |
| return tf.constant(0.0) | |
| _CUSTOM_OBJECTS = { | |
| "focal_tversky_loss": _dummy, | |
| "dice_coef_no_bkg": _dummy, | |
| "ResizeAndConcatenate": ResizeAndConcatenate, # must stay real | |
| "dice_myo": _dummy, | |
| "dice_blood": _dummy, | |
| "dice": _dummy, | |
| } | |
| def get_model(): | |
| # Optional: avoid re-tracing / keep memory stable in some environments | |
| # tf.keras.backend.clear_session() | |
| return keras.models.load_model( | |
| MODEL_PATH, | |
| custom_objects=_CUSTOM_OBJECTS, | |
| compile=False, | |
| ) | |
| OUTPUT_ROOT = "/tmp/NIFTI_OUTPUTS" | |
| # Inference & metrics settings | |
| SIZE_X, SIZE_Y = 256, 256 # UNet3+ input grid | |
| TARGET_HW = (SIZE_Y, SIZE_X) | |
| N_CLASSES = 3 | |
| BATCH_SIZE = 16 | |
| MYO_DENSITY = 1.05 # g/mL → mg via mm^3 * g/mL | |
| # Resize mode for input images -> model grid | |
| RESIZE_MODE = "auto" # "auto", "bilinear", "area", or "fft" | |
| # --- Progress weights (must sum to 1.0) --- | |
| W_LOAD_PREP = 0.15 | |
| W_PREDICT = 0.65 | |
| W_POST_GIF = 0.20 | |
| assert abs((W_LOAD_PREP + W_PREDICT + W_POST_GIF) - 1.0) < 1e-6 | |
| # Island removal (3D) | |
| ENABLE_ISLAND_REMOVAL = True | |
| ISLAND_MIN_SLICE_SPAN = 2 | |
| ISLAND_MIN_AREA_PER_SLICE = 10 | |
| ISLAND_CENTROID_DIST_THRESH = 40 | |
| # Orientation / display | |
| ORIENT_TARGET = None # None (keep native), "LPS", or "RAS" | |
| DISPLAY_MATCH_DICOM = False | |
| DISPLAY_RULES = { | |
| "LPS": dict(rot90_cw=True, flip_ud=True, flip_lr=False), | |
| "RAS": dict(rot90_cw=True, flip_ud=False, flip_lr=True), | |
| None: dict(rot90_cw=False, flip_ud=False, flip_lr=False), | |
| } | |
| CURRENT_DISPLAY_ORIENT = ORIENT_TARGET | |
| # ED/ES robust selection (mid-slices subset) | |
| USE_MID_SLICES_FOR_ED_ES = True | |
| MID_K = 4 | |
| MID_MIN_VALID_FRAC = 0.7 | |
| MID_A_BLOOD_MIN = 30 | |
| MID_A_MYO_MIN = 30 | |
| GIF_FPS = 1 | |
| GIF_DPI = 300 | |
| # Bigger drawing canvas for manual corrections (display only; masks still 256x256) | |
| CANVAS_SCALE = 3 # 2x/3x/4x larger canvas than model grid | |
| # Roundel-style display width for the LV editor | |
| DISPLAY_W = 400 | |
| # ========================= | |
| # Branding / UI | |
| # ========================= | |
| LOGO_URL = ( | |
| "https://raw.githubusercontent.com/whanisa/Segmentation/main/icon/logo.png" | |
| ) | |
| LOGO_LINK = "https://github.com/whanisa/Segmentation" | |
| LOGO_HEIGHT_PX = 120 | |
| SAFE_INSET_PX = 18 | |
| # --- CSS | |
| def _inject_layout_css(): | |
| """ | |
| Desktop-first layout with responsive overrides. | |
| ✅ File uploader styling (HF-proof): | |
| - Label strip stays WHITE (wrapper remains transparent) | |
| - Dropzone surface GREY (#f0f2f6) | |
| - "Browse files" button WHITE | |
| """ | |
| CONTENT_MEASURE_PX = 920 | |
| LEFT_OFFSET_PX = 40 | |
| UPLOAD_WIDTH_PX = 420 | |
| mobile_logo_h = max(48, int(LOGO_HEIGHT_PX * 0.7)) | |
| st.markdown( | |
| f""" | |
| <style> | |
| /* ========================= | |
| 0) Design tokens | |
| ========================= */ | |
| :root {{ | |
| --content-measure: {CONTENT_MEASURE_PX}px; | |
| --left-offset: {LEFT_OFFSET_PX}px; | |
| --upload-width: {UPLOAD_WIDTH_PX}px; | |
| --logo-height: {LOGO_HEIGHT_PX}px; | |
| --edge-x: max(12px, env(safe-area-inset-left)); | |
| --header-clear: 40px; | |
| --edge-y: calc(env(safe-area-inset-top) + var(--header-clear)); | |
| --tabs-top-gap: calc(var(--logo-height) + 16px); | |
| --tabs-left-shift: 32px; | |
| --accent: #ef4444; | |
| --surface: #ffffff; | |
| --text: #111827; | |
| --muted: #374151; | |
| --body: #333333; | |
| --link: #0b66c3; | |
| --page-pad-x: 18px; | |
| /* Uploader dropzone (match theme secondaryBackgroundColor) */ | |
| --uploader-bg: #f0f2f6; | |
| --uploader-bg-hover: #e9ecef; | |
| --uploader-border: rgba(17, 24, 39, 0.16); | |
| --uploader-border-hover: rgba(17, 24, 39, 0.26); | |
| --uploader-radius: 12px; | |
| /* Browse button (WHITE like original) */ | |
| --uploader-btn-bg: #ffffff; | |
| --uploader-btn-bg-hover: #f8f9fa; | |
| --uploader-btn-border: rgba(17, 24, 39, 0.18); | |
| }} | |
| /* ========================= | |
| 1) Force light surfaces | |
| ========================= */ | |
| html, body {{ | |
| background: var(--surface) !important; | |
| color: var(--text) !important; | |
| }} | |
| .stApp {{ | |
| background: var(--surface) !important; | |
| color: var(--text) !important; | |
| }} | |
| @media (prefers-color-scheme: dark) {{ | |
| html, body, .stApp {{ | |
| background: var(--surface) !important; | |
| color: var(--text) !important; | |
| }} | |
| }} | |
| /* Ensure fixed logo isn't clipped */ | |
| .stApp, .appview-container, .main {{ | |
| overflow: visible !important; | |
| }} | |
| /* ========================= | |
| 2) Spacing tweaks | |
| ========================= */ | |
| .appview-container .main .block-container {{ | |
| padding-top: 0.75rem; | |
| padding-bottom: 1rem; | |
| }} | |
| div[data-testid="stElementContainer"] {{ | |
| margin-bottom: 0.2rem !important; | |
| }} | |
| /* ========================= | |
| 3) Outer layout wrappers | |
| ========================= */ | |
| .content-wrap {{ | |
| width: min(1300px, 100%); | |
| margin: 0 auto; | |
| padding: 0 var(--page-pad-x); | |
| box-sizing: border-box; | |
| }} | |
| .measure-wrap {{ | |
| max-width: var(--content-measure); | |
| margin-left: var(--left-offset); | |
| margin-right: auto; | |
| }} | |
| #upload-wrap {{ | |
| max-width: var(--upload-width); | |
| margin-left: var(--left-offset); | |
| margin-right: auto; | |
| }} | |
| /* ========================================================== | |
| ✅ FILE UPLOADER (HF-proof, short, stable) | |
| ========================================================== */ | |
| /* A) Keep uploader wrapper transparent so label strip stays white */ | |
| section[data-testid="stFileUploader"], | |
| div[data-testid="stFileUploader"], | |
| .stFileUploader, | |
| section[data-testid="stFileUploader"] > div, | |
| div[data-testid="stFileUploader"] > div, | |
| .stFileUploader > div {{ | |
| background: transparent !important; | |
| background-color: transparent !important; | |
| border: 0 !important; | |
| border-radius: 0 !important; | |
| box-shadow: none !important; | |
| }} | |
| /* B) Dropzone surface targets (BaseWeb + testids + common nested nodes) */ | |
| div[data-baseweb="file-uploader"], | |
| div[data-baseweb="file-uploader"] > div, | |
| div[data-baseweb="file-uploader"] > div > div, | |
| div[data-baseweb="file-uploader"] [role="button"], | |
| div[data-testid="stFileUploaderDropzone"], | |
| section[data-testid="stFileUploaderDropzone"], | |
| div[data-testid="stFileUploaderDropzone"] > div, | |
| section[data-testid="stFileUploaderDropzone"] > div, | |
| div[data-testid="stFileUploaderDropzone"] > div > div, | |
| section[data-testid="stFileUploaderDropzone"] > div > div {{ | |
| background: var(--uploader-bg) !important; | |
| background-color: var(--uploader-bg) !important; | |
| border: 1px dashed var(--uploader-border) !important; | |
| border-radius: var(--uploader-radius) !important; | |
| box-shadow: none !important; | |
| padding: 14px 44px 14px 14px !important; | |
| }} | |
| /* Hover */ | |
| div[data-baseweb="file-uploader"] [role="button"]:hover, | |
| div[data-baseweb="file-uploader"] > div:hover, | |
| div[data-testid="stFileUploaderDropzone"]:hover, | |
| section[data-testid="stFileUploaderDropzone"]:hover {{ | |
| background: var(--uploader-bg-hover) !important; | |
| background-color: var(--uploader-bg-hover) !important; | |
| border-color: var(--uploader-border-hover) !important; | |
| }} | |
| /* Focus ring */ | |
| div[data-baseweb="file-uploader"] [role="button"]:focus-within, | |
| div[data-testid="stFileUploaderDropzone"]:focus-within, | |
| section[data-testid="stFileUploaderDropzone"]:focus-within {{ | |
| border-color: rgba(239, 68, 68, 0.45) !important; | |
| box-shadow: 0 0 0 3px rgba(239, 68, 68, 0.15) !important; | |
| }} | |
| /* Text/icon inside uploader */ | |
| div[data-baseweb="file-uploader"] *, | |
| div[data-testid="stFileUploaderDropzone"] *, | |
| section[data-testid="stFileUploaderDropzone"] * {{ | |
| color: var(--text) !important; | |
| }} | |
| /* C) Browse button = WHITE (like original) */ | |
| div[data-baseweb="file-uploader"] button, | |
| div[data-baseweb="file-uploader"] [role="button"] button, | |
| div[data-testid="stFileUploaderDropzone"] button, | |
| section[data-testid="stFileUploaderDropzone"] button, | |
| section[data-testid="stFileUploader"] button, | |
| div[data-testid="stFileUploader"] button {{ | |
| background: var(--uploader-btn-bg) !important; | |
| background-color: var(--uploader-btn-bg) !important; | |
| color: var(--text) !important; | |
| border: 1px solid var(--uploader-btn-border) !important; | |
| border-radius: 10px !important; | |
| font-weight: 600 !important; | |
| box-shadow: none !important; | |
| }} | |
| div[data-baseweb="file-uploader"] button:hover, | |
| div[data-baseweb="file-uploader"] [role="button"] button:hover, | |
| div[data-testid="stFileUploaderDropzone"] button:hover, | |
| section[data-testid="stFileUploaderDropzone"] button:hover, | |
| section[data-testid="stFileUploader"] button:hover, | |
| div[data-testid="stFileUploader"] button:hover {{ | |
| background: var(--uploader-btn-bg-hover) !important; | |
| background-color: var(--uploader-btn-bg-hover) !important; | |
| border-color: rgba(239, 68, 68, 0.55) !important; | |
| color: var(--accent) !important; | |
| }} | |
| /* ========================= | |
| 4) Fixed-edge logo | |
| ========================= */ | |
| #fixed-edge-logo {{ | |
| position: fixed; | |
| left: var(--edge-x); | |
| top: var(--edge-y); | |
| z-index: 1000; | |
| pointer-events: none; | |
| }} | |
| #fixed-edge-logo img {{ | |
| height: var(--logo-height); | |
| width: auto; | |
| display: block; | |
| }} | |
| .edge-logo-spacer {{ | |
| height: var(--tabs-top-gap); | |
| }} | |
| /* ========================= | |
| 5) Typography | |
| ========================= */ | |
| .hero-title {{ | |
| font-size: 40px; | |
| line-height: 1.25; | |
| font-weight: 800; | |
| margin: 0 0 20px; | |
| text-align: justify; | |
| text-justify: inter-word; | |
| color: var(--text); | |
| }} | |
| .text-wrap p {{ | |
| margin: 0 0 14px 0; | |
| font-size: 17px; | |
| line-height: 1.5; | |
| text-align: justify; | |
| text-justify: inter-word; | |
| color: var(--body); | |
| hyphens: auto; | |
| -webkit-hyphens: auto; | |
| -ms-hyphens: auto; | |
| }} | |
| .text-wrap a {{ | |
| color: var(--link) !important; | |
| text-decoration: underline; | |
| text-underline-offset: 2px; | |
| text-decoration-thickness: 1.5px; | |
| }} | |
| .note-text {{ | |
| font-size: 14px; | |
| color: var(--body); | |
| line-height: 1.4; | |
| margin-top: 4px; | |
| }} | |
| /* ========================= | |
| 6) Tabs styling | |
| ========================= */ | |
| div[data-testid="stTabs"] [role="tablist"], | |
| .stTabs [role="tablist"], | |
| div[data-baseweb="tab-list"] {{ | |
| margin-left: calc(var(--left-offset) + var(--tabs-left-shift)) !important; | |
| margin-right: var(--page-pad-x) !important; | |
| border-bottom: 0 !important; | |
| padding-bottom: 6px !important; | |
| }} | |
| div[data-testid="stTabs"] [role="tab"], | |
| .stTabs [role="tab"], | |
| div[data-baseweb="tab-list"] button[role="tab"] {{ | |
| color: var(--muted) !important; | |
| background: transparent !important; | |
| border: none !important; | |
| outline: none !important; | |
| padding: 6px 14px 10px 14px !important; | |
| margin: 0 4px !important; | |
| font-weight: 600 !important; | |
| border-bottom: 3px solid transparent !important; | |
| }} | |
| div[data-testid="stTabs"] [role="tab"][aria-selected="true"], | |
| .stTabs [role="tab"][aria-selected="true"], | |
| div[data-baseweb="tab-list"] button[aria-selected="true"] {{ | |
| color: var(--accent) !important; | |
| border-bottom-color: var(--accent) !important; | |
| }} | |
| div[data-baseweb="tab-highlight"] {{ | |
| display: none !important; | |
| }} | |
| /* ========================= | |
| 7) Responsive overrides | |
| ========================= */ | |
| @media (max-width: 1024px) {{ | |
| :root {{ | |
| --content-measure: 92vw; | |
| --left-offset: 16px; | |
| --upload-width: 92vw; | |
| --tabs-left-shift: 12px; | |
| --page-pad-x: 14px; | |
| }} | |
| .hero-title {{ font-size: 34px; }} | |
| }} | |
| @media (max-width: 640px) {{ | |
| :root {{ | |
| --content-measure: 94vw; | |
| --left-offset: 0px; | |
| --upload-width: 94vw; | |
| --tabs-left-shift: 0px; | |
| --header-clear: 64px; | |
| --page-pad-x: 12px; | |
| }} | |
| #fixed-edge-logo img {{ | |
| height: {mobile_logo_h}px; | |
| }} | |
| .edge-logo-spacer {{ | |
| height: calc(var(--logo-height) + 8px); | |
| }} | |
| .hero-title {{ | |
| font-size: clamp(22px, 6vw, 30px); | |
| line-height: 1.2; | |
| text-align: left; | |
| text-justify: auto; | |
| }} | |
| .text-wrap p {{ | |
| font-size: 15px; | |
| text-align: left; | |
| hyphens: none; | |
| }} | |
| }} | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ========================= | |
| # Small utilities | |
| # ========================= | |
| def log(msg: str): | |
| print(f"[INFO] {msg}") | |
| def _safe_rerun(): | |
| """Compatibility rerun for old/new Streamlit.""" | |
| if hasattr(st, "rerun"): | |
| st.rerun() | |
| elif hasattr(st, "experimental_rerun"): | |
| st.experimental_rerun() | |
| else: | |
| # very old / unusual environment | |
| raise RuntimeError("No rerun method available in this Streamlit version.") | |
| def get_nifti_SF_quick(path: str) -> Tuple[int, int]: | |
| img = nib.load(path) | |
| shape = img.shape # (X,Y,Z) or (X,Y,Z,T) | |
| S = int(shape[2]) if len(shape) >= 3 else 1 | |
| F = int(shape[3]) if len(shape) >= 4 else 1 | |
| return S, F | |
| def get_session_tmpdir() -> str: | |
| if "session_tmpdir" not in st.session_state: | |
| st.session_state["session_tmpdir"] = tempfile.mkdtemp( | |
| prefix=f"segapp_{uuid.uuid4().hex}_" | |
| ) | |
| return st.session_state["session_tmpdir"] | |
| def get_session_gifs_dir() -> str: | |
| d = os.path.join(get_session_tmpdir(), "GIFs") | |
| os.makedirs(d, exist_ok=True) | |
| return d | |
| def get_session_csv_dir() -> str: | |
| d = os.path.join(get_session_tmpdir(), "CSV") | |
| os.makedirs(d, exist_ok=True) | |
| return d | |
| def invalidate_download_caches(): | |
| # Anything derived from cases/gifs that must be rebuilt after edits | |
| for k in [ | |
| "mask_zip_bytes", | |
| "mask_zip_name", | |
| "mask_zip_bytes_corrected", | |
| "mask_zip_name_corrected", | |
| "gif_zip_bytes", | |
| "gif_zip_name", | |
| "gif_zip_bytes_corrected", | |
| "gif_zip_name_corrected", | |
| ]: | |
| st.session_state.pop(k, None) | |
| def normalize_images(x): | |
| x = tf.convert_to_tensor(x, dtype=tf.float32) | |
| mn = tf.reduce_min(x, axis=[1, 2], keepdims=True) | |
| mx = tf.reduce_max(x, axis=[1, 2], keepdims=True) | |
| rng = mx - mn | |
| x_norm = tf.where(rng > 0.0, (x - mn) / rng, tf.zeros_like(x)) | |
| return x_norm.numpy() | |
| def _tf_resize_bilinear(img, *, target_h=SIZE_Y, target_w=SIZE_X): | |
| arr = img[np.newaxis, ..., np.newaxis].astype(np.float32) | |
| out = tf.image.resize( | |
| arr, [target_h, target_w], method="bilinear", antialias=True | |
| ) | |
| return np.squeeze(out.numpy()).astype(np.float32) | |
| def _tf_resize_auto(img, target_h=SIZE_Y, target_w=SIZE_X): | |
| H, W = img.shape | |
| method = "area" if (H > target_h or W > target_w) else "bilinear" | |
| arr = img[np.newaxis, ..., np.newaxis].astype(np.float32) | |
| out = tf.image.resize(arr, [target_h, target_w], method=method, antialias=True) | |
| return np.squeeze(out.numpy()).astype(np.float32) | |
| def _resize_nn(img, new_h, new_w): | |
| arr = img[None, ..., None].astype(np.float32) | |
| out = tf.image.resize(arr, [new_h, new_w], method="nearest") | |
| return np.squeeze(out.numpy()).astype(img.dtype) | |
| def _fft_kspace_resize(img, target_h=SIZE_Y, target_w=SIZE_X): | |
| """ | |
| Resize via k-space: | |
| - Upsample: zero-fill | |
| - Downsample: crop center k-space | |
| """ | |
| img = np.asarray(img, dtype=np.float32) | |
| H, W = img.shape | |
| if H == target_h and W == target_w: | |
| return img | |
| F = np.fft.fftshift(np.fft.fft2(img)) | |
| # Create target k-space container | |
| kspace = np.zeros((target_h, target_w), dtype=np.complex64) | |
| # Compute copy ranges (centered) | |
| y_min = max(0, (target_h - H) // 2) | |
| x_min = max(0, (target_w - W) // 2) | |
| src_y_min = max(0, (H - target_h) // 2) | |
| src_x_min = max(0, (W - target_w) // 2) | |
| copy_h = min(H, target_h) | |
| copy_w = min(W, target_w) | |
| kspace[y_min:y_min + copy_h, x_min:x_min + copy_w] = \ | |
| F[src_y_min:src_y_min + copy_h, src_x_min:src_x_min + copy_w] | |
| img_res = np.fft.ifft2(np.fft.ifftshift(kspace)) | |
| return np.real(img_res).astype(np.float32) | |
| def display_xform(img2d, orient_target=None, enable=DISPLAY_MATCH_DICOM): | |
| if orient_target is None: | |
| orient_target = CURRENT_DISPLAY_ORIENT | |
| if not enable: | |
| return img2d | |
| rule = DISPLAY_RULES.get(orient_target, DISPLAY_RULES[None]) | |
| out = img2d | |
| if rule.get("rot90_cw"): | |
| out = np.rot90(out, k=-1) | |
| if rule.get("flip_ud"): | |
| out = np.flipud(out) | |
| if rule.get("flip_lr"): | |
| out = np.fliplr(out) | |
| return out | |
| def _normalize_to_uint8(img2d: np.ndarray) -> np.ndarray: | |
| arr = np.asarray(img2d, dtype=np.float32) | |
| if arr.size == 0: | |
| return np.zeros_like(arr, dtype=np.uint8) | |
| mn = np.min(arr) | |
| mx = np.max(arr) | |
| if mx > mn: | |
| arr = (arr - mn) / (mx - mn) | |
| else: | |
| arr = np.zeros_like(arr) | |
| arr = (arr * 255.0).clip(0, 255).astype(np.uint8) | |
| return arr | |
| def overlay_rgb_alpha(base_u8: np.ndarray, mask_u8: np.ndarray, | |
| alpha_myo: float = 0.45, alpha_blood: float = 0.45) -> np.ndarray: | |
| """ | |
| base_u8: (H,W) uint8 grayscale | |
| mask_u8: (H,W) uint8 labels (0,1,2) | |
| returns: (H,W,3) uint8 RGB | |
| """ | |
| rgb = np.stack([base_u8]*3, axis=-1).astype(np.float32) | |
| myo = (mask_u8 == 1) | |
| blood = (mask_u8 == 2) | |
| # myocardium = blue | |
| if myo.any(): | |
| blue = np.array([0, 0, 255], dtype=np.float32) | |
| rgb[myo] = (1 - alpha_myo) * rgb[myo] + alpha_myo * blue | |
| # blood pool = red (apply after myo so blood stays red inside the ring) | |
| if blood.any(): | |
| red = np.array([255, 0, 0], dtype=np.float32) | |
| rgb[blood] = (1 - alpha_blood) * rgb[blood] + alpha_blood * red | |
| return rgb.clip(0, 255).astype(np.uint8) | |
| def render_overlay_pil_scaled(img2d_256, mask2d_256, *, scale=CANVAS_SCALE): | |
| # 1) display-transform both consistently | |
| img_disp = display_xform(img2d_256) | |
| mask_disp = display_xform(mask2d_256.astype(np.uint8)) | |
| # 2) uint8 grayscale | |
| base_u8 = _normalize_to_uint8(img_disp) | |
| # 3) overlay in RGB | |
| rgb = overlay_rgb_alpha(base_u8, mask_disp, alpha_myo=0.45, alpha_blood=0.45) | |
| # 4) upscale to match canvas | |
| pil = Image.fromarray(rgb) | |
| W = SIZE_X * scale | |
| H = SIZE_Y * scale | |
| pil = pil.resize((W, H), resample=Image.NEAREST) | |
| return pil | |
| def build_zip_bytes(files: list, root_folder: Optional[str] = None): | |
| """ | |
| files: list of tuples (arc_rel_path, bytes_content) | |
| root_folder: if provided, makes a top-level folder inside zip. | |
| if None or "", files are placed at zip root. | |
| """ | |
| buf = io.BytesIO() | |
| with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as z: | |
| for rel_path, content in files: | |
| if root_folder: | |
| arc = f"{root_folder}/{rel_path}".replace("\\", "/") | |
| else: | |
| arc = f"{rel_path}".replace("\\", "/") | |
| z.writestr(arc, content) | |
| buf.seek(0) | |
| return buf.getvalue() | |
| def build_corrected_gif_zip_bytes(): | |
| zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem | |
| gif_root = f"{zipstem}_GIF_corrected" | |
| gif_zip_name = f"{gif_root}.zip" | |
| gif_files = [] | |
| orig = st.session_state.get("gif_paths_original", {}) or {} | |
| corr = st.session_state.get("gif_paths_corrected", {}) or {} | |
| # corrected overrides original | |
| all_pids = sorted(set(orig.keys()) | set(corr.keys())) | |
| for pid in all_pids: | |
| gif_path = corr.get(pid) or orig.get(pid) | |
| if gif_path and os.path.exists(gif_path): | |
| with open(gif_path, "rb") as f: | |
| # filename rule: pid.gif (pid derived from nifti name) | |
| gif_files.append((f"{pid}.gif", f.read())) | |
| st.session_state["gif_zip_bytes_corrected"] = build_zip_bytes( | |
| gif_files, | |
| root_folder=gif_root | |
| ) | |
| st.session_state["gif_zip_name_corrected"] = gif_zip_name | |
| def build_corrected_masks_zip_bytes_per_mouse(): | |
| zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem | |
| root = f"{zipstem}_Masks_corrected" | |
| zip_name = f"{root}.zip" | |
| mask_files = [] | |
| cases = st.session_state.get("cases", {}) or {} | |
| for pid, case in cases.items(): | |
| native_h, native_w, S, F = case["native_shape"] | |
| affine = case["affine"] | |
| # ✅ use corrected only if mask was edited; otherwise keep original mask | |
| if case.get("mask_edited", False): | |
| preds_src = case["preds_4d"] | |
| else: | |
| preds_src = case.get("preds_4d_orig", case["preds_4d"]) | |
| mask_native = resize_masks_to_native(preds_src, native_h, native_w) | |
| session_tmp = get_session_tmpdir() | |
| tmp_mask_path = os.path.join(session_tmp, f"{pid}_mask.nii.gz") | |
| save_native_mask_nifti_with_labels(mask_native, affine, tmp_mask_path) | |
| with open(tmp_mask_path, "rb") as f: | |
| mask_files.append((f"{pid}/{pid}_mask.nii.gz", f.read())) | |
| try: | |
| os.remove(tmp_mask_path) | |
| except OSError: | |
| pass | |
| st.session_state["mask_zip_bytes_corrected"] = build_zip_bytes(mask_files, root_folder=root) | |
| st.session_state["mask_zip_name_corrected"] = zip_name | |
| def _alpha_overlay_rgb(base_u8, mask_u8): | |
| """ | |
| base_u8: (H,W) uint8 grayscale | |
| mask_u8: (H,W) uint8 labels (0,1,2) | |
| returns RGB uint8 | |
| """ | |
| base_rgb = np.stack([base_u8]*3, axis=-1).astype(np.float32) | |
| # myocardium (1) blue tint | |
| myo = (mask_u8 == 1) | |
| base_rgb[myo, 2] = np.clip(base_rgb[myo, 2] * 0.6 + 255 * 0.4, 0, 255) | |
| # blood (2) red tint | |
| blood = (mask_u8 == 2) | |
| base_rgb[blood, 0] = np.clip(base_rgb[blood, 0] * 0.6 + 255 * 0.4, 0, 255) | |
| return base_rgb.astype(np.uint8) | |
| def build_frame_montage(images_4d_256, preds_4d, frame_idx, slice_indices, tile_cols=4, add_overlay=True): | |
| """ | |
| Returns a PIL Image montage for a given frame over selected slices. | |
| """ | |
| tiles = [] | |
| for s in slice_indices: | |
| img = images_4d_256[:, :, s, frame_idx] | |
| img_u8 = _normalize_to_uint8(img) | |
| if add_overlay: | |
| mask = preds_4d[:, :, s, frame_idx].astype(np.uint8) | |
| tile = overlay_rgb_alpha(img_u8, mask, alpha_myo=0.45, alpha_blood=0.45) | |
| tile_pil = Image.fromarray(tile) | |
| else: | |
| tile_pil = Image.fromarray(img_u8).convert("RGB") | |
| # annotate slice number | |
| d = ImageDraw.Draw(tile_pil) | |
| d.text((5, 5), f"S{s+1}", fill=(255,255,255)) | |
| tiles.append(tile_pil) | |
| if not tiles: | |
| return Image.new("RGB", (256, 256), (0,0,0)) | |
| w, h = tiles[0].size | |
| n = len(tiles) | |
| cols = min(tile_cols, n) | |
| rows = int(math.ceil(n / cols)) | |
| canvas = Image.new("RGB", (cols*w, rows*h), (0,0,0)) | |
| for i, tile in enumerate(tiles): | |
| r = i // cols | |
| c = i % cols | |
| canvas.paste(tile, (c*w, r*h)) | |
| return canvas | |
| def _render_overlay_png(img2d, mask2d, title=None): | |
| fig, ax = plt.subplots(figsize=(3, 3)) | |
| base = display_xform(img2d) | |
| myo_mask = display_xform((mask2d == 1).astype(np.uint8)).astype(bool) | |
| blood_mask = display_xform((mask2d == 2).astype(np.uint8)).astype(bool) | |
| ax.imshow(base, cmap="gray", interpolation="none") | |
| if myo_mask.any(): | |
| ax.imshow( | |
| np.ma.masked_where(~myo_mask, myo_mask), | |
| cmap="Blues", | |
| alpha=0.45, | |
| vmin=0, | |
| vmax=1, | |
| interpolation="none", | |
| ) | |
| if blood_mask.any(): | |
| ax.imshow( | |
| np.ma.masked_where(~blood_mask, blood_mask), | |
| cmap="jet", | |
| alpha=0.45, | |
| vmin=0, | |
| vmax=1, | |
| interpolation="none", | |
| ) | |
| ax.axis("off") | |
| if title: | |
| ax.set_title(title, fontsize=10) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return buf | |
| # Colour and label maps for simple painting (kept for future use if needed) | |
| _COLOR_MAP_RGB = { | |
| "Blood pool": np.array([255, 0, 0], dtype=np.int16), # red | |
| "Myocardium": np.array([0, 0, 255], dtype=np.int16), # blue | |
| "Eraser (background)": np.array([0, 0, 0], dtype=np.int16), # black | |
| } | |
| _LABEL_MAP = { | |
| "Blood pool": 2, | |
| "Myocardium": 1, | |
| "Eraser (background)": 0, | |
| } | |
| def _apply_canvas_to_mask_slice( | |
| mask_slice: np.ndarray, | |
| canvas_rgba: np.ndarray, | |
| class_choice: str, | |
| color_tol: float = 30.0, | |
| ) -> np.ndarray: | |
| """ | |
| Simple label-painting: update pixels where canvas color ≈ selected class color. | |
| (Currently not exposed in the UI, but kept for possible future use.) | |
| """ | |
| if canvas_rgba is None: | |
| return mask_slice | |
| if class_choice not in _COLOR_MAP_RGB or class_choice not in _LABEL_MAP: | |
| return mask_slice | |
| target_rgb = _COLOR_MAP_RGB[class_choice] | |
| label_value = _LABEL_MAP[class_choice] | |
| rgba = canvas_rgba.astype(np.int16) | |
| rgb = rgba[:, :, :3] | |
| alpha = rgba[:, :, 3] | |
| diff = np.linalg.norm(rgb - target_rgb[None, None, :], axis=-1) | |
| drawn_mask = (alpha > 0) & (diff < color_tol) | |
| updated = mask_slice.copy() | |
| updated[drawn_mask] = label_value | |
| return updated | |
| def _thicken_close_and_fill( | |
| strokes: np.ndarray, | |
| dilate_iter: int = 1, | |
| close_radius: int = 2, | |
| ) -> np.ndarray: | |
| """ | |
| Helper for LV contour editing: | |
| 1. Thicken strokes (binary_dilation) so small gaps shrink. | |
| 2. Binary closing (dilate+erode) to bridge breaks. | |
| 3. Fill interior (binary_fill_holes). | |
| Input/Output: boolean mask. | |
| """ | |
| if strokes is None or not strokes.any(): | |
| return strokes | |
| # Small 3x3-ish disk to thicken the strokes | |
| selem_dilate = disk(1) | |
| # Slightly larger disk for closing, to bridge small end-point gaps | |
| selem_close = disk(close_radius) | |
| thick = binary_dilation( | |
| strokes, | |
| structure=selem_dilate, | |
| iterations=dilate_iter, | |
| ) | |
| closed = binary_closing( | |
| thick, | |
| structure=selem_close, | |
| ) | |
| filled = binary_fill_holes(closed) | |
| return filled | |
| def _apply_lv_contour_style( | |
| mask_slice: np.ndarray, | |
| canvas_rgba: np.ndarray, | |
| color_tol: float = 30.0, | |
| ) -> np.ndarray: | |
| """ | |
| Horos-style LV editing for a single 2D slice. | |
| Assumes the user has drawn: | |
| • RED (#FF0000) -> inner border (endocardial / blood pool) | |
| • BLUE (#0000FF) -> outer border (epicardial / myocardium) | |
| We: | |
| 1. Detect red and blue stroke pixels from the RGBA canvas. | |
| 2. Thicken & binary-close each stroke mask, then fill holes. | |
| 3. Set: | |
| - blood pool = interior of red | |
| - myocardium = region between blue and red | |
| and overwrite any previous label 1/2 in this slice. | |
| """ | |
| if canvas_rgba is None: | |
| return mask_slice | |
| rgba = canvas_rgba.astype(np.int16) | |
| rgb = rgba[:, :, :3] | |
| alpha = rgba[:, :, 3] | |
| endo_rgb = np.array([255, 0, 0], dtype=np.int16) # red (blood pool) | |
| epi_rgb = np.array([0, 0, 255], dtype=np.int16) # blue (myocardium) | |
| diff_endo = np.linalg.norm(rgb - endo_rgb[None, None, :], axis=-1) | |
| diff_epi = np.linalg.norm(rgb - epi_rgb[None, None, :], axis=-1) | |
| endo_strokes = (alpha > 0) & (diff_endo < color_tol) | |
| epi_strokes = (alpha > 0) & (diff_epi < color_tol) | |
| # If nothing was drawn, just return the slice unchanged | |
| if not endo_strokes.any() and not epi_strokes.any(): | |
| return mask_slice | |
| # Thicken, close and fill each contour mask | |
| endo_region = _thicken_close_and_fill(endo_strokes) | |
| if endo_region is None: | |
| endo_region = np.zeros_like(endo_strokes, dtype=bool) | |
| epi_region = _thicken_close_and_fill(epi_strokes) | |
| if epi_region is None: | |
| epi_region = np.zeros_like(epi_strokes, dtype=bool) | |
| # Ensure epicardial region at least contains the endocardial one | |
| epi_region = epi_region | endo_region | |
| blood_region = endo_region | |
| myo_region = epi_region & ~endo_region | |
| updated = mask_slice.copy() | |
| # Clear existing LV labels on this slice, then re-assign | |
| updated[(updated == 1) | (updated == 2)] = 0 | |
| updated[blood_region] = 2 # blood pool | |
| updated[myo_region] = 1 # myocardium | |
| return updated | |
| # ========================= | |
| # NIfTI I/O + spacing | |
| # ========================= | |
| def _reorient_nifti(img: nib.Nifti1Image, target: Optional[str]): | |
| if not target: | |
| return img, None | |
| tgt = target.upper() | |
| if tgt not in ("LPS", "RAS"): | |
| raise ValueError("ORIENT_TARGET must be None, 'LPS', or 'RAS'") | |
| cur = nib.orientations.io_orientation(img.affine) | |
| wanted = nib.orientations.axcodes2ornt(tuple(tgt)) | |
| xfm = nib.orientations.ornt_transform(cur, wanted) | |
| if np.allclose(xfm, np.array([[0, 1], [1, 1], [2, 1]])): | |
| return img, xfm | |
| data = img.get_fdata() | |
| data_re = nib.orientations.apply_orientation(data, xfm) | |
| aff_re = img.affine @ nib.orientations.inv_ornt_aff(xfm, img.shape) | |
| return nib.Nifti1Image(data_re, aff_re, header=img.header), xfm | |
| def load_nifti_4d(path, orient_target: Optional[str] = ORIENT_TARGET): | |
| img_native = nib.load(path) | |
| img, _ = _reorient_nifti(img_native, orient_target) | |
| data = img.get_fdata(dtype=np.float32) # (X,Y,Z[,T]) | |
| if data.ndim == 3: | |
| data = data[..., None] | |
| data_4d = np.transpose(data, (1, 0, 2, 3)).astype(np.float32) # -> (H,W,S,F) | |
| zooms = img.header.get_zooms() | |
| x_mm = float(zooms[0]) if len(zooms) > 0 else 1.0 | |
| y_mm = float(zooms[1]) if len(zooms) > 1 else 1.0 | |
| z_mm = float(zooms[2]) if len(zooms) > 2 else 1.0 | |
| t = float(zooms[3]) if len(zooms) > 3 else None | |
| # data_4d = transpose(data, (1,0,2,3)) => (H,W)=(Y,X) | |
| spacing = dict( | |
| row_mm=y_mm, # H spacing | |
| col_mm=x_mm, # W spacing | |
| slice_thickness_mm=z_mm, | |
| frame_time_ms=t, | |
| ) | |
| return data_4d, spacing, img.affine | |
| # ========================= | |
| # Inference | |
| # ========================= | |
| def nifti_to_model_batches(data_4d): | |
| """ | |
| Fast path (vectorized) for tf resize modes. | |
| Falls back to slow loop for fft resize. | |
| Returns: | |
| x: (S*F, 256,256,1) float32 | |
| index: list of (s,f) | |
| shape4d: (H,W,S,F) native | |
| resize_method_used: str (for debugging/logging) | |
| """ | |
| H, W, S, F = data_4d.shape | |
| # ✅ Validate resize mode early (fail fast) | |
| valid = {"auto", "bilinear", "area", "fft"} | |
| if RESIZE_MODE not in valid: | |
| raise ValueError(f"RESIZE_MODE must be one of {sorted(valid)}") | |
| # index mapping (matches stack order below: S-major, F-minor) | |
| index = [(s, f) for s in range(S) for f in range(F)] # matches fft loop | |
| # ---- FFT needs per-slice handling (keep loop) ---- | |
| if RESIZE_MODE == "fft": | |
| batches = [] | |
| for s in range(S): | |
| for f in range(F): | |
| img = data_4d[..., s, f] | |
| img_resized = _fft_kspace_resize(img, target_h=SIZE_Y, target_w=SIZE_X) | |
| batches.append(img_resized[..., None]) | |
| x = np.stack(batches, axis=0).astype(np.float32) | |
| return x, index, (H, W, S, F), "fft" | |
| # ---- Vectorized TF resize path (FAST) ---- | |
| # Reshape (H,W,S,F) -> (S*F, H, W, 1) | |
| stack = ( | |
| np.transpose(data_4d, (2, 3, 0, 1)) | |
| .reshape(S * F, H, W, 1) | |
| .astype(np.float32) | |
| ) | |
| # Decide method ONCE per case (since H,W are constant within a NIfTI) | |
| if RESIZE_MODE == "auto": | |
| method = "area" if (H > SIZE_Y or W > SIZE_X) else "bilinear" | |
| else: | |
| method = RESIZE_MODE # "bilinear" or "area" | |
| x_tf = tf.image.resize(stack, [SIZE_Y, SIZE_X], method=method, antialias=True) | |
| x = x_tf.numpy().astype(np.float32) | |
| return x, index, (H, W, S, F), method | |
| def _ensure_logits_last(preds): | |
| if isinstance(preds, (list, tuple)): | |
| preds = preds[-1] | |
| return preds | |
| def predict_nifti_4d( | |
| model, | |
| data_4d, | |
| batch_size=None, | |
| *, | |
| progress_cb=None, | |
| ): | |
| """ | |
| Run inference on a single (H,W,S,F) NIfTI case and return: | |
| preds_4d: (256,256,S,F) uint8 labels | |
| imgs_4d_resized: (256,256,S,F) float32 resized images | |
| progress_cb (optional): callable(int images_done) | |
| - Called with number of 2D images processed so far within THIS file. | |
| - "2D images" here means S*F slices/frames after stacking. | |
| - This enables smooth progress updates (31->32->33...) within a file. | |
| """ | |
| x, index, shape4d, resize_method_used = nifti_to_model_batches(data_4d) | |
| # Normalize (already vectorized) | |
| x_norm = normalize_images(x) | |
| # ---- Progress-aware prediction ---- | |
| # If no progress_cb, keep the original fast path (single predict call). | |
| if progress_cb is None: | |
| preds = _ensure_logits_last( | |
| model.predict(x_norm, verbose=0, batch_size=batch_size) | |
| ) | |
| else: | |
| # We must do batched prediction to emit progress updates. | |
| N = x_norm.shape[0] # = S*F | |
| bs = int(batch_size or BATCH_SIZE or 16) | |
| bs = max(1, bs) | |
| outs = [] | |
| done = 0 | |
| # initial callback (optional, but helps UI start at 0 within-file) | |
| try: | |
| progress_cb(0) | |
| except Exception: | |
| pass | |
| for start in range(0, N, bs): | |
| end = min(start + bs, N) | |
| batch = x_norm[start:end] | |
| pred_b = _ensure_logits_last(model.predict(batch, verbose=0)) | |
| outs.append(pred_b) | |
| done = end | |
| # update progress after each batch | |
| try: | |
| progress_cb(done) | |
| except Exception: | |
| pass | |
| preds = np.concatenate(outs, axis=0) | |
| labels = np.argmax(preds, axis=-1).astype(np.uint8) # (S*F, 256,256) | |
| Hn, Wn, S, F = SIZE_Y, SIZE_X, shape4d[2], shape4d[3] | |
| # Reshape back: | |
| # x: (S*F, Hn, Wn, 1) -> (S, F, Hn, Wn) -> (Hn, Wn, S, F) | |
| imgs_4d_resized = x[..., 0].reshape(S, F, Hn, Wn).transpose(2, 3, 0, 1) | |
| # labels: (S*F, Hn, Wn) -> (S, F, Hn, Wn) -> (Hn, Wn, S, F) | |
| preds_4d = labels.reshape(S, F, Hn, Wn).transpose(2, 3, 0, 1) | |
| return preds_4d, imgs_4d_resized | |
| def resize_masks_to_native(preds_4d_256, native_h, native_w): | |
| Hm, Wm, S, F = preds_4d_256.shape | |
| out = np.zeros((native_h, native_w, S, F), dtype=preds_4d_256.dtype) | |
| for f in range(F): | |
| for s in range(S): | |
| out[..., s, f] = _resize_nn( | |
| preds_4d_256[..., s, f], native_h, native_w | |
| ) | |
| return out | |
| # ========================= | |
| # Mask export (native) | |
| # ========================= | |
| def save_native_mask_nifti(mask_hwst, affine, out_path): | |
| """ | |
| mask_hwst: (H,W,S,F) uint8 labels | |
| affine: affine from load_nifti_4d() | |
| """ | |
| data_xyzt = np.transpose(mask_hwst, (1, 0, 2, 3)).astype(np.uint8) | |
| img = nib.Nifti1Image(data_xyzt, affine) | |
| nib.save(img, out_path) | |
| return out_path | |
| DEFAULT_LABELS = { | |
| 0: {"name": "Background", "color": [0, 0, 0]}, | |
| 1: {"name": "Myocardium", "color": [0, 0, 255]}, | |
| 2: {"name": "LV Blood Pool", "color": [255, 0, 0]}, | |
| } | |
| def _make_label_json(label_defs: dict) -> str: | |
| payload = { | |
| "type": "segmentation_labels", | |
| "version": 1, | |
| "labels": {str(k): v for k, v in label_defs.items()}, | |
| } | |
| return json.dumps(payload, separators=(",", ":"), ensure_ascii=False) | |
| def save_native_mask_nifti_with_labels(mask_hwst, affine, out_path, label_defs=DEFAULT_LABELS): | |
| """ | |
| mask_hwst: (H,W,S,F) uint8 labels | |
| Saves NIfTI with: | |
| - header['descrip'] short string (max 80 bytes) | |
| - NIfTI header extension containing label JSON | |
| """ | |
| data_xyzt = np.transpose(mask_hwst, (1, 0, 2, 3)).astype(np.uint8) | |
| img = nib.Nifti1Image(data_xyzt, affine) | |
| # short description (widely preserved) | |
| desc = "Labels:0=BG,1=Myo,2=Blood" | |
| img.header["descrip"] = desc.encode("ascii", "ignore")[:80] | |
| # embedded JSON extension (many tools preserve it; some viewers won't display it) | |
| label_json = _make_label_json(label_defs) | |
| ext = nib.nifti1.Nifti1Extension(40, label_json.encode("utf-8")) # 40=comment | |
| exts = img.header.extensions | |
| exts.clear() | |
| exts.append(ext) | |
| nib.save(img, out_path) | |
| return out_path | |
| # ========================= | |
| # Cleaning + ED/ES + metrics | |
| # ========================= | |
| def clean_predictions_per_frame_3d(mask_4d): | |
| H, W, S, F = mask_4d.shape | |
| out = mask_4d.copy() | |
| for f in range(F): | |
| vol_f = out[:, :, :, f] | |
| for cls in (1, 2): | |
| m = vol_f == cls | |
| if not m.any(): | |
| continue | |
| cc = cc_label(m, connectivity=1) | |
| props = regionprops(cc) | |
| if not props: | |
| continue | |
| dom = max(props, key=lambda r: r.area) | |
| dom_centroid = np.array(dom.centroid) | |
| keep = {dom.label} | |
| for r in props: | |
| if r.label == dom.label: | |
| continue | |
| zmin, zmax = r.bbox[2], r.bbox[5] | |
| slice_span = zmax - zmin | |
| areas = [ | |
| np.count_nonzero(cc[:, :, z] == r.label) | |
| for z in range(zmin, zmax) | |
| ] | |
| median_area = np.median(areas) if areas else 0 | |
| dist = np.linalg.norm( | |
| np.array(r.centroid) - dom_centroid | |
| ) | |
| if ( | |
| slice_span >= ISLAND_MIN_SLICE_SPAN | |
| and median_area >= ISLAND_MIN_AREA_PER_SLICE | |
| and dist <= ISLAND_CENTROID_DIST_THRESH | |
| ): | |
| keep.add(r.label) | |
| drop = (cc > 0) & (~np.isin(cc, list(keep))) | |
| vol_f[drop] = 0 | |
| out[:, :, :, f] = vol_f | |
| return out | |
| def compute_per_frame_metrics( | |
| preds_4d, spacing, labels={"myo": 1, "blood": 2} | |
| ): | |
| row_mm = float(spacing["row_mm"]) | |
| col_mm = float(spacing["col_mm"]) | |
| thk = float(spacing["slice_thickness_mm"]) | |
| voxel_mm3 = row_mm * col_mm * thk | |
| H, W, S, F = preds_4d.shape | |
| blood_counts = (preds_4d == labels["blood"]).sum(axis=(0, 1, 2)) | |
| myo_counts = (preds_4d == labels["myo"]).sum(axis=(0, 1, 2)) | |
| volume_uL = blood_counts * voxel_mm3 | |
| myo_mass_mg = myo_counts * voxel_mm3 * MYO_DENSITY | |
| return pd.DataFrame( | |
| { | |
| "Frame": np.arange(F, dtype=int), | |
| "Volume_uL": volume_uL, | |
| "MyocardiumMass_mg": myo_mass_mg, | |
| } | |
| ) | |
| def slice_validity_matrix(preds_4d, A_blood_min=30, A_myo_min=30): | |
| H, W, S, F = preds_4d.shape | |
| blood = preds_4d == 2 | |
| myo = preds_4d == 1 | |
| areas_blood = blood.reshape(H * W, S, F).sum(axis=0) | |
| areas_myo = myo.reshape(H * W, S, F).sum(axis=0) | |
| has_blood = areas_blood >= A_blood_min | |
| has_myo = areas_myo >= A_myo_min | |
| return has_blood, has_myo, areas_blood, areas_myo | |
| def choose_mid_slices(has_blood, has_myo, K=4, min_frac=0.7): | |
| S, F = has_blood.shape | |
| valid_frac = (has_blood & has_myo).sum(axis=1) / max(F, 1) | |
| target = int(S // 2) | |
| best = None | |
| for start in range(0, max(S - K + 1, 1)): | |
| block = list(range(start, min(start + K, S))) | |
| score = valid_frac[block].mean() - 0.01 * np.mean( | |
| [abs(s - target) for s in block] | |
| ) | |
| if best is None or score > best[0]: | |
| best = (score, block) | |
| _, chosen = best | |
| if np.mean(valid_frac[chosen]) < min_frac: | |
| order = np.argsort(-valid_frac) | |
| chosen = sorted(order[:K].tolist()) | |
| return chosen | |
| def frame_volumes_subset_uL(preds_4d, spacing, slice_indices): | |
| voxel = ( | |
| float(spacing["row_mm"]) | |
| * float(spacing["col_mm"]) | |
| * float(spacing["slice_thickness_mm"]) | |
| ) | |
| F = preds_4d.shape[3] | |
| vols = np.zeros(F, dtype=np.float32) | |
| for f in range(F): | |
| sub = preds_4d[:, :, slice_indices, f] | |
| vols[f] = (sub == 2).sum() * voxel | |
| return vols | |
| def pick_ed_es_from_volumes( | |
| vols_uL, prefer_frame0=True, rel_tol=0.05, min_sep=1 | |
| ): | |
| ed = int(np.argmax(vols_uL)) | |
| if prefer_frame0 and abs(vols_uL[0] - vols_uL[ed]) <= rel_tol * max( | |
| vols_uL[ed], 1e-6 | |
| ): | |
| ed = 0 | |
| es = int(np.argsort(vols_uL)[0]) | |
| for c in np.argsort(vols_uL): | |
| if abs(int(c) - ed) >= min_sep: | |
| es = int(c) | |
| break | |
| return ed, es | |
| # ========================= | |
| # GIF (ED vs ES) | |
| # ========================= | |
| def gif_animation_for_patient_pred_only( | |
| images_4d, preds_4d, patient_id, ed_idx, es_idx, output_dir | |
| ): | |
| os.makedirs(output_dir, exist_ok=True) | |
| def overlay(ax, img, pred, alpha_myo=0.45, alpha_blood=0.45): | |
| base = display_xform(img) | |
| myo_mask = display_xform((pred == 1).astype(np.uint8)).astype(bool) | |
| blood_mask = display_xform((pred == 2).astype(np.uint8)).astype(bool) | |
| ax.imshow(base, cmap="gray", interpolation="none") | |
| if myo_mask.any(): | |
| ax.imshow( | |
| np.ma.masked_where(~myo_mask, myo_mask), | |
| cmap="Blues", | |
| alpha=alpha_myo, | |
| vmin=0, | |
| vmax=1, | |
| interpolation="none", | |
| ) | |
| if blood_mask.any(): | |
| ax.imshow( | |
| np.ma.masked_where(~blood_mask, blood_mask), | |
| cmap="jet", | |
| alpha=alpha_blood, | |
| vmin=0, | |
| vmax=1, | |
| interpolation="none", | |
| ) | |
| ax.axis("off") | |
| H, W, S, F = images_4d.shape | |
| fig, axarr = plt.subplots(1, 2, figsize=(8, 4)) | |
| plt.tight_layout(rect=[0, 0, 1, 0.92]) | |
| def update(slice_idx): | |
| axarr[0].clear() | |
| axarr[1].clear() | |
| overlay( | |
| axarr[0], | |
| images_4d[:, :, slice_idx, ed_idx], | |
| preds_4d[:, :, slice_idx, ed_idx], | |
| ) | |
| axarr[0].set_title(f"ED (frame {ed_idx+1}) | Slice {slice_idx+1}") | |
| overlay( | |
| axarr[1], | |
| images_4d[:, :, slice_idx, es_idx], | |
| preds_4d[:, :, slice_idx, es_idx], | |
| ) | |
| axarr[1].set_title(f"ES (frame {es_idx+1}) | Slice {slice_idx+1}") | |
| fig.suptitle(f"Mouse ID: {patient_id}", fontsize=14, y=0.98) | |
| # Slower animation (2 seconds per slice), fps still 2 | |
| anim = animation.FuncAnimation(fig, update, frames=S, interval=2500) | |
| out_path = os.path.join(output_dir, f"{patient_id}_pred.gif") | |
| anim.save(out_path, writer="pillow", fps=GIF_FPS) | |
| plt.close(fig) | |
| return out_path | |
| # ========================= | |
| # CSV writer | |
| # ========================= | |
| def write_all_in_one_csv(rows, per_frame_rows, csv_dir): | |
| df_summary = pd.DataFrame(rows) | |
| summary_cols = [ | |
| "Patient_ID", | |
| "EDV_uL", | |
| "ESV_uL", | |
| "SV_uL", | |
| "EF_%", | |
| "MyocardiumMass_ED_mg", | |
| "MyocardiumMass_ES_mg", | |
| "ED_frame_index", | |
| "ES_frame_index", | |
| "Seg_PixelSpacing_row_mm", | |
| "Seg_PixelSpacing_col_mm", | |
| "Seg_SliceThickness_mm", | |
| "Native_PixelSpacing_row_mm", | |
| "Native_PixelSpacing_col_mm", | |
| "Native_SliceThickness_mm", | |
| ] | |
| df_summary = df_summary.reindex(columns=summary_cols) | |
| if per_frame_rows: | |
| df_perframe = pd.concat(per_frame_rows, ignore_index=True) | |
| else: | |
| df_perframe = pd.DataFrame( | |
| columns=["Patient_ID", "Frame", "Volume_uL", "MyocardiumMass_mg"] | |
| ) | |
| if not df_perframe.empty: | |
| df_perframe["Frame"] = df_perframe["Frame"].astype(int) + 1 | |
| if not df_summary.empty: | |
| for c in ("ED_frame_index", "ES_frame_index"): | |
| if c in df_summary.columns: | |
| df_summary[c] = df_summary[c].astype("Int64") + 1 | |
| for c in ( | |
| "EF_%", | |
| "EDV_uL", | |
| "ESV_uL", | |
| "SV_uL", | |
| "MyocardiumMass_ED_mg", | |
| "MyocardiumMass_ES_mg", | |
| "Seg_PixelSpacing_row_mm", | |
| "Seg_PixelSpacing_col_mm", | |
| "Seg_SliceThickness_mm", | |
| "Native_PixelSpacing_row_mm", | |
| "Native_PixelSpacing_col_mm", | |
| "Native_SliceThickness_mm", | |
| ): | |
| if c in df_summary.columns: | |
| df_summary[c] = df_summary[c].astype(float).map( | |
| lambda x: f"{x:.2f}" | |
| ) | |
| for c in ("Volume_uL", "MyocardiumMass_mg"): | |
| if c in df_perframe.columns and not df_perframe.empty: | |
| df_perframe[c] = df_perframe[c].astype(float).map( | |
| lambda x: f"{x:.2f}" | |
| ) | |
| all_in_one = df_perframe.merge( | |
| df_summary[summary_cols], | |
| on="Patient_ID", | |
| how="left", | |
| )[ | |
| [ | |
| "Patient_ID", | |
| "ED_frame_index", | |
| "ES_frame_index", | |
| "EDV_uL", | |
| "ESV_uL", | |
| "SV_uL", | |
| "EF_%", | |
| "MyocardiumMass_ED_mg", | |
| "MyocardiumMass_ES_mg", | |
| "Frame", | |
| "Volume_uL", | |
| "MyocardiumMass_mg", | |
| "Seg_PixelSpacing_row_mm", | |
| "Seg_PixelSpacing_col_mm", | |
| "Seg_SliceThickness_mm", | |
| "Native_PixelSpacing_row_mm", | |
| "Native_PixelSpacing_col_mm", | |
| "Native_SliceThickness_mm", | |
| ] | |
| ] | |
| os.makedirs(csv_dir, exist_ok=True) | |
| out_csv = os.path.join(csv_dir, "Results.csv") | |
| all_in_one.to_csv(out_csv, index=False) | |
| log(f"CSV written: {out_csv}") | |
| return out_csv | |
| # ========================= | |
| # Same-tab download | |
| # ========================= | |
| def _same_tab_download_button( | |
| label: str, | |
| data_bytes: bytes, | |
| file_name: str, | |
| mime: str = "text/csv", | |
| *, | |
| key: Optional[str] = None, | |
| ): | |
| import html, hashlib, base64 | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| b64 = base64.b64encode(data_bytes).decode("ascii") | |
| btn_id = f"dl_{(key or file_name)}_{hashlib.sha256((key or file_name).encode()).hexdigest()[:8]}" | |
| st.markdown( | |
| f""" | |
| <style> | |
| a#{btn_id} {{ | |
| appearance: none; | |
| display: inline-flex; align-items: center; justify-content: center; | |
| padding: 0.5rem 0.75rem; | |
| border-radius: 0.5rem; | |
| border: 1px solid rgba(49,51,63,.2); | |
| background: var(--background-color); | |
| color: var(--text-color); | |
| font-weight: 600; text-decoration: none !important; | |
| box-shadow: 0 1px 2px rgba(0,0,0,0.04); | |
| transition: color .15s ease, border-color .15s ease, | |
| box-shadow .15s ease, transform .05s ease, background-color .15s; | |
| cursor: pointer; | |
| user-select: none; | |
| -webkit-tap-highlight-color: transparent; | |
| }} | |
| a#{btn_id}:hover, a#{btn_id}:focus {{ | |
| background: var(--background-color); | |
| color: var(--accent) !important; | |
| border-color: var(--accent) !important; | |
| box-shadow: 0 2px 6px rgba(239,68,68,0.20); | |
| transform: translateY(-1px); | |
| }} | |
| a#{btn_id}:active, | |
| a#{btn_id}.pressed {{ | |
| background: var(--accent) !important; | |
| border-color: var(--accent) !important; | |
| color: #fff !important; | |
| box-shadow: 0 3px 10px rgba(239,68,68,0.35); | |
| transform: translateY(0); | |
| }} | |
| a#{btn_id}:focus-visible {{ | |
| outline: none; | |
| box-shadow: 0 0 0 0.2rem rgba(239,68,68,0.35); | |
| }} | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| f''' | |
| <div class="dl-wrap"> | |
| <a class="dl-btn" id="{btn_id}" href="#" | |
| data-b64="{b64}" | |
| data-mime="{html.escape(mime)}" | |
| data-fname="{html.escape(file_name)}">{html.escape(label)}</a> | |
| </div> | |
| ''', | |
| unsafe_allow_html=True, | |
| ) | |
| components.html( | |
| f""" | |
| <script> | |
| (function () {{ | |
| try {{ | |
| const doc = window.parent.document; | |
| const a = doc.getElementById("{btn_id}"); | |
| if (!a) return; | |
| const pressOn = () => a.classList.add("pressed"); | |
| const pressOff = () => a.classList.remove("pressed"); | |
| a.addEventListener("mousedown", pressOn, true); | |
| a.addEventListener("mouseup", pressOff, true); | |
| a.addEventListener("mouseleave",pressOff, true); | |
| a.addEventListener("touchstart",pressOn, {{passive:true}}); | |
| a.addEventListener("touchend", pressOff, true); | |
| a.addEventListener("touchcancel",pressOff, true); | |
| a.addEventListener("click", function(ev) {{ | |
| ev.preventDefault(); | |
| ev.stopImmediatePropagation(); | |
| const b64 = a.getAttribute("data-b64"); | |
| const mime = a.getAttribute("data-mime") || "application/octet-stream"; | |
| const fname = a.getAttribute("data-fname") || "download"; | |
| const bstr = atob(b64); | |
| const len = bstr.length; | |
| const u8 = new Uint8Array(len); | |
| for (let i = 0; i < len; i++) u8[i] = bstr.charCodeAt(i); | |
| const blob = new Blob([u8], {{ type: mime }}); | |
| const url = URL.createObjectURL(blob); | |
| const tmp = doc.createElement("a"); | |
| tmp.href = url; | |
| tmp.download = fname; | |
| tmp.style.display = "none"; | |
| doc.body.appendChild(tmp); | |
| tmp.click(); | |
| setTimeout(() => {{ | |
| URL.revokeObjectURL(url); | |
| tmp.remove(); | |
| pressOff(); | |
| }}, 150); | |
| }}, true); | |
| }} catch (err) {{ | |
| console.debug("download handler error:", err); | |
| }} | |
| }})(); | |
| </script> | |
| """, | |
| height=0, | |
| ) | |
| def render_download_group(buttons: List[dict], group_key: str): | |
| """ | |
| buttons: list of dicts with keys: label, bytes, filename, mime, key | |
| """ | |
| import base64, html, hashlib | |
| ids = [] | |
| rows_html = [] | |
| for b in buttons: | |
| b64 = base64.b64encode(b["bytes"]).decode("ascii") | |
| btn_id = f"dl_{group_key}_{hashlib.sha256(b['key'].encode()).hexdigest()[:8]}" | |
| ids.append(btn_id) | |
| rows_html.append(f""" | |
| <div class="dl-row"> | |
| <a class="dl-btn" id="{btn_id}" href="#" | |
| data-b64="{b64}" | |
| data-mime="{html.escape(b['mime'])}" | |
| data-fname="{html.escape(b['filename'])}"> | |
| {html.escape(b['label'])} | |
| </a> | |
| </div> | |
| """) | |
| st.markdown( | |
| f""" | |
| <div class="dl-group"> | |
| {''.join(rows_html)} | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # one JS hook for all buttons in this group | |
| components.html( | |
| f""" | |
| <script> | |
| (function () {{ | |
| const doc = window.parent.document; | |
| function wire(id) {{ | |
| const a = doc.getElementById(id); | |
| if (!a) return; | |
| const pressOn = () => a.classList.add("pressed"); | |
| const pressOff = () => a.classList.remove("pressed"); | |
| a.addEventListener("mousedown", pressOn, true); | |
| a.addEventListener("mouseup", pressOff, true); | |
| a.addEventListener("mouseleave", pressOff, true); | |
| a.addEventListener("click", function(ev) {{ | |
| ev.preventDefault(); | |
| ev.stopImmediatePropagation(); | |
| const b64 = a.getAttribute("data-b64"); | |
| const mime = a.getAttribute("data-mime") || "application/octet-stream"; | |
| const fname = a.getAttribute("data-fname") || "download"; | |
| const bstr = atob(b64); | |
| const len = bstr.length; | |
| const u8 = new Uint8Array(len); | |
| for (let i = 0; i < len; i++) u8[i] = bstr.charCodeAt(i); | |
| const blob = new Blob([u8], {{ type: mime }}); | |
| const url = URL.createObjectURL(blob); | |
| const tmp = doc.createElement("a"); | |
| tmp.href = url; | |
| tmp.download = fname; | |
| tmp.style.display = "none"; | |
| doc.body.appendChild(tmp); | |
| tmp.click(); | |
| setTimeout(() => {{ | |
| URL.revokeObjectURL(url); | |
| tmp.remove(); | |
| pressOff(); | |
| }}, 150); | |
| }}, true); | |
| }} | |
| {''.join([f"wire('{i}');" for i in ids])} | |
| }})(); | |
| </script> | |
| """, | |
| height=0, | |
| ) | |
| # ========================= | |
| # Per-file processing | |
| # ========================= | |
| def process_nifti_case( | |
| nifti_path: str, | |
| model, | |
| rows_acc: List[Dict], | |
| per_frame_rows_acc: List[pd.DataFrame], | |
| *, | |
| progress_cb=None, | |
| ): | |
| """ | |
| progress_cb(file_frac: float, msg: Optional[str]) where file_frac in [0,1] | |
| """ | |
| def _p(frac: float, msg: Optional[str] = None): | |
| if progress_cb is None: | |
| return | |
| try: | |
| progress_cb(float(np.clip(frac, 0.0, 1.0)), msg) | |
| except Exception: | |
| pass | |
| name = Path(nifti_path).name | |
| pid = name.replace(".nii.gz", "").replace(".nii", "") | |
| log(f"[CASE] Using NIfTI input: {nifti_path}") | |
| # ---- Step A: load + prep (0 -> W_LOAD_PREP) ---- | |
| _p(0.00, "Loading NIfTI...") | |
| imgs_4d, spacing, final_aff = load_nifti_4d(nifti_path, orient_target=ORIENT_TARGET) | |
| _p(W_LOAD_PREP, "Preparing inference...") | |
| # Orientation bookkeeping (unchanged) | |
| global CURRENT_DISPLAY_ORIENT | |
| if ORIENT_TARGET is None: | |
| try: | |
| axc = nib.aff2axcodes(final_aff) | |
| if tuple(axc[:3]) == ("L", "P", "S"): | |
| CURRENT_DISPLAY_ORIENT = "LPS" | |
| elif tuple(axc[:3]) == ("R", "A", "S"): | |
| CURRENT_DISPLAY_ORIENT = "RAS" | |
| else: | |
| CURRENT_DISPLAY_ORIENT = "LPS" | |
| except Exception: | |
| CURRENT_DISPLAY_ORIENT = "LPS" | |
| else: | |
| CURRENT_DISPLAY_ORIENT = ORIENT_TARGET | |
| native_h, native_w, S, F = imgs_4d.shape | |
| row_native = float(spacing["row_mm"]) | |
| col_native = float(spacing["col_mm"]) | |
| thk_native = float(spacing["slice_thickness_mm"]) | |
| seg_row_mm = row_native * (native_h / float(SIZE_Y)) | |
| seg_col_mm = col_native * (native_w / float(SIZE_X)) | |
| seg_spacing = dict( | |
| row_mm=seg_row_mm, | |
| col_mm=seg_col_mm, | |
| slice_thickness_mm=thk_native, | |
| frame_time_ms=spacing.get("frame_time_ms", None), | |
| ) | |
| # ---- Step B: prediction (W_LOAD_PREP -> W_LOAD_PREP + W_PREDICT) ---- | |
| total_images = max(int(S * F), 1) | |
| def _pred_progress(done_images: int): | |
| frac = float(done_images) / float(total_images) | |
| frac = float(np.clip(frac, 0.0, 1.0)) | |
| file_frac = W_LOAD_PREP + W_PREDICT * frac | |
| _p(file_frac, f"Predicting... ({int(round(frac * 100))}%)") | |
| preds_4d_256, imgs_4d_256 = predict_nifti_4d( | |
| model, | |
| imgs_4d, | |
| batch_size=BATCH_SIZE, | |
| progress_cb=_pred_progress, # ✅ smooth & real | |
| ) | |
| preds_4d = preds_4d_256 | |
| # ---- Step C: postprocess + metrics + GIF (remaining W_POST_GIF) ---- | |
| post0 = W_LOAD_PREP + W_PREDICT | |
| # Split W_POST_GIF into real milestones | |
| # island: 35%, metrics: 25%, gif: 40% (of the post chunk) | |
| w_island = W_POST_GIF * 0.35 | |
| w_metrics = W_POST_GIF * 0.25 | |
| w_gif = W_POST_GIF * 0.40 | |
| # Island removal | |
| if ENABLE_ISLAND_REMOVAL: | |
| _p(post0 + 0.10 * w_island, "Cleaning islands...") | |
| preds_4d = clean_predictions_per_frame_3d(preds_4d) | |
| _p(post0 + w_island, "Cleaning islands... done") | |
| else: | |
| _p(post0 + w_island, "Skipping island removal") | |
| preds_4d_orig = preds_4d.copy() # ✅ immutable snapshot | |
| # Metrics | |
| _p(post0 + w_island + 0.10 * w_metrics, "Computing metrics...") | |
| voxel_mm3 = ( | |
| seg_spacing["row_mm"] | |
| * seg_spacing["col_mm"] | |
| * seg_spacing["slice_thickness_mm"] | |
| ) | |
| mid_slices = None | |
| if USE_MID_SLICES_FOR_ED_ES: | |
| has_blood, has_myo, _, _ = slice_validity_matrix( | |
| preds_4d, A_blood_min=MID_A_BLOOD_MIN, A_myo_min=MID_A_MYO_MIN | |
| ) | |
| mid_slices = choose_mid_slices( | |
| has_blood, | |
| has_myo, | |
| K=min(MID_K, preds_4d.shape[2]), | |
| min_frac=MID_MIN_VALID_FRAC, | |
| ) | |
| vols_subset = frame_volumes_subset_uL(preds_4d, seg_spacing, mid_slices) | |
| ed_idx, es_idx = pick_ed_es_from_volumes( | |
| vols_subset, prefer_frame0=True, rel_tol=0.05, min_sep=1 | |
| ) | |
| vols_full = np.array( | |
| [(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)], | |
| dtype=np.float32, | |
| ) | |
| EDV_uL = float(vols_full[ed_idx]) | |
| ESV_uL = float(vols_full[es_idx]) | |
| else: | |
| vols_full = np.array( | |
| [(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)], | |
| dtype=np.float32, | |
| ) | |
| ed_idx, es_idx = pick_ed_es_from_volumes( | |
| vols_full, prefer_frame0=True, rel_tol=0.05, min_sep=1 | |
| ) | |
| EDV_uL = float(vols_full[ed_idx]) | |
| ESV_uL = float(vols_full[es_idx]) | |
| SV_uL = EDV_uL - ESV_uL | |
| EF_pct = (SV_uL / EDV_uL * 100.0) if EDV_uL > 0 else 0.0 | |
| per_frame_df = compute_per_frame_metrics(preds_4d, seg_spacing) | |
| # Safer indexing | |
| try: | |
| myo_mass_ED_mg = float(per_frame_df.loc[per_frame_df["Frame"] == ed_idx, "MyocardiumMass_mg"].values[0]) | |
| except Exception: | |
| myo_mass_ED_mg = 0.0 | |
| try: | |
| myo_mass_ES_mg = float(per_frame_df.loc[per_frame_df["Frame"] == es_idx, "MyocardiumMass_mg"].values[0]) | |
| except Exception: | |
| myo_mass_ES_mg = 0.0 | |
| per_frame_df.insert(0, "Patient_ID", pid) | |
| per_frame_rows_acc.append(per_frame_df) | |
| rows_acc.append( | |
| { | |
| "Patient_ID": pid, | |
| "EDV_uL": EDV_uL, | |
| "ESV_uL": ESV_uL, | |
| "SV_uL": SV_uL, | |
| "EF_%": EF_pct, | |
| "MyocardiumMass_ED_mg": myo_mass_ED_mg, | |
| "MyocardiumMass_ES_mg": myo_mass_ES_mg, | |
| "ED_frame_index": int(ed_idx), | |
| "ES_frame_index": int(es_idx), | |
| "Seg_SliceThickness_mm": seg_spacing["slice_thickness_mm"], | |
| "Seg_PixelSpacing_row_mm": seg_spacing["row_mm"], | |
| "Seg_PixelSpacing_col_mm": seg_spacing["col_mm"], | |
| "Native_SliceThickness_mm": spacing["slice_thickness_mm"], | |
| "Native_PixelSpacing_row_mm": spacing["row_mm"], | |
| "Native_PixelSpacing_col_mm": spacing["col_mm"], | |
| } | |
| ) | |
| _p(post0 + w_island + w_metrics, "Computing metrics... done") | |
| # GIF | |
| _p(post0 + w_island + w_metrics + 0.10 * w_gif, "Creating GIF...") | |
| gif_path = gif_animation_for_patient_pred_only( | |
| imgs_4d_256, preds_4d, pid, ed_idx, es_idx, get_session_gifs_dir() | |
| ) | |
| log(f"[CASE] GIF saved: {gif_path}") | |
| if "gif_paths" not in st.session_state: | |
| st.session_state["gif_paths"] = {} | |
| st.session_state["gif_paths"][pid] = gif_path | |
| case_info = dict( | |
| pid=pid, | |
| imgs_4d_256=imgs_4d_256, | |
| preds_4d=preds_4d, | |
| preds_4d_orig=preds_4d_orig, | |
| seg_spacing=seg_spacing, | |
| native_spacing=spacing, | |
| native_shape=(native_h, native_w, S, F), | |
| affine=final_aff, | |
| ed_idx_auto=int(ed_idx), | |
| es_idx_auto=int(es_idx), | |
| ed_idx_user=int(ed_idx), | |
| es_idx_user=int(es_idx), | |
| ed_es_confirmed=False, | |
| mid_slices=mid_slices, | |
| ) | |
| case_info["dirty"] = False | |
| case_info["mask_edited"] = False | |
| if "cases" not in st.session_state: | |
| st.session_state["cases"] = {} | |
| st.session_state["cases"][pid] = case_info | |
| # ✅ Done for this file | |
| _p(1.0, "Done") | |
| # ========================= | |
| # Recompute metrics from corrected masks | |
| # ========================= | |
| def recompute_metrics_from_session_cases_dirty_gif_only(): | |
| """ | |
| Choice 1: | |
| - Recompute metrics for ALL cases (cheap; keeps CSV consistent) | |
| - Regenerate GIF ONLY for cases marked case["dirty"] == True (expensive; saved) | |
| - Clear dirty flag after regenerating GIF | |
| Uses user-confirmed ED/ES if available; otherwise falls back to current case ed/es. | |
| """ | |
| if "cases" not in st.session_state or not st.session_state["cases"]: | |
| return [], [], False | |
| rows: List[Dict] = [] | |
| per_frame_rows: List[pd.DataFrame] = [] | |
| updated_any_gif = False | |
| if "gif_paths_corrected" not in st.session_state: | |
| st.session_state["gif_paths_corrected"] = {} | |
| for pid, case in st.session_state["cases"].items(): | |
| preds_4d = case["preds_4d"] | |
| imgs_4d_256 = case["imgs_4d_256"] | |
| seg_spacing = case["seg_spacing"] | |
| spacing = case["native_spacing"] | |
| # --- choose ED/ES indices --- | |
| # Prefer user-confirmed indices if present; else fall back to stored ones. | |
| if case.get("ed_es_confirmed", False): | |
| ed_idx = int(case["ed_idx_user"]) | |
| es_idx = int(case["es_idx_user"]) | |
| else: | |
| ed_idx = int(case["ed_idx_auto"]) | |
| es_idx = int(case["es_idx_auto"]) | |
| # --- recompute metrics for ALL cases (cheap) --- | |
| voxel_mm3 = ( | |
| seg_spacing["row_mm"] | |
| * seg_spacing["col_mm"] | |
| * seg_spacing["slice_thickness_mm"] | |
| ) | |
| F = preds_4d.shape[3] | |
| vols_full = np.array( | |
| [(preds_4d[..., f] == 2).sum() * voxel_mm3 for f in range(F)], | |
| dtype=np.float32, | |
| ) | |
| EDV_uL = float(vols_full[ed_idx]) if 0 <= ed_idx < F else 0.0 | |
| ESV_uL = float(vols_full[es_idx]) if 0 <= es_idx < F else 0.0 | |
| SV_uL = EDV_uL - ESV_uL | |
| EF_pct = (SV_uL / EDV_uL * 100.0) if EDV_uL > 0 else 0.0 | |
| per_frame_df = compute_per_frame_metrics(preds_4d, seg_spacing) | |
| # Guard in case indices are out of bounds (shouldn't happen, but safe) | |
| try: | |
| myo_mass_ED_mg = float( | |
| per_frame_df.loc[per_frame_df["Frame"] == ed_idx, "MyocardiumMass_mg"].values[0] | |
| ) | |
| except Exception: | |
| myo_mass_ED_mg = 0.0 | |
| try: | |
| myo_mass_ES_mg = float( | |
| per_frame_df.loc[per_frame_df["Frame"] == es_idx, "MyocardiumMass_mg"].values[0] | |
| ) | |
| except Exception: | |
| myo_mass_ES_mg = 0.0 | |
| per_frame_df.insert(0, "Patient_ID", pid) | |
| per_frame_rows.append(per_frame_df) | |
| rows.append( | |
| { | |
| "Patient_ID": pid, | |
| "EDV_uL": EDV_uL, | |
| "ESV_uL": ESV_uL, | |
| "SV_uL": SV_uL, | |
| "EF_%": EF_pct, | |
| "MyocardiumMass_ED_mg": myo_mass_ED_mg, | |
| "MyocardiumMass_ES_mg": myo_mass_ES_mg, | |
| "ED_frame_index": int(ed_idx), | |
| "ES_frame_index": int(es_idx), | |
| "Seg_SliceThickness_mm": seg_spacing["slice_thickness_mm"], | |
| "Seg_PixelSpacing_row_mm": seg_spacing["row_mm"], | |
| "Seg_PixelSpacing_col_mm": seg_spacing["col_mm"], | |
| "Native_SliceThickness_mm": spacing["slice_thickness_mm"], | |
| "Native_PixelSpacing_row_mm": spacing["row_mm"], | |
| "Native_PixelSpacing_col_mm": spacing["col_mm"], | |
| } | |
| ) | |
| # --- regenerate GIF only if dirty --- | |
| if case.get("dirty", False): | |
| gif_path = gif_animation_for_patient_pred_only( | |
| imgs_4d_256, | |
| preds_4d, | |
| pid, | |
| ed_idx, | |
| es_idx, | |
| get_session_gifs_dir(), | |
| ) | |
| st.session_state["gif_paths_corrected"][pid] = gif_path | |
| case["dirty"] = False | |
| updated_any_gif = True | |
| return rows, per_frame_rows, updated_any_gif | |
| # ========================================================== | |
| # Roundel-style LV-only manual correction | |
| # ========================================================== | |
| # This section adapts the in-house Roundel biventricular editor interface | |
| # for the public Hugging Face / Streamlit app. | |
| # | |
| # Important differences from the in-house Roundel app: | |
| # - This app is LV-only: labels are 0=background, 1=LV myocardium, 2=LV blood pool. | |
| # - Images/masks come from st.session_state["cases"] after model inference. | |
| # - No persistent server-side case database is used; notes/status are session-only. | |
| # - The original Roundel visual workflow is preserved: ED/ES Finder → Mask Editor → Final Result. | |
| # ========================================================== | |
| LV_MYO_LABEL = 1 | |
| LV_BP_LABEL = 2 | |
| ROUNDDEL_LV_OVERLAY_COLORS = { | |
| 0: (200, 200, 200, 0), | |
| LV_MYO_LABEL: (0, 0, 255, 70), # LV myocardium = blue | |
| LV_BP_LABEL: (255, 10, 10, 90), # LV blood pool = red | |
| } | |
| ROUNDDEL_LV_BRUSH_LABELS = { | |
| LV_MYO_LABEL: "LV Myocardium 🔵", | |
| LV_BP_LABEL: "LV Blood Pool 🔴", | |
| } | |
| def _streamlit_segmented_control(label, options, default=None, key=None, label_visibility="visible"): | |
| """ | |
| Compatibility helper: | |
| - New Streamlit: st.segmented_control | |
| - Older Streamlit: st.radio | |
| Important: | |
| If a key is provided, initialise st.session_state[key] once and then | |
| do NOT pass default/value/index to the widget. This avoids Streamlit | |
| warnings/errors when we programmatically switch views. | |
| """ | |
| if default is None: | |
| default = options[0] | |
| if key is not None: | |
| if key not in st.session_state: | |
| st.session_state[key] = default | |
| elif st.session_state[key] not in options: | |
| st.session_state[key] = default | |
| if hasattr(st, "segmented_control"): | |
| try: | |
| return st.segmented_control( | |
| label, | |
| options=options, | |
| key=key, | |
| label_visibility=label_visibility, | |
| ) | |
| except TypeError: | |
| pass | |
| return st.radio( | |
| label, | |
| options=options, | |
| horizontal=True, | |
| key=key, | |
| label_visibility=label_visibility, | |
| ) | |
| def _roundel_case_label(pid: str) -> str: | |
| case = (st.session_state.get("cases", {}) or {}).get(pid, {}) | |
| done = bool(case.get("manual_review_done", False)) | |
| note = (case.get("manual_note", "") or "").strip() | |
| status_icon = "✅" if done else "⬜" | |
| note_icon = " 📝" if note else "" | |
| return f"{status_icon} {pid}{note_icon}" | |
| def _roundel_normalize(image): | |
| image = np.asarray(image, dtype=np.float32) | |
| return (image - np.min(image)) / (np.max(image) - np.min(image) + 1e-9) | |
| def _roundel_get_overlay(image_slice_u8, mask_slice, channels_to_show): | |
| """ | |
| Returns an RGB PIL image, similar to the in-house Roundel overlay renderer. | |
| image_slice_u8: 2D uint8 image at 256x256. | |
| mask_slice: 2D uint8 label mask at 256x256. | |
| """ | |
| base = Image.fromarray(np.stack([image_slice_u8] * 3, axis=-1)).convert("RGBA") | |
| H, W = image_slice_u8.shape | |
| for ch in channels_to_show: | |
| ch_mask = (mask_slice == int(ch)) | |
| if np.any(ch_mask): | |
| overlay = np.zeros((H, W, 4), dtype=np.uint8) | |
| overlay[ch_mask] = np.array(ROUNDDEL_LV_OVERLAY_COLORS[int(ch)], dtype=np.uint8) | |
| base = Image.alpha_composite(base, Image.fromarray(overlay, mode="RGBA")) | |
| return base.convert("RGB") | |
| def _roundel_overlay_selector_ui(): | |
| """ | |
| LV-only version of Roundel overlay selection. | |
| """ | |
| st.session_state.setdefault("roundel_overlay_overlapped_all", True) | |
| st.session_state.setdefault("roundel_overlay_lv_bp", True) | |
| st.session_state.setdefault("roundel_overlay_lv_myo", True) | |
| overlapped = st.checkbox("Overlapped All", key="roundel_overlay_overlapped_all") | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| lv_bp = st.checkbox("LV Blood Pool", key="roundel_overlay_lv_bp", disabled=overlapped) | |
| with col_b: | |
| lv_myo = st.checkbox("LV Myocardium", key="roundel_overlay_lv_myo", disabled=overlapped) | |
| if overlapped: | |
| return [LV_MYO_LABEL, LV_BP_LABEL] | |
| channels = [] | |
| if lv_myo: | |
| channels.append(LV_MYO_LABEL) | |
| if lv_bp: | |
| channels.append(LV_BP_LABEL) | |
| return channels if channels else [LV_MYO_LABEL, LV_BP_LABEL] | |
| def _roundel_select_brush(): | |
| """ | |
| LV-only version of Roundel brush selection. | |
| """ | |
| if "roundel_brush_mode" not in st.session_state: | |
| st.session_state["roundel_brush_mode"] = "Paint ✏️" | |
| if "roundel_stroke_width_preset" not in st.session_state: | |
| st.session_state["roundel_stroke_width_preset"] = "thin" | |
| if "roundel_stroke_width_value" not in st.session_state: | |
| st.session_state["roundel_stroke_width_value"] = 6 | |
| action = st.radio( | |
| "Brush Stroke Selection", | |
| options=["Paint ✏️", "Erase ✂️"], | |
| index=["Paint ✏️", "Erase ✂️"].index(st.session_state["roundel_brush_mode"]), | |
| horizontal=True, | |
| key="roundel_brush_mode_radio", | |
| ) | |
| st.session_state["roundel_brush_mode"] = action | |
| stroke_width_map = {"thin": 6, "medium": 20, "thick": 40} | |
| preset = st.radio( | |
| "Stroke Width (Preset)", | |
| options=list(stroke_width_map.keys()), | |
| index=list(stroke_width_map.keys()).index(st.session_state.get("roundel_stroke_width_preset", "thin")), | |
| horizontal=True, | |
| key="roundel_stroke_width_preset_radio", | |
| ) | |
| st.session_state["roundel_stroke_width_preset"] = preset | |
| preset_val = stroke_width_map[preset] | |
| last_preset = st.session_state.get("_roundel_last_stroke_preset", None) | |
| if last_preset != preset: | |
| st.session_state["roundel_stroke_width_value"] = preset_val | |
| st.session_state["_roundel_last_stroke_preset"] = preset | |
| st.session_state["roundel_stroke_width_value"] = int( | |
| np.clip(st.session_state.get("roundel_stroke_width_value", preset_val), 6, 40) | |
| ) | |
| stroke_val = st.slider( | |
| "Stroke Width", | |
| min_value=6, | |
| max_value=40, | |
| step=1, | |
| key="roundel_stroke_width_value", | |
| ) | |
| if action == "Paint ✏️": | |
| channel = st.radio( | |
| "Mask", | |
| options=[LV_MYO_LABEL, LV_BP_LABEL], | |
| format_func=lambda x: ROUNDDEL_LV_BRUSH_LABELS[x], | |
| index=0, | |
| horizontal=True, | |
| key="roundel_brush_label_radio", | |
| ) | |
| else: | |
| channel = 0 | |
| return int(channel), action, int(stroke_val) | |
| def _extract_stroke_mask_from_canvas(canvas_rgba, background_rgb, diff_thresh=10): | |
| """ | |
| Roundel-style stroke extraction by comparing the returned canvas image | |
| against the exact background image passed to st_canvas. | |
| canvas_rgba: numpy (Hc,Wc,4) from canvas_result.image_data | |
| background_rgb: numpy (Hc,Wc,3) from the SAME background passed to st_canvas | |
| Returns: | |
| binary uint8 (Hc,Wc), where 1 == newly drawn stroke pixels. | |
| """ | |
| if canvas_rgba is None or background_rgb is None: | |
| return None | |
| canvas_rgb = canvas_rgba[:, :, :3].astype(np.int16) | |
| bg_rgb = background_rgb[:, :, :3].astype(np.int16) | |
| alpha = canvas_rgba[:, :, 3].astype(np.int16) | |
| diff = np.max(np.abs(canvas_rgb - bg_rgb), axis=-1) | |
| stroke = (diff >= int(diff_thresh)) & (alpha > 0) | |
| return stroke.astype(np.uint8) | |
| def _resize_binary_mask(mask2d, target_h, target_w): | |
| if mask2d is None: | |
| return None | |
| if mask2d.shape[0] == target_h and mask2d.shape[1] == target_w: | |
| return mask2d.astype(np.uint8) | |
| pil = Image.fromarray((mask2d > 0).astype(np.uint8) * 255) | |
| pil = pil.resize((int(target_w), int(target_h)), resample=Image.NEAREST) | |
| return (np.array(pil) > 0).astype(np.uint8) | |
| def build_processed_stroke_and_enclosed_region(strokes, stroke_width, do_fill_holes: bool): | |
| """ | |
| Adapted from roundel_biventricular_utils.py. | |
| Process a drawn stroke and also estimate the enclosed region. | |
| Returns | |
| ------- | |
| processed_mask : uint8 binary mask | |
| The mask that will be used for painting/erasing. | |
| enclosed_region : uint8 binary mask | |
| Pixels enclosed by the drawn contour. For myocardium painting, this lets us | |
| preserve an already existing LV blood pool inside a drawn myocardium ring. | |
| """ | |
| if strokes is None or not np.any(strokes): | |
| empty = np.zeros_like(strokes, dtype=np.uint8) if strokes is not None else None | |
| return strokes, empty | |
| strokes = (strokes > 0).astype(np.uint8) | |
| dilation_factor = int(max(1, round(stroke_width / 8))) | |
| closed = binary_closing(strokes, iterations=max(1, dilation_factor)) | |
| thick = binary_dilation(closed, iterations=max(1, dilation_factor)) | |
| filled = binary_fill_holes(thick).astype(np.uint8) | |
| enclosed_region = ((filled > 0) & (thick == 0)).astype(np.uint8) | |
| if do_fill_holes: | |
| working = filled | |
| else: | |
| working = thick | |
| out = binary_erosion(working, iterations=max(1, dilation_factor)) | |
| out = out.astype(np.uint8) | |
| return out, enclosed_region.astype(np.uint8) | |
| def _roundel_apply_stroke_to_label_slice(mask_slice, stroke_mask, enclosed_region, channel, action): | |
| """ | |
| LV-only label-mask painting/erasing. | |
| mask_slice labels: | |
| 0 = background | |
| 1 = LV myocardium | |
| 2 = LV blood pool | |
| """ | |
| updated = mask_slice.copy() | |
| pm = (stroke_mask > 0) | |
| if not np.any(pm): | |
| return updated, False | |
| if action == "Erase ✂️" or int(channel) == 0: | |
| updated[pm] = 0 | |
| return updated, True | |
| channel = int(channel) | |
| if channel == LV_MYO_LABEL: | |
| # Preserve existing blood pool if user draws a myocardium ring around it. | |
| if enclosed_region is not None and np.any(enclosed_region): | |
| protected_bp = (enclosed_region.astype(bool)) & (updated == LV_BP_LABEL) | |
| writable = pm & (~protected_bp) | |
| else: | |
| writable = pm | |
| updated[writable] = LV_MYO_LABEL | |
| return updated, bool(np.any(writable)) | |
| if channel == LV_BP_LABEL: | |
| updated[pm] = LV_BP_LABEL | |
| return updated, True | |
| return updated, False | |
| def _roundel_frame_index_slider(T, frames, initial_idx, label, disabled_flag, key): | |
| """ | |
| LV-only Roundel ED/ES frame slider with montage preview. | |
| Uses 0-based indices to match the original in-house Roundel app. | |
| Important: | |
| Do NOT pass value=... if the same widget key is also initialised | |
| in st.session_state, otherwise Streamlit shows: | |
| "The widget with key ... was created with a default value but also had | |
| its value set via the Session State API." | |
| """ | |
| initial_idx = int(np.clip(initial_idx, 0, T - 1)) | |
| if key not in st.session_state: | |
| st.session_state[key] = initial_idx | |
| else: | |
| st.session_state[key] = int(np.clip(st.session_state[key], 0, T - 1)) | |
| idx = st.slider( | |
| f"{label} | *{int(st.session_state[key])}*", | |
| min_value=0, | |
| max_value=T - 1, | |
| key=key, | |
| disabled=disabled_flag, | |
| ) | |
| idx = int(np.clip(idx, 0, T - 1)) | |
| st.image(frames[idx], use_column_width=True) | |
| return idx | |
| def _roundel_build_all_slice_preview_frames(case): | |
| """ | |
| Build one montage frame for every timeframe. | |
| """ | |
| imgs_4d_256 = case["imgs_4d_256"] | |
| preds_4d = case["preds_4d"] | |
| _, _, S, F = preds_4d.shape | |
| frames = [] | |
| slice_indices = list(range(S)) | |
| for f in range(F): | |
| frames.append( | |
| build_frame_montage( | |
| imgs_4d_256, | |
| preds_4d, | |
| frame_idx=f, | |
| slice_indices=slice_indices, | |
| tile_cols=4, | |
| add_overlay=True, | |
| ) | |
| ) | |
| return frames | |
| def _roundel_confirm_ed_es_frames_callback(selected_pid, ed_idx, es_idx): | |
| """ | |
| Confirm the selected ED/ES frame indices and automatically switch | |
| the internal Roundel view to the Mask Editor. | |
| Important: | |
| This is a button callback, so it runs before Streamlit recreates | |
| the segmented-control widget on the next rerun. That avoids the | |
| StreamlitAPIException caused by modifying a widget key after creation. | |
| """ | |
| cases = st.session_state.get("cases", {}) | |
| if selected_pid not in cases: | |
| return | |
| case = cases[selected_pid] | |
| case["ed_idx_user"] = int(ed_idx) | |
| case["es_idx_user"] = int(es_idx) | |
| case["ed_es_confirmed"] = True | |
| case["dirty"] = True | |
| invalidate_download_caches() | |
| # Automatically load the Mask Editor view after confirmation. | |
| st.session_state["roundel_view_segmented"] = "Mask Editor 📝" | |
| def _roundel_change_slice_callback(selected_pid, delta, max_slices): | |
| """ | |
| Change slice index safely from Previous/Next buttons. | |
| This runs as a button callback before Streamlit recreates the slider, | |
| so it is allowed to modify the slider's session_state key. | |
| """ | |
| key = f"roundel_slice_idx_{selected_pid}" | |
| current = int(st.session_state.get(key, 0)) | |
| st.session_state[key] = int(np.clip(current + int(delta), 0, int(max_slices) - 1)) | |
| def _roundel_change_frame_callback(selected_pid, delta, max_frames): | |
| """ | |
| Change frame index safely from Previous/Next buttons. | |
| Also switches frame mode to Custom. | |
| """ | |
| frame_key = f"roundel_frame_idx_{selected_pid}" | |
| mode_key = f"roundel_frame_mode_{selected_pid}" | |
| current = int(st.session_state.get(frame_key, 0)) | |
| st.session_state[frame_key] = int(np.clip(current + int(delta), 0, int(max_frames) - 1)) | |
| st.session_state[mode_key] = "Custom" | |
| def roundel_lv_edv_esv_view(selected_pid, case): | |
| """ | |
| Roundel-style EDV/ESV Finder, adapted to LV-only. | |
| """ | |
| preds_4d = case["preds_4d"] | |
| _, _, S, F = preds_4d.shape | |
| if f"roundel_preview_frames_{selected_pid}" not in st.session_state: | |
| st.session_state[f"roundel_preview_frames_{selected_pid}"] = _roundel_build_all_slice_preview_frames(case) | |
| frames = st.session_state[f"roundel_preview_frames_{selected_pid}"] | |
| disabled_flag = bool(case.get("ed_es_confirmed", False)) | |
| display_lv_dia_idx = int(case.get("ed_idx_user", case.get("ed_idx_auto", 0))) | |
| display_lv_sys_idx = int(case.get("es_idx_user", case.get("es_idx_auto", 0))) | |
| st.markdown("#### Left Ventricle") | |
| col_edv, col_esv = st.columns(2) | |
| with col_edv: | |
| lv_dia_idx = _roundel_frame_index_slider( | |
| F, | |
| frames, | |
| display_lv_dia_idx, | |
| "LV End-Diastolic Index", | |
| disabled_flag, | |
| key=f"roundel_lv_edv_{selected_pid}", | |
| ) | |
| with col_esv: | |
| lv_sys_idx = _roundel_frame_index_slider( | |
| F, | |
| frames, | |
| display_lv_sys_idx, | |
| "LV End-Systolic Index", | |
| disabled_flag, | |
| key=f"roundel_lv_esv_{selected_pid}", | |
| ) | |
| st.write("") | |
| if not disabled_flag: | |
| st.button( | |
| "Confirm ED & ES Frames", | |
| type="primary", | |
| use_container_width=True, | |
| key=f"roundel_confirm_ed_es_{selected_pid}", | |
| on_click=_roundel_confirm_ed_es_frames_callback, | |
| args=(selected_pid, int(lv_dia_idx), int(lv_sys_idx)), | |
| ) | |
| else: | |
| st.success("ED & ES frames confirmed!") | |
| if st.button( | |
| "Reselect ED/ES Frames", | |
| use_container_width=True, | |
| key=f"roundel_reselect_ed_es_{selected_pid}", | |
| ): | |
| case["ed_es_confirmed"] = False | |
| st.session_state["roundel_view_segmented"] = "EDV/ESV Finder 🔍" | |
| _safe_rerun() | |
| def _roundel_ensure_nav_state(selected_pid): | |
| for key, default in [ | |
| (f"roundel_slice_idx_{selected_pid}", 0), | |
| (f"roundel_frame_idx_{selected_pid}", 0), | |
| (f"roundel_frame_mode_{selected_pid}", "End-Diastole"), | |
| (f"roundel_canvas_nonce_{selected_pid}", 0), | |
| ]: | |
| if key not in st.session_state: | |
| st.session_state[key] = default | |
| def roundel_lv_mask_editor_view(selected_pid, case): | |
| """ | |
| Roundel-style LV mask editor. | |
| """ | |
| if not case.get("ed_es_confirmed", False): | |
| st.error("Select and confirm ED & ES frames first.") | |
| return | |
| imgs_4d_256 = case["imgs_4d_256"] | |
| preds_4d = case["preds_4d"] | |
| _, _, S, F = preds_4d.shape | |
| _roundel_ensure_nav_state(selected_pid) | |
| col1, col2, col3 = st.columns([1, 1.5, 1.5]) | |
| with col1: | |
| st.markdown("#### Left Ventricle") | |
| st.caption("LV-only editor for labels: myocardium + blood pool.") | |
| channel, action, stroke_width = _roundel_select_brush() | |
| st.caption("Image Selection") | |
| lv_dia_idx = int(case["ed_idx_user"]) | |
| lv_sys_idx = int(case["es_idx_user"]) | |
| st.caption(f"LV ED = **{lv_dia_idx}** | LV ES = **{lv_sys_idx}**") | |
| slice_key = f"roundel_slice_idx_{selected_pid}" | |
| frame_key = f"roundel_frame_idx_{selected_pid}" | |
| mode_key = f"roundel_frame_mode_{selected_pid}" | |
| st.slider("Slice Index", 0, S - 1, key=slice_key) | |
| col_prev_s, col_next_s = st.columns(2) | |
| with col_prev_s: | |
| st.button( | |
| "Previous", | |
| use_container_width=True, | |
| key=f"roundel_slice_prev_btn_{selected_pid}", | |
| on_click=_roundel_change_slice_callback, | |
| args=(selected_pid, -1, S), | |
| ) | |
| with col_next_s: | |
| st.button( | |
| "Next", | |
| use_container_width=True, | |
| key=f"roundel_slice_next_btn_{selected_pid}", | |
| on_click=_roundel_change_slice_callback, | |
| args=(selected_pid, 1, S), | |
| ) | |
| frame_mode = st.radio( | |
| "Frame", | |
| options=["End-Diastole", "End-Systole", "Custom"], | |
| index=["End-Diastole", "End-Systole", "Custom"].index(st.session_state.get(mode_key, "End-Diastole")), | |
| horizontal=False, | |
| key=f"roundel_frame_mode_radio_{selected_pid}", | |
| ) | |
| st.session_state[mode_key] = frame_mode | |
| if frame_mode == "End-Diastole": | |
| st.session_state[frame_key] = lv_dia_idx | |
| elif frame_mode == "End-Systole": | |
| st.session_state[frame_key] = lv_sys_idx | |
| else: | |
| st.session_state[frame_key] = int(np.clip(st.session_state[frame_key], 0, F - 1)) | |
| st.slider("Frame Index", 0, F - 1, key=frame_key) | |
| col_prev_f, col_next_f = st.columns(2) | |
| with col_prev_f: | |
| st.button( | |
| "Previous", | |
| use_container_width=True, | |
| key=f"roundel_frame_prev_btn_{selected_pid}", | |
| on_click=_roundel_change_frame_callback, | |
| args=(selected_pid, -1, F), | |
| ) | |
| with col_next_f: | |
| st.button( | |
| "Next", | |
| use_container_width=True, | |
| key=f"roundel_frame_next_btn_{selected_pid}", | |
| on_click=_roundel_change_frame_callback, | |
| args=(selected_pid, 1, F), | |
| ) | |
| d = int(np.clip(st.session_state[f"roundel_slice_idx_{selected_pid}"], 0, S - 1)) | |
| t = int(np.clip(st.session_state[f"roundel_frame_idx_{selected_pid}"], 0, F - 1)) | |
| image_slice = imgs_4d_256[:, :, d, t] | |
| image_slice_u8 = (_roundel_normalize(image_slice) * 255).astype(np.uint8) | |
| mask_slice = preds_4d[:, :, d, t].astype(np.uint8) | |
| with col2: | |
| edit_mode = st.radio( | |
| "Segmentation Editor", | |
| ["Editor", "Viewer"], | |
| index=0, | |
| horizontal=True, | |
| key=f"roundel_seg_editor_mode_{selected_pid}", | |
| ) | |
| channels_to_show = _roundel_overlay_selector_ui() | |
| if edit_mode == "Viewer": | |
| st.image(image_slice_u8, width=DISPLAY_W) | |
| else: | |
| stroke_color = ( | |
| "rgba(200, 200, 200, 0.55)" | |
| if action == "Erase ✂️" | |
| else "rgba(0, 0, 255, 0.85)" if channel == LV_MYO_LABEL | |
| else "rgba(255, 10, 10, 0.85)" | |
| ) | |
| bg_img = _roundel_get_overlay(image_slice_u8, mask_slice, channels_to_show) | |
| canvas_w = int(DISPLAY_W) | |
| canvas_h = int(SIZE_Y * DISPLAY_W / SIZE_X) | |
| bg_img_for_canvas = bg_img.resize((canvas_w, canvas_h), resample=Image.NEAREST) | |
| bg_np = np.array(bg_img_for_canvas, dtype=np.uint8) | |
| canvas_key = ( | |
| f"roundel_canvas_{selected_pid}_{d}_{t}_" | |
| f"{st.session_state.get(f'roundel_canvas_nonce_{selected_pid}', 0)}" | |
| ) | |
| canvas_result = st_canvas( | |
| stroke_width=stroke_width, | |
| stroke_color=stroke_color, | |
| background_image=bg_img_for_canvas, | |
| update_streamlit=True, | |
| height=canvas_h, | |
| width=canvas_w, | |
| drawing_mode="freedraw", | |
| key=canvas_key, | |
| ) | |
| col_save, col_clear = st.columns([1, 0.35]) | |
| with col_save: | |
| save_contour = st.button("Save Contour", type="primary", use_container_width=True, key=f"roundel_save_contour_btn_{selected_pid}") | |
| if save_contour and canvas_result and (canvas_result.image_data is not None): | |
| canvas_rgba = np.array(canvas_result.image_data, dtype=np.uint8) | |
| stroke_mask_canvas = _extract_stroke_mask_from_canvas(canvas_rgba, bg_np, diff_thresh=10) | |
| if stroke_mask_canvas is None or not np.any(stroke_mask_canvas): | |
| st.warning("No new strokes detected.") | |
| else: | |
| stroke_mask = _resize_binary_mask(stroke_mask_canvas, target_h=SIZE_Y, target_w=SIZE_X) | |
| # Match Roundel behaviour: paint strokes can be closed/filled. | |
| # For myocardium, enclosed existing blood pool is protected below. | |
| do_fill = (action == "Paint ✏️") and (channel in (LV_MYO_LABEL, LV_BP_LABEL)) | |
| stroke_mask, enclosed_region = build_processed_stroke_and_enclosed_region( | |
| stroke_mask, | |
| stroke_width, | |
| do_fill_holes=do_fill, | |
| ) | |
| updated_slice, changed = _roundel_apply_stroke_to_label_slice( | |
| mask_slice=preds_4d[:, :, d, t].astype(np.uint8), | |
| stroke_mask=stroke_mask, | |
| enclosed_region=enclosed_region, | |
| channel=channel, | |
| action=action, | |
| ) | |
| if changed: | |
| preds_4d[:, :, d, t] = updated_slice | |
| case["preds_4d"] = preds_4d | |
| case["dirty"] = True | |
| case["mask_edited"] = True | |
| st.session_state.pop(f"roundel_preview_frames_{selected_pid}", None) | |
| invalidate_download_caches() | |
| st.success("Contour saved.") | |
| else: | |
| st.warning("Stroke was detected, but no writable mask pixels were changed.") | |
| st.session_state[f"roundel_canvas_nonce_{selected_pid}"] = int(st.session_state.get(f"roundel_canvas_nonce_{selected_pid}", 0)) + 1 | |
| _safe_rerun() | |
| with col_clear: | |
| if st.button("Clear Slice", use_container_width=True, key=f"roundel_clear_slice_btn_{selected_pid}"): | |
| preds_4d[:, :, d, t] = 0 | |
| case["preds_4d"] = preds_4d | |
| case["dirty"] = True | |
| case["mask_edited"] = True | |
| st.session_state.pop(f"roundel_preview_frames_{selected_pid}", None) | |
| invalidate_download_caches() | |
| st.session_state[f"roundel_canvas_nonce_{selected_pid}"] = int(st.session_state.get(f"roundel_canvas_nonce_{selected_pid}", 0)) + 1 | |
| st.success("LV mask cleared for this slice/frame.") | |
| _safe_rerun() | |
| with col3: | |
| view_mode = st.radio( | |
| "Corrected Mask", | |
| ["Static", "Viewer"], | |
| index=0, | |
| horizontal=True, | |
| key=f"roundel_corrected_mask_view_mode_{selected_pid}", | |
| ) | |
| if view_mode == "Viewer": | |
| corrected = _roundel_get_overlay( | |
| image_slice_u8, | |
| preds_4d[:, :, d, t].astype(np.uint8), | |
| [LV_MYO_LABEL, LV_BP_LABEL], | |
| ) | |
| st.image(corrected, width=DISPLAY_W) | |
| else: | |
| if f"roundel_preview_frames_{selected_pid}" not in st.session_state: | |
| st.session_state[f"roundel_preview_frames_{selected_pid}"] = _roundel_build_all_slice_preview_frames(case) | |
| frames = st.session_state[f"roundel_preview_frames_{selected_pid}"] | |
| ed = int(case["ed_idx_user"]) | |
| es = int(case["es_idx_user"]) | |
| if t == es: | |
| st.image(frames[es], width=int(DISPLAY_W * 1.5)) | |
| else: | |
| st.image(frames[ed], width=int(DISPLAY_W * 1.5)) | |
| def _roundel_metrics_for_case(case): | |
| preds_4d = case["preds_4d"] | |
| seg_spacing = case["seg_spacing"] | |
| ed_idx = int(case["ed_idx_user"]) if case.get("ed_es_confirmed", False) else int(case["ed_idx_auto"]) | |
| es_idx = int(case["es_idx_user"]) if case.get("ed_es_confirmed", False) else int(case["es_idx_auto"]) | |
| voxel_mm3 = ( | |
| float(seg_spacing["row_mm"]) | |
| * float(seg_spacing["col_mm"]) | |
| * float(seg_spacing["slice_thickness_mm"]) | |
| ) | |
| F = preds_4d.shape[3] | |
| ed_idx = int(np.clip(ed_idx, 0, F - 1)) | |
| es_idx = int(np.clip(es_idx, 0, F - 1)) | |
| vols = np.array([(preds_4d[..., f] == LV_BP_LABEL).sum() * voxel_mm3 for f in range(F)], dtype=np.float64) | |
| masses = np.array([(preds_4d[..., f] == LV_MYO_LABEL).sum() * voxel_mm3 * MYO_DENSITY for f in range(F)], dtype=np.float64) | |
| edv = float(vols[ed_idx]) | |
| esv = float(vols[es_idx]) | |
| sv = float(edv - esv) | |
| ef = float((sv / edv * 100.0) if edv > 0 else 0.0) | |
| return { | |
| "ed_idx": ed_idx, | |
| "es_idx": es_idx, | |
| "edv_ul": edv, | |
| "esv_ul": esv, | |
| "sv_ul": sv, | |
| "ef_percent": ef, | |
| "mass_ed_mg": float(masses[ed_idx]), | |
| "mass_es_mg": float(masses[es_idx]), | |
| } | |
| def _render_corrected_downloads_group(): | |
| """ | |
| Corrected outputs: CSV + masks + GIFs. | |
| """ | |
| if "csv_bytes_corrected" not in st.session_state or "csv_name_corrected" not in st.session_state: | |
| return | |
| st.markdown("#### Updated Results") | |
| if "mask_zip_bytes_corrected" not in st.session_state: | |
| build_corrected_masks_zip_bytes_per_mouse() | |
| if "gif_zip_bytes_corrected" not in st.session_state: | |
| build_corrected_gif_zip_bytes() | |
| buttons = [ | |
| { | |
| "label": "Download CSV (corrected)", | |
| "bytes": st.session_state["csv_bytes_corrected"], | |
| "filename": st.session_state["csv_name_corrected"], | |
| "mime": "text/csv", | |
| "key": "corr_csv", | |
| }, | |
| { | |
| "label": "Download Masks (corrected)", | |
| "bytes": st.session_state["mask_zip_bytes_corrected"], | |
| "filename": st.session_state["mask_zip_name_corrected"], | |
| "mime": "application/zip", | |
| "key": "corr_masks", | |
| }, | |
| { | |
| "label": "Download GIF (corrected)", | |
| "bytes": st.session_state["gif_zip_bytes_corrected"], | |
| "filename": st.session_state["gif_zip_name_corrected"], | |
| "mime": "application/zip", | |
| "key": "corr_gif", | |
| }, | |
| ] | |
| render_download_group(buttons, group_key="corr") | |
| def _roundel_recompute_corrected_outputs(): | |
| """ | |
| Wrapper around the existing corrected-output pipeline. | |
| """ | |
| new_rows, new_per_frame_rows, updated_any_gif = recompute_metrics_from_session_cases_dirty_gif_only() | |
| if not new_rows: | |
| st.error("No cases available to recompute metrics.") | |
| return | |
| csv_path = write_all_in_one_csv(new_rows, new_per_frame_rows, get_session_csv_dir()) | |
| base_name = "Results" | |
| if st.session_state.get("_last_zip_name"): | |
| base_name = Path(st.session_state["_last_zip_name"]).stem + "_Results_corrected" | |
| csv_download_name = f"{base_name}.csv" | |
| with open(csv_path, "rb") as fcsv: | |
| csv_bytes = fcsv.read() | |
| st.session_state["csv_bytes_corrected"] = csv_bytes | |
| st.session_state["csv_name_corrected"] = csv_download_name | |
| invalidate_download_caches() | |
| st.success( | |
| "CSV updated. Regenerated results for edited mice." | |
| if updated_any_gif | |
| else "CSV updated. No edited mice detected — GIFs unchanged." | |
| ) | |
| def roundel_lv_final_result_view(selected_pid, case): | |
| """ | |
| Roundel-style final result view, LV-only. | |
| """ | |
| if not case.get("ed_es_confirmed", False): | |
| st.error("Select and confirm ED & ES frames first.") | |
| return | |
| metrics = _roundel_metrics_for_case(case) | |
| ed_idx = int(metrics["ed_idx"]) | |
| es_idx = int(metrics["es_idx"]) | |
| # Generate/update this case GIF if needed. | |
| if case.get("dirty", False) or selected_pid not in (st.session_state.get("gif_paths_corrected", {}) or {}): | |
| if "gif_paths_corrected" not in st.session_state: | |
| st.session_state["gif_paths_corrected"] = {} | |
| gif_path = gif_animation_for_patient_pred_only( | |
| case["imgs_4d_256"], | |
| case["preds_4d"], | |
| selected_pid, | |
| ed_idx, | |
| es_idx, | |
| get_session_gifs_dir(), | |
| ) | |
| st.session_state["gif_paths_corrected"][selected_pid] = gif_path | |
| case["dirty"] = False | |
| col_metrics, col_gif = st.columns([0.30, 0.70]) | |
| with col_metrics: | |
| st.caption("LV Metrics") | |
| st.metric("EDV", f"{metrics['edv_ul']:.1f} µL") | |
| st.metric("ESV", f"{metrics['esv_ul']:.1f} µL") | |
| st.metric("SV", f"{metrics['sv_ul']:.1f} µL") | |
| st.metric("EF", f"{metrics['ef_percent']:.1f} %") | |
| st.metric("Mass (ED)", f"{metrics['mass_ed_mg']:.1f} mg") | |
| st.metric("Mass (ES)", f"{metrics['mass_es_mg']:.1f} mg") | |
| with col_gif: | |
| st.caption("Final LV Mask") | |
| gif_path = (st.session_state.get("gif_paths_corrected", {}) or {}).get(selected_pid) | |
| if gif_path and os.path.exists(gif_path): | |
| with open(gif_path, "rb") as fg: | |
| st.image(fg.read(), caption=f"{selected_pid} – ED/ES GIF (corrected)", use_column_width=True) | |
| else: | |
| st.info("No corrected GIF available yet.") | |
| st.markdown("---") | |
| if st.button("Recompute cardiac function & regenerate CSV/GIFs from corrected masks", type="primary", use_container_width=True, key="roundel_recompute_metrics_btn"): | |
| with st.spinner("Recomputing metrics and regenerating CSV/GIFs from corrected masks..."): | |
| _roundel_recompute_corrected_outputs() | |
| _render_corrected_downloads_group() | |
| if st.button("Mark review complete ✅", use_container_width=True, key=f"roundel_mark_done_{selected_pid}"): | |
| case["manual_review_done"] = True | |
| st.success("Review marked as complete.") | |
| def roundel_lv_editor_app(): | |
| """ | |
| Top-level Roundel LV Editor tab. | |
| """ | |
| st.write("# Roundel App (2D LV)") | |
| if "cases" not in st.session_state or not st.session_state["cases"]: | |
| st.info("Process a ZIP file first in the **Segmentation App** tab.") | |
| return | |
| # No extra red enable button. | |
| # The editor is enabled directly from the Segmentation App tab when the user selects Yes. | |
| if not st.session_state.get("open_roundel_lv_editor", False): | |
| st.info( | |
| "Manual correction is optional. Select **Yes** in the **Segmentation App** tab " | |
| "to enable the Roundel LV Editor." | |
| ) | |
| return | |
| cases = st.session_state["cases"] | |
| case_ids = list(cases.keys()) | |
| col1, col2 = st.columns([0.35, 0.65]) | |
| with col1: | |
| selected_pid = st.selectbox( | |
| "Select NIfTI / mouse ID", | |
| options=case_ids, | |
| index=0, | |
| format_func=lambda pid: _roundel_case_label(pid), | |
| key="roundel_sax_series_uid_selectbox", | |
| ) | |
| case = cases[selected_pid] | |
| with col2: | |
| seg_spacing = case.get("seg_spacing", {}) | |
| native_spacing = case.get("native_spacing", {}) | |
| frame_interval = native_spacing.get("frame_time_ms", None) | |
| frame_interval_label = f"{float(frame_interval):.3g} ms" if frame_interval else "not available" | |
| st.markdown( | |
| f"**Pixel Size**: {seg_spacing.get('row_mm', np.nan):.4g} x " | |
| f"{seg_spacing.get('col_mm', np.nan):.4g} mm | " | |
| f"**Slice Thickness**: {seg_spacing.get('slice_thickness_mm', np.nan):.4g} mm | " | |
| f"**Frame Interval**: {frame_interval_label}" | |
| ) | |
| done_key = f"roundel_case_done__{selected_pid}" | |
| note_key = f"roundel_case_note__{selected_pid}" | |
| if done_key not in st.session_state: | |
| st.session_state[done_key] = bool(case.get("manual_review_done", False)) | |
| if note_key not in st.session_state: | |
| st.session_state[note_key] = str(case.get("manual_note", "")) | |
| done_val = st.checkbox("Status: Complete", key=done_key) | |
| case["manual_review_done"] = bool(done_val) | |
| note_val = st.text_area( | |
| "Notes", | |
| key=note_key, | |
| height=90, | |
| placeholder="Type notes for this case…", | |
| ) | |
| case["manual_note"] = str(note_val) | |
| view = _streamlit_segmented_control( | |
| "Tab", | |
| options=["EDV/ESV Finder 🔍", "Mask Editor 📝", "Final Result ✅"], | |
| default="EDV/ESV Finder 🔍", | |
| key="roundel_view_segmented", | |
| label_visibility="hidden", | |
| ) | |
| st.divider() | |
| if view == "EDV/ESV Finder 🔍": | |
| roundel_lv_edv_esv_view(selected_pid, case) | |
| if view == "Mask Editor 📝": | |
| roundel_lv_mask_editor_view(selected_pid, case) | |
| if view == "Final Result ✅": | |
| roundel_lv_final_result_view(selected_pid, case) | |
| def switch_to_tab_by_label(tab_label: str): | |
| """ | |
| Programmatically click a Streamlit tab by its visible label. | |
| Streamlit does not currently provide a native Python API to select st.tabs(). | |
| This JavaScript clicks the tab button in the parent document. | |
| Notes: | |
| - This is a UI workaround. | |
| - It works by matching the visible tab text exactly. | |
| - If Streamlit changes its internal DOM, this may need updating. | |
| """ | |
| components.html( | |
| f""" | |
| <script> | |
| (function() {{ | |
| const tabLabel = {json.dumps(tab_label)}; | |
| const doc = window.parent.document; | |
| function normaliseText(x) {{ | |
| return (x || '').replace(/\\s+/g, ' ').trim(); | |
| }} | |
| function clickTab() {{ | |
| const tabs = Array.from(doc.querySelectorAll('button[role="tab"]')); | |
| const target = tabs.find(tab => normaliseText(tab.innerText) === tabLabel); | |
| if (target) {{ | |
| target.click(); | |
| target.scrollIntoView({{behavior: "smooth", block: "nearest", inline: "center"}}); | |
| return true; | |
| }} | |
| return false; | |
| }} | |
| if (!clickTab()) {{ | |
| setTimeout(clickTab, 100); | |
| setTimeout(clickTab, 250); | |
| setTimeout(clickTab, 500); | |
| setTimeout(clickTab, 1000); | |
| }} | |
| }})(); | |
| </script> | |
| """, | |
| height=0, | |
| ) | |
| def _manual_correction_radio_changed(): | |
| """ | |
| When the user selects Yes/No in the Segmentation App tab: | |
| - Yes -> enable Roundel LV Editor and request tab switch | |
| - No -> disable Roundel LV Editor | |
| """ | |
| choice = st.session_state.get("manual_correction_choice", "No") | |
| if choice == "Yes": | |
| st.session_state["open_roundel_lv_editor"] = True | |
| st.session_state["_switch_to_roundel_lv_editor"] = True | |
| else: | |
| st.session_state["open_roundel_lv_editor"] = False | |
| st.session_state["_switch_to_roundel_lv_editor"] = False | |
| # ========================= | |
| # UI | |
| # ========================= | |
| def main(): | |
| st.set_page_config(layout="wide", initial_sidebar_state="collapsed") | |
| _inject_layout_css() | |
| st.markdown( | |
| f''' | |
| <div id="fixed-edge-logo" aria-hidden="true" role="presentation"> | |
| <img src="{LOGO_URL}" alt="Pre-Clinical Cardiac MRI Segmentation"> | |
| </div> | |
| <div class="edge-logo-spacer"></div> | |
| ''', | |
| unsafe_allow_html=True, | |
| ) | |
| tab1, tab_roundel, tab2, tab3 = st.tabs( | |
| ["Segmentation App", "Roundel LV Editor", "Dataset", "NIfTI converter"] | |
| ) | |
| # ===== Tab 1: Segmentation App ===== | |
| with tab1: | |
| st.markdown( | |
| """ | |
| <style> | |
| [data-testid="stHeading"] a, | |
| h1 a[href^="#"], | |
| h2 a[href^="#"], | |
| h3 a[href^="#"] { | |
| display: none !important; | |
| visibility: hidden !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| HERO_HTML = dedent( | |
| """\ | |
| <div class="content-wrap"> | |
| <div class="measure-wrap"> | |
| <div class="text-wrap"> | |
| <h1 class="hero-title"> | |
| Open-Source Pre-Clinical Image Segmentation:<br> | |
| Mouse cardiac MRI datasets with a deep learning segmentation framework | |
| </h1> | |
| </div> | |
| <div class="text-wrap"> | |
| <p>We present the first publicly-available pre-clinical cardiac MRI dataset, along with an open-source DL segmentation model (both available on GitHub: | |
| <a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank" rel="noopener">https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git</a>) and this web-based interface for easy deployment.</p> | |
| <p>The dataset comprises complete cine short-axis cardiac MRI images from 130 mice with diverse phenotypes. It also contains expert manual segmentations of left ventricular (LV) blood pool and myocardium at end-diastole, end-systole, as well as additional timeframes with artefacts to improve robustness.</p> | |
| <p>Using this resource, we developed an open-source DL segmentation model based on the UNet3+ architecture.</p> | |
| <p>This Streamlit application consists of the inference model to provide an easy-to-use interface for our DL segmentation model, without the need for local installation. The application requires the complete SAX cine image data to be uploaded in NIfTI format, as a ZIP file using the simple file browser below.</p> | |
| <p>Pre-processing and inference are then performed on all 2D images. The resulting blood-pool and myocardial volumes are combined across all slices at each timeframe and output to a .csv file. The blood-pool volumes are used to identify ED and ES, and these volumes are displayed as a GIF with the segmentations overlaid.</p> | |
| <p class="note-text">(Note: This Hugging Face model was developed as part of a manuscript submitted to the <em>Journal of Cardiovascular Magnetic Resonance</em>)</p> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| st.markdown(HERO_HTML, unsafe_allow_html=True) | |
| st.markdown( | |
| '<div class="content-wrap"><div class="measure-wrap" id="upload-wrap">', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| """ | |
| <h2 style='margin-bottom:0.2rem;'> | |
| Data Upload <span style='font-size:33px;'>📤</span> | |
| </h2> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| uploaded_zip = st.file_uploader( | |
| "Upload ZIP of NIfTI files 🐭", | |
| type="zip", | |
| label_visibility="visible", | |
| ) | |
| st.markdown( | |
| """ | |
| <p style="margin-top:0.3rem; font-size:15px; color:#444;"> | |
| Or download our <a href="https://huggingface.co/spaces/mrphys/Pre-clinical_DL_segmentation/tree/main/NIfTI_dataset" target="_blank" rel="noopener"> | |
| sample NIfTI dataset</a> to try it out! | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("</div></div>", unsafe_allow_html=True) | |
| # Reset state when a new ZIP is uploaded (by CONTENT, not only name) | |
| if uploaded_zip is not None: | |
| zip_bytes_now = uploaded_zip.getvalue() | |
| zip_hash_now = hashlib.sha256(zip_bytes_now).hexdigest() | |
| if st.session_state.get("_last_zip_hash") != zip_hash_now: | |
| # Clear all dynamic Roundel/manual-correction keys from previous uploads | |
| for k in list(st.session_state.keys()): | |
| if ( | |
| k.startswith("roundel_") | |
| or k.startswith("_roundel_") | |
| or k.startswith("manual_correction") | |
| or k.startswith("_manual_correction") | |
| ): | |
| st.session_state.pop(k, None) | |
| for key in [ | |
| "csv_bytes_orig", | |
| "csv_name_orig", | |
| "csv_bytes_corrected", | |
| "csv_name_corrected", | |
| "rows_count", | |
| "cases", | |
| "gif_paths", | |
| "gif_paths_original", | |
| "gif_paths_corrected", | |
| "session_tmpdir", | |
| # cached zips | |
| "mask_zip_bytes", | |
| "mask_zip_name", | |
| "gif_zip_bytes", | |
| "gif_zip_name", | |
| "mask_zip_bytes_corrected", | |
| "mask_zip_name_corrected", | |
| "gif_zip_bytes_corrected", | |
| "gif_zip_name_corrected", | |
| # processing state | |
| "processing_done", | |
| "processing_running", | |
| # manual correction / Roundel state | |
| "open_roundel_lv_editor", | |
| "manual_correction_choice", | |
| "_switch_to_roundel_lv_editor", | |
| "roundel_view_segmented", | |
| "roundel_sax_series_uid_selectbox", | |
| "roundel_overlay_overlapped_all", | |
| "roundel_overlay_lv_bp", | |
| "roundel_overlay_lv_myo", | |
| "roundel_brush_mode", | |
| "roundel_stroke_width_preset", | |
| "roundel_stroke_width_value", | |
| "_roundel_last_stroke_preset", | |
| ]: | |
| st.session_state.pop(key, None) | |
| invalidate_download_caches() | |
| st.session_state["_last_zip_hash"] = zip_hash_now | |
| st.session_state["_last_zip_name"] = uploaded_zip.name | |
| st.session_state["_last_zip_bytes"] = zip_bytes_now # reuse without re-reading | |
| def extract_zip(zip_path, extract_to): | |
| os.makedirs(extract_to, exist_ok=True) | |
| with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
| valid_files = [ | |
| f | |
| for f in zip_ref.namelist() | |
| if "__MACOSX" not in f and not os.path.basename(f).startswith("._") | |
| ] | |
| zip_ref.extractall(extract_to, members=valid_files) | |
| # --- Process button / ZIP pipeline --- | |
| if uploaded_zip and st.button("Process Data"): | |
| # ✅ processing flags | |
| st.session_state["processing_running"] = True | |
| st.session_state["processing_done"] = False | |
| zip_label = uploaded_zip.name or "ZIP" | |
| # ✅ 1) fixed-position spinner placeholder (right after button) | |
| spinner_ph = st.empty() | |
| # ✅ 2) progress + status BELOW spinner | |
| prog = st.progress(0.0) | |
| status = st.empty() | |
| # ✅ 3) GIF gallery BELOW everything (so it doesn't push spinner) | |
| gif_gallery = st.container() | |
| # Use cached bytes (set in the reset block); fallback to getvalue() | |
| zip_bytes = st.session_state.get("_last_zip_bytes", uploaded_zip.getvalue()) | |
| try: | |
| # ✅ render spinner INSIDE placeholder so it stays in one fixed slot | |
| with spinner_ph: | |
| with st.spinner(f"Processing {zip_label}..."): | |
| tmpdir = tempfile.mkdtemp() | |
| zpath = os.path.join(tmpdir, uploaded_zip.name) | |
| # write uploaded zip to disk (DON'T use uploaded_zip.read()) | |
| with open(zpath, "wb") as f: | |
| f.write(zip_bytes) | |
| extract_zip(zpath, tmpdir) | |
| nii_files: List[str] = [] | |
| for root, _, files in os.walk(tmpdir): | |
| for fn in files: | |
| low = fn.lower() | |
| if low.endswith(".nii") or low.endswith(".nii.gz"): | |
| nii_files.append(os.path.join(root, fn)) | |
| if not nii_files: | |
| st.error("No NIfTI files (.nii / .nii.gz) found in the uploaded ZIP.") | |
| else: | |
| model = get_model() | |
| log("[MODEL] Loaded segmentation model (cached).") | |
| rows: List[Dict] = [] | |
| per_frame_rows: List[pd.DataFrame] = [] | |
| st.session_state["cases"] = {} | |
| st.session_state["gif_paths"] = {} | |
| nii_files_sorted = sorted(nii_files) | |
| total = len(nii_files_sorted) | |
| # Start at 0% | |
| prog.progress(0.0) | |
| status.markdown( | |
| f"**ZIP:** `{zip_label}` \n" | |
| f"**Progress:** 0% (0/{total}) \n" | |
| f"**Now:** `-` \n" | |
| f"**Step:** Waiting..." | |
| ) | |
| for i, fp in enumerate(nii_files_sorted, start=1): | |
| name = Path(fp).name | |
| pid = name.replace(".nii.gz", "").replace(".nii", "") | |
| # Fraction completed by finished files | |
| base = (i - 1) / max(total, 1) | |
| span = 1.0 / max(total, 1) | |
| # Avoid backwards jitter (per-file) | |
| overall_state = {"last_overall": float(base)} | |
| def file_progress_cb(file_frac: float, msg: Optional[str] = None): | |
| """ | |
| file_frac: 0..1 for THIS file; map to global progress 0..1 | |
| msg: step message (optional) | |
| """ | |
| overall = base + span * float(np.clip(file_frac, 0.0, 1.0)) | |
| # Never go backwards | |
| overall = max(overall, overall_state["last_overall"]) | |
| overall_state["last_overall"] = overall | |
| prog.progress(float(np.clip(overall, 0.0, 1.0))) | |
| pct = int(round(overall * 100)) | |
| line_msg = msg or "Working..." | |
| status.markdown( | |
| f"**ZIP:** `{zip_label}` \n" | |
| f"**Progress:** {pct}% ({i}/{total}) \n" | |
| f"**Now:** `{pid}` \n" | |
| f"**Step:** {line_msg}" | |
| ) | |
| # Make sure UI shows which file started immediately | |
| file_progress_cb(0.0, "Starting...") | |
| try: | |
| process_nifti_case( | |
| fp, | |
| model, | |
| rows, | |
| per_frame_rows, | |
| progress_cb=file_progress_cb, # ✅ weighted, real progress + step messages | |
| ) | |
| # ✅ show GIF immediately (stacking) | |
| gif_path = st.session_state.get("gif_paths", {}).get(pid) | |
| if gif_path and os.path.exists(gif_path): | |
| with gif_gallery: | |
| with open(gif_path, "rb") as fgif: | |
| st.image( | |
| fgif.read(), | |
| caption=f"{pid} – ED/ES GIF (original)", | |
| use_column_width=True, | |
| ) | |
| except Exception as e: | |
| st.warning(f"Failed: {name} — {e}") | |
| # After each file, force progress to at least the file boundary | |
| file_progress_cb(1.0, "Done") | |
| # Ensure end at 100% | |
| prog.progress(1.0) | |
| status.markdown(f"✅ **Done:** `{zip_label}` — {total}/{total} (100%)") | |
| csv_path = write_all_in_one_csv(rows, per_frame_rows, get_session_csv_dir()) | |
| csv_download_name = f"{Path(zip_label).stem}_Results.csv" | |
| with open(csv_path, "rb") as fcsv: | |
| st.session_state["csv_bytes_orig"] = fcsv.read() | |
| st.session_state["csv_name_orig"] = csv_download_name | |
| st.session_state["rows_count"] = len(rows) | |
| st.session_state["gif_paths_original"] = dict(st.session_state.get("gif_paths", {})) | |
| finally: | |
| # ✅ always clear spinner + set flags, even if something errors | |
| spinner_ph.empty() | |
| st.session_state["processing_running"] = False | |
| st.session_state["processing_done"] = True | |
| invalidate_download_caches() | |
| # --- Reserve stable slots so layout doesn't jump --- | |
| downloads_slot = st.container() | |
| downloads_spacer = st.empty() | |
| if not st.session_state.get("processing_done", False): | |
| downloads_spacer.markdown( | |
| "<div style='min-height:120px'></div>", | |
| unsafe_allow_html=True | |
| ) | |
| else: | |
| downloads_spacer.empty() | |
| with downloads_slot: | |
| # Show the green success banner once (only if results exist) | |
| if "csv_bytes_orig" in st.session_state and "csv_name_orig" in st.session_state: | |
| st.success(f"Processed {st.session_state.get('rows_count', 0)} NIfTI file(s).") | |
| # ========================= | |
| # ORIGINAL downloads group | |
| # ========================= | |
| # 1) Ensure ORIGINAL mask zip exists | |
| if "cases" in st.session_state and st.session_state["cases"]: | |
| if "mask_zip_bytes" not in st.session_state: | |
| zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem | |
| mask_zip_name = f"{zipstem}_Mask.zip" | |
| mask_files = [] | |
| for pid, case in st.session_state["cases"].items(): | |
| native_h, native_w, S, F = case["native_shape"] | |
| affine = case["affine"] | |
| mask_native = resize_masks_to_native(case["preds_4d_orig"], native_h, native_w) | |
| session_tmp = get_session_tmpdir() | |
| tmp_mask_path = os.path.join(session_tmp, f"{pid}_mask.nii.gz") | |
| save_native_mask_nifti_with_labels(mask_native, affine, tmp_mask_path) | |
| with open(tmp_mask_path, "rb") as f: | |
| mask_files.append((f"{pid}_mask.nii.gz", f.read())) | |
| try: | |
| os.remove(tmp_mask_path) | |
| except OSError: | |
| pass | |
| st.session_state["mask_zip_bytes"] = build_zip_bytes(mask_files, root_folder=None) | |
| st.session_state["mask_zip_name"] = mask_zip_name | |
| # 2) Ensure ORIGINAL gif zip exists | |
| if "gif_paths_original" in st.session_state: | |
| if "gif_zip_bytes" not in st.session_state: | |
| zipstem = Path(st.session_state.get("_last_zip_name", "Results")).stem | |
| gif_root = f"{zipstem}_GIF" | |
| gif_zip_name = f"{gif_root}.zip" | |
| gif_files = [] | |
| orig = st.session_state.get("gif_paths_original", {}) or {} | |
| for pid, gif_path in orig.items(): | |
| if gif_path and os.path.exists(gif_path): | |
| with open(gif_path, "rb") as f: | |
| gif_files.append((f"{pid}.gif", f.read())) | |
| st.session_state["gif_zip_bytes"] = build_zip_bytes(gif_files, root_folder=gif_root) | |
| st.session_state["gif_zip_name"] = gif_zip_name | |
| # 3) Build ONE button group | |
| buttons = [] | |
| if "csv_bytes_orig" in st.session_state and "csv_name_orig" in st.session_state: | |
| buttons.append({ | |
| "label": "Download CSV", | |
| "bytes": st.session_state["csv_bytes_orig"], | |
| "filename": st.session_state["csv_name_orig"], | |
| "mime": "text/csv", | |
| "key": "orig_csv", | |
| }) | |
| if "mask_zip_bytes" in st.session_state and "mask_zip_name" in st.session_state: | |
| buttons.append({ | |
| "label": "Download Masks", | |
| "bytes": st.session_state["mask_zip_bytes"], | |
| "filename": st.session_state["mask_zip_name"], | |
| "mime": "application/zip", | |
| "key": "orig_masks", | |
| }) | |
| if "gif_zip_bytes" in st.session_state and "gif_zip_name" in st.session_state: | |
| buttons.append({ | |
| "label": "Download GIF", | |
| "bytes": st.session_state["gif_zip_bytes"], | |
| "filename": st.session_state["gif_zip_name"], | |
| "mime": "application/zip", | |
| "key": "orig_gif", | |
| }) | |
| if buttons: | |
| render_download_group(buttons, group_key="orig") | |
| # ========================================================== | |
| # Optional manual correction entry point | |
| # ========================================================== | |
| if "cases" in st.session_state and st.session_state["cases"]: | |
| st.markdown("---") | |
| st.markdown("## Optional: Manual correction") | |
| manual_choice = st.radio( | |
| "Would you like to manually correct the LV segmentation masks?", | |
| options=["No", "Yes"], | |
| horizontal=True, | |
| key="manual_correction_choice", | |
| on_change=_manual_correction_radio_changed, | |
| ) | |
| # Keep state synced | |
| st.session_state["open_roundel_lv_editor"] = (manual_choice == "Yes") | |
| if manual_choice == "Yes": | |
| st.success("Manual correction enabled. Opening the **Roundel LV Editor** tab...") | |
| # Switch only once after user selects Yes | |
| if st.session_state.get("_switch_to_roundel_lv_editor", False): | |
| st.session_state["_switch_to_roundel_lv_editor"] = False | |
| switch_to_tab_by_label("Roundel LV Editor") | |
| else: | |
| st.session_state["_switch_to_roundel_lv_editor"] = False | |
| st.caption( | |
| "Manual correction is optional. You can use the automatic CSV, mask and GIF downloads above." | |
| ) | |
| # ===== Tab 2: Roundel LV Editor ===== | |
| with tab_roundel: | |
| roundel_lv_editor_app() | |
| # ===== Tab 2: Dataset ===== | |
| with tab2: | |
| st.markdown( | |
| """ | |
| <style> | |
| .ds-hero-section { | |
| background: #082c3a; | |
| padding: 30px 10px; | |
| text-align: center; | |
| margin-left: -100vw; | |
| margin-right: -100vw; | |
| left: 0; right: 0; position: relative; | |
| } | |
| .ds-hero-section-inner { max-width: 1100px; margin: 0 auto; } | |
| .ds-heroimg { | |
| max-width: 1000px; width: 100%; height: auto; | |
| border-radius: 10px; box-shadow: 0 8px 24px rgba(0,0,0,.25); | |
| display: block; margin: 0 auto; | |
| } | |
| .ds-caption { | |
| text-align: center; color: #e0f2f1; | |
| font-size: 18px; line-height: 1.5; | |
| margin: 14px 6px 0; font-style: italic; | |
| } | |
| .ds-hr { | |
| height: 8px; | |
| border: 0; background: #ea580c; | |
| margin: 24px 0 20px; | |
| border-radius: 3px; | |
| } | |
| .ds-wrap { | |
| max-width: var(--content-measure, 920px); | |
| margin-left: var(--left-offset, 40px); | |
| margin-right: auto; | |
| background: #fff; padding: 16px 24px; border-radius: 6px; | |
| } | |
| .ds-section h2 { | |
| font-size: 20px; font-weight: 700; | |
| margin: 0 0 2px; | |
| color: #082c3a; | |
| } | |
| .ds-section p { | |
| font-size: 16px; line-height: 1.6; color: #333; | |
| margin: 0 0 6px; | |
| } | |
| .ds-section ul { | |
| margin: 2px 0 8px 18px; | |
| padding: 0; | |
| } | |
| .ds-section li { | |
| font-size: 16px; line-height: 1.6; color: #333; | |
| margin-bottom: 10px; | |
| } | |
| .ds-section a { | |
| color: #0b66c3 !important; text-decoration: underline !important; | |
| } | |
| h2 a, [data-testid="stHeading"] a { display: none !important; } | |
| </style> | |
| <div class="ds-hero-section"> | |
| <div class="ds-hero-section-inner"> | |
| <img class="ds-heroimg" | |
| src="https://raw.githubusercontent.com/whanisa/Segmentation/main/icon/open_source.png" | |
| alt="Illustration of mouse with heart representing open-source pre-clinical cardiac MRI dataset" /> | |
| <p class="ds-caption"> | |
| The first publicly-available pre-clinical cardiac MRI dataset,<br/> | |
| with an open-source segmentation model and an easy-to-use web app. | |
| </p> | |
| </div> | |
| </div> | |
| <hr class="ds-hr"/> | |
| <div class="ds-wrap"> | |
| <div class="ds-section"> | |
| <h2>Repository & Paper Resources</h2> | |
| <p>GitHub: | |
| <a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation.git" target="_blank"> | |
| Open-Source_Pre-Clinical_Segmentation | |
| </a> | |
| </p> | |
| <h2>📊 Dataset Availability</h2> | |
| <ul> | |
| <li> | |
| <strong>Full dataset (130 mice, HDF5 format):</strong><br/> | |
| Available in our | |
| <a href="https://github.com/mrphys/Open-Source_Pre-Clinical_Segmentation/tree/master/Data" target="_blank"> | |
| GitHub repository | |
| </a>.<br/> | |
| Each .h5 file contains the complete cine SAX MRI and expert manual segmentations. | |
| </li> | |
| <li> | |
| <strong>Sample datasets (3 mice, NIfTI format):</strong><br/> | |
| Available here: | |
| <a href="https://huggingface.co/spaces/mrphys/Pre-clinical_DL_segmentation/tree/main/NIfTI_dataset" target="_blank"> | |
| NIfTI Sample Dataset | |
| </a>.<br/> | |
| We provide 3 example NIfTI datasets for quick download and direct use within the app. | |
| </li> | |
| </ul> | |
| </div> | |
| <hr class="ds-hr"/> | |
| <div class="ds-section"> | |
| <h2>Notes</h2> | |
| <ul> | |
| <li>Complete SAX cine MRI for 130 mice with expert LV blood & myocardium labels (ED/ES).</li> | |
| </ul> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ===== Tab 3: NIfTI converter ===== | |
| with tab3: | |
| st.subheader("NIfTI converter") | |
| st.markdown( | |
| """ | |
| **Working with Agilent data?** | |
| Easily convert your fid files to NIfTI using our **fid2niix**. | |
| ```bash | |
| fid2niix -z y -o /path/to/out -f "%p_%s" /path/to/fid_folder | |
| ``` | |
| **💡 Tips** | |
| - Upload a ZIP file that includes both `fid` and `procpar`. | |
| - Conversion outputs **NIfTI-1** format, ready to use with our web app. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| main() | |