|
|
"""Neuroimaging visualization for Gradio. |
|
|
|
|
|
This module provides visualization components for neuroimaging data: |
|
|
- Matplotlib-based 2D slice comparisons |
|
|
- NIfTI URL helper for Custom Component |
|
|
|
|
|
See: |
|
|
- docs/specs/07-hf-spaces-deployment.md |
|
|
- docs/specs/19-perf-base64-to-file-urls.md (Issue #19 optimization) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from pathlib import Path |
|
|
|
|
|
from matplotlib.figure import Figure |
|
|
|
|
|
from stroke_deepisles_demo.core.logging import get_logger |
|
|
from stroke_deepisles_demo.metrics import load_nifti_as_array |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def nifti_to_gradio_url(nifti_path: Path) -> str: |
|
|
""" |
|
|
Get Gradio file URL for a NIfTI file. |
|
|
|
|
|
Uses Gradio's built-in file serving instead of base64 encoding. |
|
|
This reduces payload size by ~33% and improves browser performance |
|
|
by avoiding large base64 strings in the DOM. |
|
|
|
|
|
Args: |
|
|
nifti_path: Path to NIfTI file. Must be in an allowed path: |
|
|
- tempfile.gettempdir() (default for pipeline results) |
|
|
- Current working directory |
|
|
- Paths specified in allowed_paths during launch() |
|
|
|
|
|
Returns: |
|
|
Gradio file URL (e.g., /gradio_api/file=/tmp/.../dwi.nii.gz) |
|
|
|
|
|
Note: |
|
|
This replaces the deprecated nifti_to_data_url() function. |
|
|
See Issue #19 for performance analysis and benchmarks. |
|
|
|
|
|
References: |
|
|
- https://www.gradio.app/guides/file-access |
|
|
- https://niivue.com/docs/loading/ |
|
|
""" |
|
|
|
|
|
abs_path = nifti_path.resolve() |
|
|
|
|
|
|
|
|
|
|
|
return f"/gradio_api/file={abs_path}" |
|
|
|
|
|
|
|
|
def get_slice_at_max_lesion( |
|
|
mask_path: Path, |
|
|
orientation: str = "axial", |
|
|
) -> int: |
|
|
""" |
|
|
Find slice index with maximum lesion area. |
|
|
|
|
|
Useful for displaying the most informative slice. |
|
|
|
|
|
Args: |
|
|
mask_path: Path to lesion mask NIfTI |
|
|
orientation: Slice orientation ("axial", "coronal", "sagittal") |
|
|
|
|
|
Returns: |
|
|
Slice index with maximum lesion area |
|
|
""" |
|
|
data, _ = load_nifti_as_array(mask_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if orientation == "sagittal": |
|
|
|
|
|
lesion_counts = np.sum(data > 0, axis=(1, 2)) |
|
|
elif orientation == "coronal": |
|
|
|
|
|
lesion_counts = np.sum(data > 0, axis=(0, 2)) |
|
|
else: |
|
|
|
|
|
lesion_counts = np.sum(data > 0, axis=(0, 1)) |
|
|
|
|
|
max_slice = int(np.argmax(lesion_counts)) |
|
|
|
|
|
|
|
|
if np.max(lesion_counts) == 0: |
|
|
if orientation == "sagittal": |
|
|
return int(data.shape[0] // 2) |
|
|
elif orientation == "coronal": |
|
|
return int(data.shape[1] // 2) |
|
|
else: |
|
|
return int(data.shape[2] // 2) |
|
|
|
|
|
return max_slice |
|
|
|
|
|
|
|
|
def render_3panel_view( |
|
|
nifti_path: Path, |
|
|
mask_path: Path | None = None, |
|
|
*, |
|
|
mask_alpha: float = 0.5, |
|
|
) -> Figure: |
|
|
""" |
|
|
Render axial/coronal/sagittal slices with optional mask overlay. |
|
|
|
|
|
Args: |
|
|
nifti_path: Path to base NIfTI volume |
|
|
mask_path: Optional path to mask for overlay |
|
|
mask_alpha: Transparency of mask overlay |
|
|
|
|
|
Returns: |
|
|
Matplotlib figure with 3-panel view |
|
|
""" |
|
|
data, _ = load_nifti_as_array(nifti_path) |
|
|
mask_data = None |
|
|
if mask_path: |
|
|
mask_data, _ = load_nifti_as_array(mask_path) |
|
|
|
|
|
|
|
|
mid_x, mid_y, mid_z = data.shape[0] // 2, data.shape[1] // 2, data.shape[2] // 2 |
|
|
|
|
|
if mask_data is not None and np.any(mask_data > 0): |
|
|
|
|
|
|
|
|
coords = np.argwhere(mask_data > 0) |
|
|
center = coords.mean(axis=0).astype(int) |
|
|
mid_x, mid_y, mid_z = center[0], center[1], center[2] |
|
|
|
|
|
|
|
|
fig = Figure(figsize=(15, 5)) |
|
|
fig.patch.set_facecolor("black") |
|
|
axes = fig.subplots(1, 3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax_slice = np.rot90(data[:, :, mid_z]) |
|
|
axes[0].imshow(ax_slice, cmap="gray") |
|
|
axes[0].set_title(f"Axial (z={mid_z})", color="white") |
|
|
if mask_data is not None: |
|
|
m_slice = np.rot90(mask_data[:, :, mid_z]) |
|
|
|
|
|
m_slice_binary = (m_slice > 0.5).astype(float) |
|
|
axes[0].imshow( |
|
|
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), |
|
|
cmap="Reds", |
|
|
alpha=mask_alpha, |
|
|
vmin=0, |
|
|
vmax=1, |
|
|
) |
|
|
|
|
|
|
|
|
cor_slice = np.rot90(data[:, mid_y, :]) |
|
|
axes[1].imshow(cor_slice, cmap="gray") |
|
|
axes[1].set_title(f"Coronal (y={mid_y})", color="white") |
|
|
if mask_data is not None: |
|
|
m_slice = np.rot90(mask_data[:, mid_y, :]) |
|
|
|
|
|
m_slice_binary = (m_slice > 0.5).astype(float) |
|
|
axes[1].imshow( |
|
|
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), |
|
|
cmap="Reds", |
|
|
alpha=mask_alpha, |
|
|
vmin=0, |
|
|
vmax=1, |
|
|
) |
|
|
|
|
|
|
|
|
sag_slice = np.rot90(data[mid_x, :, :]) |
|
|
axes[2].imshow(sag_slice, cmap="gray") |
|
|
axes[2].set_title(f"Sagittal (x={mid_x})", color="white") |
|
|
if mask_data is not None: |
|
|
m_slice = np.rot90(mask_data[mid_x, :, :]) |
|
|
|
|
|
m_slice_binary = (m_slice > 0.5).astype(float) |
|
|
axes[2].imshow( |
|
|
np.ma.masked_where(m_slice_binary == 0, m_slice_binary), |
|
|
cmap="Reds", |
|
|
alpha=mask_alpha, |
|
|
vmin=0, |
|
|
vmax=1, |
|
|
) |
|
|
|
|
|
for ax in axes: |
|
|
ax.axis("off") |
|
|
|
|
|
fig.tight_layout() |
|
|
return fig |
|
|
|
|
|
|
|
|
def render_slice_comparison( |
|
|
dwi_path: Path, |
|
|
prediction_path: Path, |
|
|
ground_truth_path: Path | None = None, |
|
|
*, |
|
|
slice_idx: int | None = None, |
|
|
orientation: str = "axial", |
|
|
) -> Figure: |
|
|
""" |
|
|
Render side-by-side comparison of DWI, prediction, and ground truth. |
|
|
|
|
|
Args: |
|
|
dwi_path: Path to DWI NIfTI |
|
|
prediction_path: Path to predicted mask NIfTI |
|
|
ground_truth_path: Optional path to ground truth mask |
|
|
slice_idx: Slice index (default: max lesion or middle) |
|
|
orientation: One of "axial", "coronal", "sagittal" |
|
|
|
|
|
Returns: |
|
|
Matplotlib figure with comparison view |
|
|
""" |
|
|
dwi_data, _ = load_nifti_as_array(dwi_path) |
|
|
pred_data, _ = load_nifti_as_array(prediction_path) |
|
|
gt_data = None |
|
|
if ground_truth_path: |
|
|
gt_data, _ = load_nifti_as_array(ground_truth_path) |
|
|
|
|
|
|
|
|
if slice_idx is None: |
|
|
|
|
|
slice_idx = get_slice_at_max_lesion(prediction_path, orientation) |
|
|
|
|
|
|
|
|
|
|
|
if orientation == "sagittal": |
|
|
|
|
|
d_slice = np.rot90(dwi_data[slice_idx, :, :]) |
|
|
p_slice = np.rot90(pred_data[slice_idx, :, :]) |
|
|
g_slice = np.rot90(gt_data[slice_idx, :, :]) if gt_data is not None else None |
|
|
elif orientation == "coronal": |
|
|
|
|
|
d_slice = np.rot90(dwi_data[:, slice_idx, :]) |
|
|
p_slice = np.rot90(pred_data[:, slice_idx, :]) |
|
|
g_slice = np.rot90(gt_data[:, slice_idx, :]) if gt_data is not None else None |
|
|
else: |
|
|
|
|
|
d_slice = np.rot90(dwi_data[:, :, slice_idx]) |
|
|
p_slice = np.rot90(pred_data[:, :, slice_idx]) |
|
|
g_slice = np.rot90(gt_data[:, :, slice_idx]) if gt_data is not None else None |
|
|
|
|
|
|
|
|
num_plots = 3 if gt_data is not None else 2 |
|
|
|
|
|
fig = Figure(figsize=(5 * num_plots, 5)) |
|
|
fig.patch.set_facecolor("black") |
|
|
axes = fig.subplots(1, num_plots) |
|
|
|
|
|
if num_plots == 2: |
|
|
axes = np.array(axes) |
|
|
|
|
|
|
|
|
axes[0].imshow(d_slice, cmap="gray") |
|
|
axes[0].set_title("DWI Input", color="white") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p_slice_binary = (p_slice > 0.5).astype(float) |
|
|
axes[1].imshow(d_slice, cmap="gray") |
|
|
axes[1].imshow( |
|
|
np.ma.masked_where(p_slice_binary == 0, p_slice_binary), |
|
|
cmap="Reds", |
|
|
alpha=0.5, |
|
|
vmin=0, |
|
|
vmax=1, |
|
|
) |
|
|
axes[1].set_title("Prediction", color="white") |
|
|
|
|
|
|
|
|
if gt_data is not None: |
|
|
axes[2].imshow(d_slice, cmap="gray") |
|
|
axes[2].imshow( |
|
|
np.ma.masked_where(g_slice == 0, g_slice), |
|
|
cmap="Greens", |
|
|
alpha=0.5, |
|
|
vmin=0, |
|
|
vmax=1, |
|
|
) |
|
|
axes[2].set_title("Ground Truth", color="white") |
|
|
|
|
|
for ax in axes: |
|
|
ax.axis("off") |
|
|
|
|
|
fig.tight_layout() |
|
|
return fig |
|
|
|