| """Gradio UI for PanCancerSeg single-case CT tumour segmentation.""" |
|
|
| import shutil |
| import tempfile |
| from pathlib import Path |
|
|
| import gradio as gr |
|
|
| from predict import ( |
| CANCER_CONFIGS, |
| install_custom_trainer, |
| resolve_case_id, |
| resolve_model_folder, |
| run_nnunet_prediction, |
| summarize_segmentation, |
| ) |
| from visualize import generate_outputs |
|
|
| |
|
|
| CANCER_TYPE_CHOICES = { |
| "Kidney Cancer": "kidney_cancer", |
| "Liver Cancer": "liver_cancer", |
| "Pancreatic Cancer": "pancreatic_cancer", |
| "Lung Cancer": "lung_cancer", |
| } |
|
|
| DEFAULT_MODEL_DIR = str(Path(__file__).parent / "PanCancerSeg-Specialized-weights") |
| DEFAULT_DEVICE = "cuda" |
|
|
| |
| |
| MODEL_REPO_ID = "KS987/PanCancerSeg-Specialized-weights" |
|
|
|
|
| def resolve_weights_dir() -> Path: |
| """Return a directory containing the DatasetXXX_* model folders. |
| |
| Prefer a local checkout (fast local dev); otherwise download the weights |
| from the Hugging Face Hub and cache them. |
| """ |
| local_dir = Path(DEFAULT_MODEL_DIR).expanduser().resolve() |
| if local_dir.exists() and any(local_dir.glob("Dataset*")): |
| return local_dir |
|
|
| from huggingface_hub import snapshot_download |
|
|
| downloaded = snapshot_download( |
| repo_id=MODEL_REPO_ID, |
| repo_type="model", |
| allow_patterns=["Dataset*/**"], |
| ) |
| return Path(downloaded) |
|
|
| _SAMPLE_DIR = Path(__file__).parent / "sample_input" |
| _CANCER_TYPE_TO_FOLDER = { |
| "Kidney Cancer": "kidney", |
| "Liver Cancer": "liver", |
| "Pancreatic Cancer": "pancreas", |
| "Lung Cancer": "lung", |
| } |
|
|
| def load_example(cancer_type_label: str, index: int) -> str: |
| """Return the index-th (1-based) example _0000.nii.gz for the given cancer type.""" |
| folder = _SAMPLE_DIR / _CANCER_TYPE_TO_FOLDER[cancer_type_label] |
| files = sorted(folder.glob("*_0000.nii.gz")) |
| if len(files) < index: |
| raise gr.Error(f"Example {index} not found for {cancer_type_label} in {folder}") |
| return str(files[index - 1]) |
|
|
|
|
| |
|
|
| def run_inference( |
| input_file, |
| cancer_type_label, |
| fps, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| import torch |
|
|
| if input_file is None: |
| raise gr.Error("Please upload a .nii.gz CT image first.") |
|
|
| input_path = Path(input_file) |
| if not input_path.name.endswith(".nii.gz"): |
| raise gr.Error(f"File must be .nii.gz format. Got: {input_path.name}") |
|
|
| device = DEFAULT_DEVICE if torch.cuda.is_available() else "cpu" |
|
|
| progress(0.02, desc="Resolving model weights...") |
| try: |
| model_dir_path = resolve_weights_dir() |
| except Exception as e: |
| raise gr.Error(f"Failed to obtain model weights from '{MODEL_REPO_ID}': {e}") |
|
|
| cancer_key = CANCER_TYPE_CHOICES[cancer_type_label] |
| config = CANCER_CONFIGS[cancer_key] |
| case_id = resolve_case_id(input_path) |
|
|
| progress(0.05, desc="Installing custom trainer...") |
| install_custom_trainer() |
|
|
| progress(0.10, desc="Loading model weights...") |
| model_folder = resolve_model_folder(model_dir_path, config["dataset_name"]) |
|
|
| output_dir = Path(tempfile.mkdtemp(prefix="pancancerseg_out_")) |
|
|
| try: |
| with tempfile.TemporaryDirectory(prefix="pancancerseg_in_") as tmp: |
| tmp_path = Path(tmp) |
| tmp_input_dir = tmp_path / "input" |
| tmp_output_dir = tmp_path / "prediction" |
| tmp_input_dir.mkdir() |
| tmp_output_dir.mkdir() |
|
|
| nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz" |
| try: |
| nnunet_input.symlink_to(input_path.resolve()) |
| except (OSError, NotImplementedError): |
| shutil.copy2(input_path, nnunet_input) |
|
|
| progress(0.20, desc="Running nnUNet inference (this may take a few minutes)...") |
| run_nnunet_prediction( |
| model_folder=model_folder, |
| input_dir=tmp_input_dir, |
| output_dir=tmp_output_dir, |
| device=device, |
| ) |
|
|
| raw_seg = tmp_output_dir / f"{case_id}.nii.gz" |
| if not raw_seg.exists(): |
| produced = [p.name for p in tmp_output_dir.glob("*.nii.gz")] |
| raise RuntimeError( |
| f"nnUNet did not produce the expected segmentation. Found: {produced}" |
| ) |
|
|
| seg_path = output_dir / f"{case_id}_seg.nii.gz" |
| shutil.copy2(raw_seg, seg_path) |
|
|
| progress(0.80, desc="Generating slice images and overlay video...") |
| viz = generate_outputs( |
| image_path=input_path, |
| mask_path=seg_path, |
| output_dir=output_dir, |
| case_name=case_id, |
| cancer_type=config["display_name"], |
| wl=config["wl"], |
| ww=config["ww"], |
| color=config["color"], |
| alpha=0.5, |
| fps=int(fps), |
| ) |
|
|
| progress(0.95, desc="Computing tumour volume...") |
| positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path) |
|
|
| stats = ( |
| f"Case ID : {case_id}\n" |
| f"Cancer type : {config['display_name']}\n" |
| f"Positive voxels: {positive_voxels:,}\n" |
| f"Tumour volume : {tumor_volume_ml:.3f} mL" |
| ) |
|
|
| slices = viz["slices"] |
| video_path = viz["video"] |
| video_out = ( |
| str(video_path) |
| if video_path.exists() and video_path.stat().st_size > 0 |
| else None |
| ) |
|
|
| progress(1.0, desc="Done!") |
| return ( |
| stats, |
| str(seg_path), |
| str(slices.get("centroid")), |
| str(slices.get("max_area")), |
| str(slices.get("extent25")), |
| str(slices.get("extent75")), |
| video_out, |
| ) |
|
|
| except Exception as e: |
| shutil.rmtree(output_dir, ignore_errors=True) |
| raise gr.Error(str(e)) |
|
|
|
|
| |
|
|
| def build_ui(): |
| with gr.Blocks(title="PanCancerSeg Inference") as demo: |
| gr.Markdown( |
| """ |
| # PanCancerSeg β Specialist CT Tumour Segmentation |
| Upload a `.nii.gz` CT image, select the cancer type, and click **Run Inference** to obtain |
| a segmentation mask and visualisations. |
| """ |
| ) |
|
|
| with gr.Row(): |
| |
| with gr.Column(scale=1, min_width=300): |
| input_file = gr.File( |
| label="CT Image (.nii.gz)", |
| file_types=[".gz"], |
| ) |
| cancer_type = gr.Dropdown( |
| choices=list(CANCER_TYPE_CHOICES.keys()), |
| value="Kidney Cancer", |
| label="Cancer Type", |
| ) |
| fps = gr.Slider( |
| minimum=1, |
| maximum=30, |
| value=10, |
| step=1, |
| label="Video FPS", |
| ) |
| with gr.Row(): |
| load_btn_1 = gr.Button("Load Example 1", size="lg") |
| load_btn_2 = gr.Button("Load Example 2", size="lg") |
| run_btn = gr.Button("Run Inference", variant="primary", size="lg") |
| video_out = gr.Video(label="Overlay Video") |
|
|
| |
| with gr.Column(scale=2): |
| with gr.Row(): |
| stats_box = gr.Textbox( |
| label="Inference Summary", |
| lines=4, |
| interactive=False, |
| ) |
| seg_file = gr.File(label="Download Segmentation Mask (.nii.gz)") |
| with gr.Row(): |
| img_centroid = gr.Image(label="Centroid Slice", type="filepath") |
| img_max_area = gr.Image(label="Max Area Slice", type="filepath") |
| with gr.Row(): |
| img_ext25 = gr.Image(label="Extent 25% Slice", type="filepath") |
| img_ext75 = gr.Image(label="Extent 75% Slice", type="filepath") |
|
|
| load_btn_1.click(fn=lambda ct: load_example(ct, 1), inputs=[cancer_type], outputs=[input_file]) |
| load_btn_2.click(fn=lambda ct: load_example(ct, 2), inputs=[cancer_type], outputs=[input_file]) |
|
|
| run_btn.click( |
| fn=run_inference, |
| inputs=[input_file, cancer_type, fps], |
| outputs=[ |
| stats_box, |
| seg_file, |
| img_centroid, |
| img_max_area, |
| img_ext25, |
| img_ext75, |
| video_out, |
| ], |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| import os |
|
|
| demo = build_ui() |
| |
| |
| port = int(os.environ.get("GRADIO_SERVER_PORT", 7860)) |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=port, |
| share=False, |
| theme=gr.themes.Soft(), |
| ssr_mode=False, |
| ) |
|
|