WorkTimer's picture
Add PanCancerSeg Gradio inference app
ce1057b verified
"""Run single-case PanCancerSeg nnUNet CT inference and visualization."""
import argparse
import shutil
import tempfile
from pathlib import Path
import numpy as np
import SimpleITK as sitk
import torch
from visualize import generate_outputs
CANCER_CONFIGS = {
"kidney_cancer": {
"dataset_id": 102,
"dataset_name": "Dataset102_Kidney",
"display_name": "Kidney cancer",
"wl": 40,
"ww": 400,
"color": (255, 0, 0),
},
"liver_cancer": {
"dataset_id": 103,
"dataset_name": "Dataset103_Liver",
"display_name": "Liver cancer",
"wl": 40,
"ww": 400,
"color": (255, 0, 0),
},
"pancreatic_cancer": {
"dataset_id": 104,
"dataset_name": "Dataset104_Pancreas",
"display_name": "Pancreatic cancer",
"wl": 40,
"ww": 400,
"color": (255, 0, 0),
},
"lung_cancer": {
"dataset_id": 105,
"dataset_name": "Dataset105_Lung",
"display_name": "Lung cancer",
"wl": -600,
"ww": 1500,
"color": (255, 0, 0),
},
}
CANCER_TYPE_ALIASES = {
"kidney": "kidney_cancer",
"liver": "liver_cancer",
"pancreas": "pancreatic_cancer",
"lung": "lung_cancer",
}
TRAINER_NAME = "nnUNetTrainerWandb2000"
PLANS_NAME = "nnUNetResEncUNetMPlans"
CONFIGURATION = "3d_fullres"
CHECKPOINT_NAME = "checkpoint_best.pth"
def parse_args():
parser = argparse.ArgumentParser(
description="Run one PanCancerSeg cancer-specific nnUNet model on a single NIfTI image."
)
parser.add_argument("--input", required=True, help="Path to a single .nii.gz image")
parser.add_argument(
"--cancer_type",
required=True,
help=(
"Cancer-specific model to use. "
f"Canonical values: {', '.join(sorted(CANCER_CONFIGS))}. "
f"Legacy aliases still accepted: {', '.join(sorted(CANCER_TYPE_ALIASES))}."
),
)
parser.add_argument(
"--model_dir",
required=True,
help="Path to nnUNet results directory containing DatasetXXX_* folders",
)
parser.add_argument("--output_dir", default="./output", help="Where to save results")
parser.add_argument("--fps", type=int, default=10, help="Video frames per second")
parser.add_argument("--device", choices=["cuda", "cpu"], default="cuda")
return parser.parse_args()
def main():
args = parse_args()
args.cancer_type = normalize_cancer_type(args.cancer_type)
input_path = Path(args.input).expanduser().resolve()
model_dir = Path(args.model_dir).expanduser().resolve()
output_dir = Path(args.output_dir).expanduser().resolve()
if not input_path.exists():
raise FileNotFoundError(f"Input image not found: {input_path}")
if input_path.name.startswith("._") or not input_path.name.endswith(".nii.gz"):
raise ValueError(f"Expected a .nii.gz image, got: {input_path.name}")
if not model_dir.exists():
raise FileNotFoundError(f"Model directory not found: {model_dir}")
if args.device == "cuda" and not torch.cuda.is_available():
raise RuntimeError(
"CUDA was requested but torch.cuda.is_available() is False. "
"Use --device cpu or install CUDA-enabled PyTorch."
)
if args.fps <= 0:
raise ValueError("--fps must be a positive integer")
output_dir.mkdir(parents=True, exist_ok=True)
config = CANCER_CONFIGS[args.cancer_type]
case_id = resolve_case_id(input_path)
install_custom_trainer()
model_folder = resolve_model_folder(model_dir, config["dataset_name"])
with tempfile.TemporaryDirectory(prefix="pancancerseg_") as tmp:
tmp_path = Path(tmp)
tmp_input_dir = tmp_path / "input"
tmp_output_dir = tmp_path / "prediction"
tmp_input_dir.mkdir()
tmp_output_dir.mkdir()
nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz"
symlink_or_copy(input_path, nnunet_input)
run_nnunet_prediction(
model_folder=model_folder,
input_dir=tmp_input_dir,
output_dir=tmp_output_dir,
device=args.device,
)
raw_seg = tmp_output_dir / f"{case_id}.nii.gz"
if not raw_seg.exists():
produced = sorted(tmp_output_dir.glob("*.nii.gz"))
raise FileNotFoundError(
f"nnUNet did not write the expected segmentation {raw_seg}. "
f"Found: {[p.name for p in produced]}"
)
seg_path = output_dir / f"{case_id}_seg.nii.gz"
shutil.copy2(raw_seg, seg_path)
viz_outputs = generate_outputs(
image_path=input_path,
mask_path=seg_path,
output_dir=output_dir,
case_name=case_id,
cancer_type=config["display_name"],
wl=config["wl"],
ww=config["ww"],
color=config["color"],
alpha=0.5,
fps=args.fps,
)
positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path)
print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml)
def resolve_case_id(input_path):
name = input_path.name
if not name.endswith(".nii.gz"):
raise ValueError(f"Expected a .nii.gz image, got: {name}")
case_id = name[: -len(".nii.gz")]
if case_id.endswith("_0000"):
case_id = case_id[: -len("_0000")]
if not case_id:
raise ValueError(f"Could not resolve a case ID from: {input_path}")
return case_id
def normalize_cancer_type(cancer_type):
cancer_type = cancer_type.strip().lower()
normalized = CANCER_TYPE_ALIASES.get(cancer_type, cancer_type)
if normalized not in CANCER_CONFIGS:
valid = sorted(list(CANCER_CONFIGS) + list(CANCER_TYPE_ALIASES))
raise ValueError(
f"Unsupported --cancer_type '{cancer_type}'. Valid values: {', '.join(valid)}"
)
return normalized
def install_custom_trainer():
import nnunetv2
src = Path(__file__).resolve().parent / "trainers" / f"{TRAINER_NAME}.py"
if not src.exists():
raise FileNotFoundError(f"Custom trainer file is missing: {src}")
variants_dir = Path(nnunetv2.__path__[0]) / "training" / "nnUNetTrainer" / "variants"
variants_dir.mkdir(parents=True, exist_ok=True)
dst = variants_dir / src.name
if dst.exists() or dst.is_symlink():
try:
if dst.resolve() == src.resolve():
return dst
except OSError:
pass
dst.unlink()
try:
dst.symlink_to(src.resolve())
except (OSError, NotImplementedError):
shutil.copy2(src, dst)
print(f"Installed custom trainer: {dst}")
return dst
def resolve_model_folder(model_dir, dataset_name):
model_folder = (
model_dir
/ dataset_name
/ f"{TRAINER_NAME}__{PLANS_NAME}__{CONFIGURATION}"
)
checkpoint = model_folder / "fold_0" / CHECKPOINT_NAME
if not checkpoint.exists():
raise FileNotFoundError(
f"Expected checkpoint not found: {checkpoint}. "
"Check --model_dir and make sure the trained weights are downloaded."
)
return model_folder
def symlink_or_copy(src, dst):
try:
dst.symlink_to(src.resolve())
except (OSError, NotImplementedError):
shutil.copy2(src, dst)
def run_nnunet_prediction(model_folder, input_dir, output_dir, device):
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=False,
perform_everything_on_device=(device == "cuda"),
device=torch.device(device),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True,
)
predictor.initialize_from_trained_model_folder(
str(model_folder),
use_folds=(0,),
checkpoint_name=CHECKPOINT_NAME,
)
predictor.predict_from_files(
str(input_dir),
str(output_dir),
save_probabilities=False,
overwrite=True,
num_processes_preprocessing=1,
num_processes_segmentation_export=1,
folder_with_segs_from_prev_stage=None,
num_parts=1,
part_id=0,
)
def summarize_segmentation(seg_path):
seg = sitk.ReadImage(str(seg_path))
seg_arr = sitk.GetArrayFromImage(seg)
positive_voxels = int(np.count_nonzero(seg_arr))
spacing_x, spacing_y, spacing_z = seg.GetSpacing()
tumor_volume_ml = positive_voxels * spacing_x * spacing_y * spacing_z / 1000.0
return positive_voxels, tumor_volume_ml
def print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml):
print("\nPanCancerSeg inference complete")
print(f"Segmentation mask : {seg_path}")
print("Slice PNGs :")
for label, path in viz_outputs["slices"].items():
print(f" {label:9s} : {path}")
print(f"Overlay video : {viz_outputs['video']}")
print(f"Positive voxels : {positive_voxels}")
print(f"Tumor volume : {tumor_volume_ml:.3f} mL")
if __name__ == "__main__":
main()