|
|
"""Main Gradio application for stroke-deepisles-demo.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import gradio as gr |
|
|
from matplotlib.figure import Figure |
|
|
|
|
|
from stroke_deepisles_demo.core.logging import get_logger |
|
|
from stroke_deepisles_demo.data import list_case_ids |
|
|
from stroke_deepisles_demo.metrics import compute_volume_ml |
|
|
from stroke_deepisles_demo.pipeline import run_pipeline_on_case |
|
|
from stroke_deepisles_demo.ui.components import ( |
|
|
create_case_selector, |
|
|
create_results_display, |
|
|
create_settings_accordion, |
|
|
) |
|
|
from stroke_deepisles_demo.ui.viewer import ( |
|
|
nifti_to_gradio_url, |
|
|
render_3panel_view, |
|
|
render_slice_comparison, |
|
|
) |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def initialize_case_selector() -> gr.Dropdown: |
|
|
""" |
|
|
Initialize case selector by loading dataset (lazy load). |
|
|
|
|
|
This prevents the app from hanging during startup while downloading data. |
|
|
Called via demo.load() after the UI renders. |
|
|
""" |
|
|
try: |
|
|
logger.info("Initializing dataset for case selector...") |
|
|
case_ids = list_case_ids() |
|
|
|
|
|
if not case_ids: |
|
|
return gr.Dropdown(choices=[], info="No cases found in dataset.") |
|
|
|
|
|
return gr.Dropdown( |
|
|
choices=case_ids, |
|
|
value=case_ids[0], |
|
|
info="Choose a case from isles24-stroke dataset", |
|
|
interactive=True, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.exception("Failed to initialize dataset") |
|
|
return gr.Dropdown(choices=[], info=f"Error loading data: {e!s}") |
|
|
|
|
|
|
|
|
def _cleanup_previous_results(previous_results_dir: str | None) -> None: |
|
|
"""Clean up previous results directory (per-session, thread-safe). |
|
|
|
|
|
Security: Validates path is under allowed results root to prevent |
|
|
arbitrary file deletion via manipulated Gradio state. |
|
|
""" |
|
|
if previous_results_dir is None: |
|
|
return |
|
|
|
|
|
from stroke_deepisles_demo.core.config import get_settings |
|
|
|
|
|
prev_path = Path(previous_results_dir).resolve() |
|
|
allowed_root = get_settings().results_dir.resolve() |
|
|
|
|
|
|
|
|
try: |
|
|
prev_path.relative_to(allowed_root) |
|
|
except ValueError: |
|
|
logger.warning( |
|
|
"Refusing to cleanup path outside allowed root: %s (root: %s)", |
|
|
prev_path, |
|
|
allowed_root, |
|
|
) |
|
|
return |
|
|
|
|
|
if prev_path.exists(): |
|
|
try: |
|
|
shutil.rmtree(prev_path) |
|
|
logger.debug("Cleaned up previous results: %s", prev_path) |
|
|
except OSError as e: |
|
|
|
|
|
logger.warning("Failed to cleanup %s: %s", prev_path, e) |
|
|
|
|
|
|
|
|
def run_segmentation( |
|
|
case_id: str, |
|
|
fast_mode: bool, |
|
|
show_ground_truth: bool, |
|
|
previous_results_dir: str | None, |
|
|
) -> tuple[ |
|
|
dict[str, str | None] | None, |
|
|
Figure | None, |
|
|
Figure | None, |
|
|
dict[str, Any], |
|
|
str | None, |
|
|
str, |
|
|
str | None, |
|
|
]: |
|
|
""" |
|
|
Run segmentation and return results for display. |
|
|
|
|
|
Args: |
|
|
case_id: Selected case identifier |
|
|
fast_mode: Whether to use fast mode (SEALS) |
|
|
show_ground_truth: Whether to show ground truth in plots |
|
|
previous_results_dir: Path to previous results (from gr.State, for cleanup) |
|
|
|
|
|
Returns: |
|
|
Tuple of (niivue_data, slice_fig, ortho_fig, metrics_dict, download_path, status_msg, new_results_dir) |
|
|
The new_results_dir is returned to update the gr.State for next cleanup. |
|
|
""" |
|
|
if not case_id: |
|
|
return ( |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
{}, |
|
|
None, |
|
|
"Please select a case first.", |
|
|
previous_results_dir, |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
_cleanup_previous_results(previous_results_dir) |
|
|
|
|
|
logger.info("Running segmentation for %s", case_id) |
|
|
result = run_pipeline_on_case( |
|
|
case_id, |
|
|
fast=fast_mode, |
|
|
compute_dice=True, |
|
|
cleanup_staging=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dwi_path = result.input_files["dwi"] |
|
|
dwi_url = nifti_to_gradio_url(dwi_path) |
|
|
|
|
|
|
|
|
|
|
|
mask_url = None |
|
|
if result.prediction_mask.exists(): |
|
|
mask_url = nifti_to_gradio_url(result.prediction_mask) |
|
|
|
|
|
niivue_data = {"background_url": dwi_url, "overlay_url": mask_url} |
|
|
|
|
|
|
|
|
gt_path = result.ground_truth if show_ground_truth else None |
|
|
|
|
|
|
|
|
slice_fig = render_slice_comparison( |
|
|
dwi_path=dwi_path, |
|
|
prediction_path=result.prediction_mask, |
|
|
ground_truth_path=gt_path, |
|
|
orientation="axial", |
|
|
) |
|
|
|
|
|
|
|
|
ortho_fig = render_3panel_view( |
|
|
nifti_path=dwi_path, |
|
|
mask_path=result.prediction_mask, |
|
|
mask_alpha=0.5, |
|
|
) |
|
|
|
|
|
|
|
|
volume_ml: float | None = None |
|
|
try: |
|
|
volume_ml = round(compute_volume_ml(result.prediction_mask, threshold=0.5), 2) |
|
|
except Exception: |
|
|
logger.warning("Failed to compute volume for %s", case_id, exc_info=True) |
|
|
|
|
|
metrics = { |
|
|
"case_id": result.case_id, |
|
|
"dice_score": result.dice_score, |
|
|
"volume_ml": volume_ml, |
|
|
"elapsed_seconds": round(result.elapsed_seconds, 2), |
|
|
"model": "SEALS (Fast)" if fast_mode else "Ensemble", |
|
|
} |
|
|
|
|
|
|
|
|
download_path = str(result.prediction_mask) |
|
|
|
|
|
status_msg = ( |
|
|
f"Success! Dice: {result.dice_score:.3f}" |
|
|
if result.dice_score is not None |
|
|
else "Success!" |
|
|
) |
|
|
|
|
|
|
|
|
return ( |
|
|
niivue_data, |
|
|
slice_fig, |
|
|
ortho_fig, |
|
|
metrics, |
|
|
download_path, |
|
|
status_msg, |
|
|
str(result.results_dir), |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("Error running segmentation") |
|
|
return None, None, None, {}, None, f"Error: {e!s}", previous_results_dir |
|
|
|
|
|
|
|
|
def create_app() -> gr.Blocks: |
|
|
""" |
|
|
Create the Gradio application. |
|
|
|
|
|
Returns: |
|
|
Configured gr.Blocks application |
|
|
""" |
|
|
with gr.Blocks( |
|
|
title="Stroke Lesion Segmentation Demo", |
|
|
) as demo: |
|
|
|
|
|
|
|
|
previous_results_state = gr.State(value=None) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
# Stroke Lesion Segmentation Demo |
|
|
|
|
|
This demo runs [DeepISLES](https://github.com/ezequieldlrosa/DeepIsles) |
|
|
stroke segmentation on cases from |
|
|
[isles24-stroke](https://huggingface.co/datasets/hugging-science/isles24-stroke). |
|
|
|
|
|
**Model:** SEALS (ISLES'22 winner) - Fast, accurate ischemic stroke lesion segmentation. |
|
|
|
|
|
**Note:** First run may take a moment to load models and data. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
case_selector = create_case_selector() |
|
|
settings = create_settings_accordion() |
|
|
run_btn = gr.Button("Run Segmentation", variant="primary") |
|
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
|
results = create_results_display() |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=run_segmentation, |
|
|
inputs=[ |
|
|
case_selector, |
|
|
settings["fast_mode"], |
|
|
settings["show_ground_truth"], |
|
|
previous_results_state, |
|
|
], |
|
|
outputs=[ |
|
|
results["niivue_viewer"], |
|
|
results["slice_plot"], |
|
|
results["ortho_plot"], |
|
|
results["metrics"], |
|
|
results["download"], |
|
|
status, |
|
|
previous_results_state, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.load(initialize_case_selector, outputs=[case_selector]) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
_demo: gr.Blocks | None = None |
|
|
|
|
|
|
|
|
def get_demo() -> gr.Blocks: |
|
|
"""Get the global demo instance, creating it if necessary.""" |
|
|
global _demo |
|
|
if _demo is None: |
|
|
_demo = create_app() |
|
|
return _demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from stroke_deepisles_demo.core.config import get_settings |
|
|
from stroke_deepisles_demo.core.logging import setup_logging |
|
|
|
|
|
settings = get_settings() |
|
|
setup_logging(settings.log_level, format_style=settings.log_format) |
|
|
|
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("STARTUP: stroke-deepisles-demo") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
get_demo().launch( |
|
|
server_name=settings.gradio_server_name, |
|
|
server_port=settings.gradio_server_port, |
|
|
share=settings.gradio_share, |
|
|
theme=gr.themes.Soft(), |
|
|
css="footer {visibility: hidden}", |
|
|
show_error=settings.gradio_show_error, |
|
|
) |
|
|
|