WorkTimer's picture
Move theme to launch(), disable SSR to silence asyncio noise
937fcbb verified
"""Gradio UI for PanCancerSeg single-case CT tumour segmentation."""
import shutil
import tempfile
from pathlib import Path
import gradio as gr
from predict import (
CANCER_CONFIGS,
install_custom_trainer,
resolve_case_id,
resolve_model_folder,
run_nnunet_prediction,
summarize_segmentation,
)
from visualize import generate_outputs
# ── Constants ──────────────────────────────────────────────────────────────────
CANCER_TYPE_CHOICES = {
"Kidney Cancer": "kidney_cancer",
"Liver Cancer": "liver_cancer",
"Pancreatic Cancer": "pancreatic_cancer",
"Lung Cancer": "lung_cancer",
}
DEFAULT_MODEL_DIR = str(Path(__file__).parent / "PanCancerSeg-Specialized-weights")
DEFAULT_DEVICE = "cuda"
# Hugging Face Hub repo that hosts the trained nnUNet weights. On Spaces (where the
# local weights folder is absent) we download them on first use.
MODEL_REPO_ID = "KS987/PanCancerSeg-Specialized-weights"
def resolve_weights_dir() -> Path:
"""Return a directory containing the DatasetXXX_* model folders.
Prefer a local checkout (fast local dev); otherwise download the weights
from the Hugging Face Hub and cache them.
"""
local_dir = Path(DEFAULT_MODEL_DIR).expanduser().resolve()
if local_dir.exists() and any(local_dir.glob("Dataset*")):
return local_dir
from huggingface_hub import snapshot_download
downloaded = snapshot_download(
repo_id=MODEL_REPO_ID,
repo_type="model",
allow_patterns=["Dataset*/**"],
)
return Path(downloaded)
_SAMPLE_DIR = Path(__file__).parent / "sample_input"
_CANCER_TYPE_TO_FOLDER = {
"Kidney Cancer": "kidney",
"Liver Cancer": "liver",
"Pancreatic Cancer": "pancreas",
"Lung Cancer": "lung",
}
def load_example(cancer_type_label: str, index: int) -> str:
"""Return the index-th (1-based) example _0000.nii.gz for the given cancer type."""
folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label]
files = sorted(folder.glob("*_0000.nii.gz"))
if len(files) < index:
raise gr.Error(f"Example {index} not found for {cancer_type_label} in {folder}")
return str(files[index - 1])
# ── Inference ──────────────────────────────────────────────────────────────────
def run_inference(
input_file,
cancer_type_label,
fps,
progress=gr.Progress(track_tqdm=True),
):
import torch
if input_file is None:
raise gr.Error("Please upload a .nii.gz CT image first.")
input_path = Path(input_file)
if not input_path.name.endswith(".nii.gz"):
raise gr.Error(f"File must be .nii.gz format. Got: {input_path.name}")
device = DEFAULT_DEVICE if torch.cuda.is_available() else "cpu"
progress(0.02, desc="Resolving model weights...")
try:
model_dir_path = resolve_weights_dir()
except Exception as e:
raise gr.Error(f"Failed to obtain model weights from '{MODEL_REPO_ID}': {e}")
cancer_key = CANCER_TYPE_CHOICES[cancer_type_label]
config = CANCER_CONFIGS[cancer_key]
case_id = resolve_case_id(input_path)
progress(0.05, desc="Installing custom trainer...")
install_custom_trainer()
progress(0.10, desc="Loading model weights...")
model_folder = resolve_model_folder(model_dir_path, config["dataset_name"])
output_dir = Path(tempfile.mkdtemp(prefix="pancancerseg_out_"))
try:
with tempfile.TemporaryDirectory(prefix="pancancerseg_in_") as tmp:
tmp_path = Path(tmp)
tmp_input_dir = tmp_path / "input"
tmp_output_dir = tmp_path / "prediction"
tmp_input_dir.mkdir()
tmp_output_dir.mkdir()
nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz"
try:
nnunet_input.symlink_to(input_path.resolve())
except (OSError, NotImplementedError):
shutil.copy2(input_path, nnunet_input)
progress(0.20, desc="Running nnUNet inference (this may take a few minutes)...")
run_nnunet_prediction(
model_folder=model_folder,
input_dir=tmp_input_dir,
output_dir=tmp_output_dir,
device=device,
)
raw_seg = tmp_output_dir / f"{case_id}.nii.gz"
if not raw_seg.exists():
produced = [p.name for p in tmp_output_dir.glob("*.nii.gz")]
raise RuntimeError(
f"nnUNet did not produce the expected segmentation. Found: {produced}"
)
seg_path = output_dir / f"{case_id}_seg.nii.gz"
shutil.copy2(raw_seg, seg_path)
progress(0.80, desc="Generating slice images and overlay video...")
viz = generate_outputs(
image_path=input_path,
mask_path=seg_path,
output_dir=output_dir,
case_name=case_id,
cancer_type=config["display_name"],
wl=config["wl"],
ww=config["ww"],
color=config["color"],
alpha=0.5,
fps=int(fps),
)
progress(0.95, desc="Computing tumour volume...")
positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path)
stats = (
f"Case ID : {case_id}\n"
f"Cancer type : {config['display_name']}\n"
f"Positive voxels: {positive_voxels:,}\n"
f"Tumour volume : {tumor_volume_ml:.3f} mL"
)
slices = viz["slices"]
video_path = viz["video"]
video_out = (
str(video_path)
if video_path.exists() and video_path.stat().st_size > 0
else None
)
progress(1.0, desc="Done!")
return (
stats,
str(seg_path),
str(slices.get("centroid")),
str(slices.get("max_area")),
str(slices.get("extent25")),
str(slices.get("extent75")),
video_out,
)
except Exception as e:
shutil.rmtree(output_dir, ignore_errors=True)
raise gr.Error(str(e))
# ── UI ─────────────────────────────────────────────────────────────────────────
def build_ui():
with gr.Blocks(title="PanCancerSeg Inference") as demo:
gr.Markdown(
"""
# PanCancerSeg β€” Specialist CT Tumour Segmentation
Upload a `.nii.gz` CT image, select the cancer type, and click **Run Inference** to obtain
a segmentation mask and visualisations.
"""
)
with gr.Row():
# ── Left panel: inputs ─────────────────────────────────────────────
with gr.Column(scale=1, min_width=300):
input_file = gr.File(
label="CT Image (.nii.gz)",
file_types=[".gz"],
)
cancer_type = gr.Dropdown(
choices=list(CANCER_TYPE_CHOICES.keys()),
value="Kidney Cancer",
label="Cancer Type",
)
fps = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Video FPS",
)
with gr.Row():
load_btn_1 = gr.Button("Load Example 1", size="lg")
load_btn_2 = gr.Button("Load Example 2", size="lg")
run_btn = gr.Button("Run Inference", variant="primary", size="lg")
video_out = gr.Video(label="Overlay Video")
# ── Right panel: outputs ───────────────────────────────────────────
with gr.Column(scale=2):
with gr.Row():
stats_box = gr.Textbox(
label="Inference Summary",
lines=4,
interactive=False,
)
seg_file = gr.File(label="Download Segmentation Mask (.nii.gz)")
with gr.Row():
img_centroid = gr.Image(label="Centroid Slice", type="filepath")
img_max_area = gr.Image(label="Max Area Slice", type="filepath")
with gr.Row():
img_ext25 = gr.Image(label="Extent 25% Slice", type="filepath")
img_ext75 = gr.Image(label="Extent 75% Slice", type="filepath")
load_btn_1.click(fn=lambda ct: load_example(ct, 1), inputs=[cancer_type], outputs=[input_file])
load_btn_2.click(fn=lambda ct: load_example(ct, 2), inputs=[cancer_type], outputs=[input_file])
run_btn.click(
fn=run_inference,
inputs=[input_file, cancer_type, fps],
outputs=[
stats_box,
seg_file,
img_centroid,
img_max_area,
img_ext25,
img_ext75,
video_out,
],
)
return demo
if __name__ == "__main__":
import os
demo = build_ui()
# Hugging Face Spaces expect the app on port 7860 (set via GRADIO_SERVER_PORT).
# Locally this falls back to 7860 unless overridden.
port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(
server_name="0.0.0.0",
server_port=port,
share=False,
theme=gr.themes.Soft(),
ssr_mode=False,
)