| """Run single-case PanCancerSeg nnUNet CT inference and visualization.""" |
|
|
| import argparse |
| import shutil |
| import tempfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import SimpleITK as sitk |
| import torch |
|
|
| from visualize import generate_outputs |
|
|
|
|
| CANCER_CONFIGS = { |
| "kidney_cancer": { |
| "dataset_id": 102, |
| "dataset_name": "Dataset102_Kidney", |
| "display_name": "Kidney cancer", |
| "wl": 40, |
| "ww": 400, |
| "color": (255, 0, 0), |
| }, |
| "liver_cancer": { |
| "dataset_id": 103, |
| "dataset_name": "Dataset103_Liver", |
| "display_name": "Liver cancer", |
| "wl": 40, |
| "ww": 400, |
| "color": (255, 0, 0), |
| }, |
| "pancreatic_cancer": { |
| "dataset_id": 104, |
| "dataset_name": "Dataset104_Pancreas", |
| "display_name": "Pancreatic cancer", |
| "wl": 40, |
| "ww": 400, |
| "color": (255, 0, 0), |
| }, |
| "lung_cancer": { |
| "dataset_id": 105, |
| "dataset_name": "Dataset105_Lung", |
| "display_name": "Lung cancer", |
| "wl": -600, |
| "ww": 1500, |
| "color": (255, 0, 0), |
| }, |
| } |
|
|
| CANCER_TYPE_ALIASES = { |
| "kidney": "kidney_cancer", |
| "liver": "liver_cancer", |
| "pancreas": "pancreatic_cancer", |
| "lung": "lung_cancer", |
| } |
|
|
| TRAINER_NAME = "nnUNetTrainerWandb2000" |
| PLANS_NAME = "nnUNetResEncUNetMPlans" |
| CONFIGURATION = "3d_fullres" |
| CHECKPOINT_NAME = "checkpoint_best.pth" |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Run one PanCancerSeg cancer-specific nnUNet model on a single NIfTI image." |
| ) |
| parser.add_argument("--input", required=True, help="Path to a single .nii.gz image") |
| parser.add_argument( |
| "--cancer_type", |
| required=True, |
| help=( |
| "Cancer-specific model to use. " |
| f"Canonical values: {', '.join(sorted(CANCER_CONFIGS))}. " |
| f"Legacy aliases still accepted: {', '.join(sorted(CANCER_TYPE_ALIASES))}." |
| ), |
| ) |
| parser.add_argument( |
| "--model_dir", |
| required=True, |
| help="Path to nnUNet results directory containing DatasetXXX_* folders", |
| ) |
| parser.add_argument("--output_dir", default="./output", help="Where to save results") |
| parser.add_argument("--fps", type=int, default=10, help="Video frames per second") |
| parser.add_argument("--device", choices=["cuda", "cpu"], default="cuda") |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| args.cancer_type = normalize_cancer_type(args.cancer_type) |
| input_path = Path(args.input).expanduser().resolve() |
| model_dir = Path(args.model_dir).expanduser().resolve() |
| output_dir = Path(args.output_dir).expanduser().resolve() |
|
|
| if not input_path.exists(): |
| raise FileNotFoundError(f"Input image not found: {input_path}") |
| if input_path.name.startswith("._") or not input_path.name.endswith(".nii.gz"): |
| raise ValueError(f"Expected a .nii.gz image, got: {input_path.name}") |
| if not model_dir.exists(): |
| raise FileNotFoundError(f"Model directory not found: {model_dir}") |
| if args.device == "cuda" and not torch.cuda.is_available(): |
| raise RuntimeError( |
| "CUDA was requested but torch.cuda.is_available() is False. " |
| "Use --device cpu or install CUDA-enabled PyTorch." |
| ) |
| if args.fps <= 0: |
| raise ValueError("--fps must be a positive integer") |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| config = CANCER_CONFIGS[args.cancer_type] |
| case_id = resolve_case_id(input_path) |
|
|
| install_custom_trainer() |
| model_folder = resolve_model_folder(model_dir, config["dataset_name"]) |
|
|
| with tempfile.TemporaryDirectory(prefix="pancancerseg_") as tmp: |
| tmp_path = Path(tmp) |
| tmp_input_dir = tmp_path / "input" |
| tmp_output_dir = tmp_path / "prediction" |
| tmp_input_dir.mkdir() |
| tmp_output_dir.mkdir() |
|
|
| nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz" |
| symlink_or_copy(input_path, nnunet_input) |
|
|
| run_nnunet_prediction( |
| model_folder=model_folder, |
| input_dir=tmp_input_dir, |
| output_dir=tmp_output_dir, |
| device=args.device, |
| ) |
|
|
| raw_seg = tmp_output_dir / f"{case_id}.nii.gz" |
| if not raw_seg.exists(): |
| produced = sorted(tmp_output_dir.glob("*.nii.gz")) |
| raise FileNotFoundError( |
| f"nnUNet did not write the expected segmentation {raw_seg}. " |
| f"Found: {[p.name for p in produced]}" |
| ) |
|
|
| seg_path = output_dir / f"{case_id}_seg.nii.gz" |
| shutil.copy2(raw_seg, seg_path) |
|
|
| viz_outputs = generate_outputs( |
| image_path=input_path, |
| mask_path=seg_path, |
| output_dir=output_dir, |
| case_name=case_id, |
| cancer_type=config["display_name"], |
| wl=config["wl"], |
| ww=config["ww"], |
| color=config["color"], |
| alpha=0.5, |
| fps=args.fps, |
| ) |
|
|
| positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path) |
| print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml) |
|
|
|
|
| def resolve_case_id(input_path): |
| name = input_path.name |
| if not name.endswith(".nii.gz"): |
| raise ValueError(f"Expected a .nii.gz image, got: {name}") |
| case_id = name[: -len(".nii.gz")] |
| if case_id.endswith("_0000"): |
| case_id = case_id[: -len("_0000")] |
| if not case_id: |
| raise ValueError(f"Could not resolve a case ID from: {input_path}") |
| return case_id |
|
|
|
|
| def normalize_cancer_type(cancer_type): |
| cancer_type = cancer_type.strip().lower() |
| normalized = CANCER_TYPE_ALIASES.get(cancer_type, cancer_type) |
| if normalized not in CANCER_CONFIGS: |
| valid = sorted(list(CANCER_CONFIGS) + list(CANCER_TYPE_ALIASES)) |
| raise ValueError( |
| f"Unsupported --cancer_type '{cancer_type}'. Valid values: {', '.join(valid)}" |
| ) |
| return normalized |
|
|
|
|
| def install_custom_trainer(): |
| import nnunetv2 |
|
|
| src = Path(__file__).resolve().parent / "trainers" / f"{TRAINER_NAME}.py" |
| if not src.exists(): |
| raise FileNotFoundError(f"Custom trainer file is missing: {src}") |
|
|
| variants_dir = Path(nnunetv2.__path__[0]) / "training" / "nnUNetTrainer" / "variants" |
| variants_dir.mkdir(parents=True, exist_ok=True) |
| dst = variants_dir / src.name |
|
|
| if dst.exists() or dst.is_symlink(): |
| try: |
| if dst.resolve() == src.resolve(): |
| return dst |
| except OSError: |
| pass |
| dst.unlink() |
|
|
| try: |
| dst.symlink_to(src.resolve()) |
| except (OSError, NotImplementedError): |
| shutil.copy2(src, dst) |
| print(f"Installed custom trainer: {dst}") |
| return dst |
|
|
|
|
| def resolve_model_folder(model_dir, dataset_name): |
| model_folder = ( |
| model_dir |
| / dataset_name |
| / f"{TRAINER_NAME}__{PLANS_NAME}__{CONFIGURATION}" |
| ) |
| checkpoint = model_folder / "fold_0" / CHECKPOINT_NAME |
| if not checkpoint.exists(): |
| raise FileNotFoundError( |
| f"Expected checkpoint not found: {checkpoint}. " |
| "Check --model_dir and make sure the trained weights are downloaded." |
| ) |
| return model_folder |
|
|
|
|
| def symlink_or_copy(src, dst): |
| try: |
| dst.symlink_to(src.resolve()) |
| except (OSError, NotImplementedError): |
| shutil.copy2(src, dst) |
|
|
|
|
| def run_nnunet_prediction(model_folder, input_dir, output_dir, device): |
| from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor |
|
|
| predictor = nnUNetPredictor( |
| tile_step_size=0.5, |
| use_gaussian=True, |
| use_mirroring=False, |
| perform_everything_on_device=(device == "cuda"), |
| device=torch.device(device), |
| verbose=False, |
| verbose_preprocessing=False, |
| allow_tqdm=True, |
| ) |
| predictor.initialize_from_trained_model_folder( |
| str(model_folder), |
| use_folds=(0,), |
| checkpoint_name=CHECKPOINT_NAME, |
| ) |
| predictor.predict_from_files( |
| str(input_dir), |
| str(output_dir), |
| save_probabilities=False, |
| overwrite=True, |
| num_processes_preprocessing=1, |
| num_processes_segmentation_export=1, |
| folder_with_segs_from_prev_stage=None, |
| num_parts=1, |
| part_id=0, |
| ) |
|
|
|
|
| def summarize_segmentation(seg_path): |
| seg = sitk.ReadImage(str(seg_path)) |
| seg_arr = sitk.GetArrayFromImage(seg) |
| positive_voxels = int(np.count_nonzero(seg_arr)) |
| spacing_x, spacing_y, spacing_z = seg.GetSpacing() |
| tumor_volume_ml = positive_voxels * spacing_x * spacing_y * spacing_z / 1000.0 |
| return positive_voxels, tumor_volume_ml |
|
|
|
|
| def print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml): |
| print("\nPanCancerSeg inference complete") |
| print(f"Segmentation mask : {seg_path}") |
| print("Slice PNGs :") |
| for label, path in viz_outputs["slices"].items(): |
| print(f" {label:9s} : {path}") |
| print(f"Overlay video : {viz_outputs['video']}") |
| print(f"Positive voxels : {positive_voxels}") |
| print(f"Tumor volume : {tumor_volume_ml:.3f} mL") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|