"""Gradio UI for PanCancerSeg single-case CT tumour segmentation.""" import shutil import tempfile from pathlib import Path import gradio as gr from predict import ( CANCER_CONFIGS, install_custom_trainer, resolve_case_id, resolve_model_folder, run_nnunet_prediction, summarize_segmentation, ) from visualize import generate_outputs # ── Constants ────────────────────────────────────────────────────────────────── CANCER_TYPE_CHOICES = { "Kidney Cancer": "kidney_cancer", "Liver Cancer": "liver_cancer", "Pancreatic Cancer": "pancreatic_cancer", "Lung Cancer": "lung_cancer", } DEFAULT_MODEL_DIR = str(Path(__file__).parent / "PanCancerSeg-Specialized-weights") DEFAULT_DEVICE = "cuda" # Hugging Face Hub repo that hosts the trained nnUNet weights. On Spaces (where the # local weights folder is absent) we download them on first use. MODEL_REPO_ID = "KS987/PanCancerSeg-Specialized-weights" def resolve_weights_dir() -> Path: """Return a directory containing the DatasetXXX_* model folders. Prefer a local checkout (fast local dev); otherwise download the weights from the Hugging Face Hub and cache them. """ local_dir = Path(DEFAULT_MODEL_DIR).expanduser().resolve() if local_dir.exists() and any(local_dir.glob("Dataset*")): return local_dir from huggingface_hub import snapshot_download downloaded = snapshot_download( repo_id=MODEL_REPO_ID, repo_type="model", allow_patterns=["Dataset*/**"], ) return Path(downloaded) _SAMPLE_DIR = Path(__file__).parent / "sample_input" _CANCER_TYPE_TO_FOLDER = { "Kidney Cancer": "kidney", "Liver Cancer": "liver", "Pancreatic Cancer": "pancreas", "Lung Cancer": "lung", } def load_example(cancer_type_label: str, index: int) -> str: """Return the index-th (1-based) example _0000.nii.gz for the given cancer type.""" folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label] files = sorted(folder.glob("*_0000.nii.gz")) if len(files) < index: raise gr.Error(f"Example {index} not found for {cancer_type_label} in {folder}") return str(files[index - 1]) # ── Inference ────────────────────────────────────────────────────────────────── def run_inference( input_file, cancer_type_label, fps, progress=gr.Progress(track_tqdm=True), ): import torch if input_file is None: raise gr.Error("Please upload a .nii.gz CT image first.") input_path = Path(input_file) if not input_path.name.endswith(".nii.gz"): raise gr.Error(f"File must be .nii.gz format. Got: {input_path.name}") device = DEFAULT_DEVICE if torch.cuda.is_available() else "cpu" progress(0.02, desc="Resolving model weights...") try: model_dir_path = resolve_weights_dir() except Exception as e: raise gr.Error(f"Failed to obtain model weights from '{MODEL_REPO_ID}': {e}") cancer_key = CANCER_TYPE_CHOICES[cancer_type_label] config = CANCER_CONFIGS[cancer_key] case_id = resolve_case_id(input_path) progress(0.05, desc="Installing custom trainer...") install_custom_trainer() progress(0.10, desc="Loading model weights...") model_folder = resolve_model_folder(model_dir_path, config["dataset_name"]) output_dir = Path(tempfile.mkdtemp(prefix="pancancerseg_out_")) try: with tempfile.TemporaryDirectory(prefix="pancancerseg_in_") as tmp: tmp_path = Path(tmp) tmp_input_dir = tmp_path / "input" tmp_output_dir = tmp_path / "prediction" tmp_input_dir.mkdir() tmp_output_dir.mkdir() nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz" try: nnunet_input.symlink_to(input_path.resolve()) except (OSError, NotImplementedError): shutil.copy2(input_path, nnunet_input) progress(0.20, desc="Running nnUNet inference (this may take a few minutes)...") run_nnunet_prediction( model_folder=model_folder, input_dir=tmp_input_dir, output_dir=tmp_output_dir, device=device, ) raw_seg = tmp_output_dir / f"{case_id}.nii.gz" if not raw_seg.exists(): produced = [p.name for p in tmp_output_dir.glob("*.nii.gz")] raise RuntimeError( f"nnUNet did not produce the expected segmentation. Found: {produced}" ) seg_path = output_dir / f"{case_id}_seg.nii.gz" shutil.copy2(raw_seg, seg_path) progress(0.80, desc="Generating slice images and overlay video...") viz = generate_outputs( image_path=input_path, mask_path=seg_path, output_dir=output_dir, case_name=case_id, cancer_type=config["display_name"], wl=config["wl"], ww=config["ww"], color=config["color"], alpha=0.5, fps=int(fps), ) progress(0.95, desc="Computing tumour volume...") positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path) stats = ( f"Case ID : {case_id}\n" f"Cancer type : {config['display_name']}\n" f"Positive voxels: {positive_voxels:,}\n" f"Tumour volume : {tumor_volume_ml:.3f} mL" ) slices = viz["slices"] video_path = viz["video"] video_out = ( str(video_path) if video_path.exists() and video_path.stat().st_size > 0 else None ) progress(1.0, desc="Done!") return ( stats, str(seg_path), str(slices.get("centroid")), str(slices.get("max_area")), str(slices.get("extent25")), str(slices.get("extent75")), video_out, ) except Exception as e: shutil.rmtree(output_dir, ignore_errors=True) raise gr.Error(str(e)) # ── UI ───────────────────────────────────────────────────────────────────────── def build_ui(): with gr.Blocks(title="PanCancerSeg Inference") as demo: gr.Markdown( """ # PanCancerSeg — Specialist CT Tumour Segmentation Upload a `.nii.gz` CT image, select the cancer type, and click **Run Inference** to obtain a segmentation mask and visualisations. """ ) with gr.Row(): # ── Left panel: inputs ───────────────────────────────────────────── with gr.Column(scale=1, min_width=300): input_file = gr.File( label="CT Image (.nii.gz)", file_types=[".gz"], ) cancer_type = gr.Dropdown( choices=list(CANCER_TYPE_CHOICES.keys()), value="Kidney Cancer", label="Cancer Type", ) fps = gr.Slider( minimum=1, maximum=30, value=10, step=1, label="Video FPS", ) with gr.Row(): load_btn_1 = gr.Button("Load Example 1", size="lg") load_btn_2 = gr.Button("Load Example 2", size="lg") run_btn = gr.Button("Run Inference", variant="primary", size="lg") video_out = gr.Video(label="Overlay Video") # ── Right panel: outputs ─────────────────────────────────────────── with gr.Column(scale=2): with gr.Row(): stats_box = gr.Textbox( label="Inference Summary", lines=4, interactive=False, ) seg_file = gr.File(label="Download Segmentation Mask (.nii.gz)") with gr.Row(): img_centroid = gr.Image(label="Centroid Slice", type="filepath") img_max_area = gr.Image(label="Max Area Slice", type="filepath") with gr.Row(): img_ext25 = gr.Image(label="Extent 25% Slice", type="filepath") img_ext75 = gr.Image(label="Extent 75% Slice", type="filepath") load_btn_1.click(fn=lambda ct: load_example(ct, 1), inputs=[cancer_type], outputs=[input_file]) load_btn_2.click(fn=lambda ct: load_example(ct, 2), inputs=[cancer_type], outputs=[input_file]) run_btn.click( fn=run_inference, inputs=[input_file, cancer_type, fps], outputs=[ stats_box, seg_file, img_centroid, img_max_area, img_ext25, img_ext75, video_out, ], ) return demo if __name__ == "__main__": import os demo = build_ui() # Hugging Face Spaces expect the app on port 7860 (set via GRADIO_SERVER_PORT). # Locally this falls back to 7860 unless overridden. port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) demo.launch( server_name="0.0.0.0", server_port=port, share=False, theme=gr.themes.Soft(), ssr_mode=False, )