File size: 9,105 Bytes
ce1057b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""Run single-case PanCancerSeg nnUNet CT inference and visualization."""

import argparse
import shutil
import tempfile
from pathlib import Path

import numpy as np
import SimpleITK as sitk
import torch

from visualize import generate_outputs


CANCER_CONFIGS = {
    "kidney_cancer": {
        "dataset_id": 102,
        "dataset_name": "Dataset102_Kidney",
        "display_name": "Kidney cancer",
        "wl": 40,
        "ww": 400,
        "color": (255, 0, 0),
    },
    "liver_cancer": {
        "dataset_id": 103,
        "dataset_name": "Dataset103_Liver",
        "display_name": "Liver cancer",
        "wl": 40,
        "ww": 400,
        "color": (255, 0, 0),
    },
    "pancreatic_cancer": {
        "dataset_id": 104,
        "dataset_name": "Dataset104_Pancreas",
        "display_name": "Pancreatic cancer",
        "wl": 40,
        "ww": 400,
        "color": (255, 0, 0),
    },
    "lung_cancer": {
        "dataset_id": 105,
        "dataset_name": "Dataset105_Lung",
        "display_name": "Lung cancer",
        "wl": -600,
        "ww": 1500,
        "color": (255, 0, 0),
    },
}

CANCER_TYPE_ALIASES = {
    "kidney": "kidney_cancer",
    "liver": "liver_cancer",
    "pancreas": "pancreatic_cancer",
    "lung": "lung_cancer",
}

TRAINER_NAME = "nnUNetTrainerWandb2000"
PLANS_NAME = "nnUNetResEncUNetMPlans"
CONFIGURATION = "3d_fullres"
CHECKPOINT_NAME = "checkpoint_best.pth"


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run one PanCancerSeg cancer-specific nnUNet model on a single NIfTI image."
    )
    parser.add_argument("--input", required=True, help="Path to a single .nii.gz image")
    parser.add_argument(
        "--cancer_type",
        required=True,
        help=(
            "Cancer-specific model to use. "
            f"Canonical values: {', '.join(sorted(CANCER_CONFIGS))}. "
            f"Legacy aliases still accepted: {', '.join(sorted(CANCER_TYPE_ALIASES))}."
        ),
    )
    parser.add_argument(
        "--model_dir",
        required=True,
        help="Path to nnUNet results directory containing DatasetXXX_* folders",
    )
    parser.add_argument("--output_dir", default="./output", help="Where to save results")
    parser.add_argument("--fps", type=int, default=10, help="Video frames per second")
    parser.add_argument("--device", choices=["cuda", "cpu"], default="cuda")
    return parser.parse_args()


def main():
    args = parse_args()
    args.cancer_type = normalize_cancer_type(args.cancer_type)
    input_path = Path(args.input).expanduser().resolve()
    model_dir = Path(args.model_dir).expanduser().resolve()
    output_dir = Path(args.output_dir).expanduser().resolve()

    if not input_path.exists():
        raise FileNotFoundError(f"Input image not found: {input_path}")
    if input_path.name.startswith("._") or not input_path.name.endswith(".nii.gz"):
        raise ValueError(f"Expected a .nii.gz image, got: {input_path.name}")
    if not model_dir.exists():
        raise FileNotFoundError(f"Model directory not found: {model_dir}")
    if args.device == "cuda" and not torch.cuda.is_available():
        raise RuntimeError(
            "CUDA was requested but torch.cuda.is_available() is False. "
            "Use --device cpu or install CUDA-enabled PyTorch."
        )
    if args.fps <= 0:
        raise ValueError("--fps must be a positive integer")

    output_dir.mkdir(parents=True, exist_ok=True)
    config = CANCER_CONFIGS[args.cancer_type]
    case_id = resolve_case_id(input_path)

    install_custom_trainer()
    model_folder = resolve_model_folder(model_dir, config["dataset_name"])

    with tempfile.TemporaryDirectory(prefix="pancancerseg_") as tmp:
        tmp_path = Path(tmp)
        tmp_input_dir = tmp_path / "input"
        tmp_output_dir = tmp_path / "prediction"
        tmp_input_dir.mkdir()
        tmp_output_dir.mkdir()

        nnunet_input = tmp_input_dir / f"{case_id}_0000.nii.gz"
        symlink_or_copy(input_path, nnunet_input)

        run_nnunet_prediction(
            model_folder=model_folder,
            input_dir=tmp_input_dir,
            output_dir=tmp_output_dir,
            device=args.device,
        )

        raw_seg = tmp_output_dir / f"{case_id}.nii.gz"
        if not raw_seg.exists():
            produced = sorted(tmp_output_dir.glob("*.nii.gz"))
            raise FileNotFoundError(
                f"nnUNet did not write the expected segmentation {raw_seg}. "
                f"Found: {[p.name for p in produced]}"
            )

        seg_path = output_dir / f"{case_id}_seg.nii.gz"
        shutil.copy2(raw_seg, seg_path)

    viz_outputs = generate_outputs(
        image_path=input_path,
        mask_path=seg_path,
        output_dir=output_dir,
        case_name=case_id,
        cancer_type=config["display_name"],
        wl=config["wl"],
        ww=config["ww"],
        color=config["color"],
        alpha=0.5,
        fps=args.fps,
    )

    positive_voxels, tumor_volume_ml = summarize_segmentation(seg_path)
    print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml)


def resolve_case_id(input_path):
    name = input_path.name
    if not name.endswith(".nii.gz"):
        raise ValueError(f"Expected a .nii.gz image, got: {name}")
    case_id = name[: -len(".nii.gz")]
    if case_id.endswith("_0000"):
        case_id = case_id[: -len("_0000")]
    if not case_id:
        raise ValueError(f"Could not resolve a case ID from: {input_path}")
    return case_id


def normalize_cancer_type(cancer_type):
    cancer_type = cancer_type.strip().lower()
    normalized = CANCER_TYPE_ALIASES.get(cancer_type, cancer_type)
    if normalized not in CANCER_CONFIGS:
        valid = sorted(list(CANCER_CONFIGS) + list(CANCER_TYPE_ALIASES))
        raise ValueError(
            f"Unsupported --cancer_type '{cancer_type}'. Valid values: {', '.join(valid)}"
        )
    return normalized


def install_custom_trainer():
    import nnunetv2

    src = Path(__file__).resolve().parent / "trainers" / f"{TRAINER_NAME}.py"
    if not src.exists():
        raise FileNotFoundError(f"Custom trainer file is missing: {src}")

    variants_dir = Path(nnunetv2.__path__[0]) / "training" / "nnUNetTrainer" / "variants"
    variants_dir.mkdir(parents=True, exist_ok=True)
    dst = variants_dir / src.name

    if dst.exists() or dst.is_symlink():
        try:
            if dst.resolve() == src.resolve():
                return dst
        except OSError:
            pass
        dst.unlink()

    try:
        dst.symlink_to(src.resolve())
    except (OSError, NotImplementedError):
        shutil.copy2(src, dst)
    print(f"Installed custom trainer: {dst}")
    return dst


def resolve_model_folder(model_dir, dataset_name):
    model_folder = (
        model_dir
        / dataset_name
        / f"{TRAINER_NAME}__{PLANS_NAME}__{CONFIGURATION}"
    )
    checkpoint = model_folder / "fold_0" / CHECKPOINT_NAME
    if not checkpoint.exists():
        raise FileNotFoundError(
            f"Expected checkpoint not found: {checkpoint}. "
            "Check --model_dir and make sure the trained weights are downloaded."
        )
    return model_folder


def symlink_or_copy(src, dst):
    try:
        dst.symlink_to(src.resolve())
    except (OSError, NotImplementedError):
        shutil.copy2(src, dst)


def run_nnunet_prediction(model_folder, input_dir, output_dir, device):
    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False,
        perform_everything_on_device=(device == "cuda"),
        device=torch.device(device),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True,
    )
    predictor.initialize_from_trained_model_folder(
        str(model_folder),
        use_folds=(0,),
        checkpoint_name=CHECKPOINT_NAME,
    )
    predictor.predict_from_files(
        str(input_dir),
        str(output_dir),
        save_probabilities=False,
        overwrite=True,
        num_processes_preprocessing=1,
        num_processes_segmentation_export=1,
        folder_with_segs_from_prev_stage=None,
        num_parts=1,
        part_id=0,
    )


def summarize_segmentation(seg_path):
    seg = sitk.ReadImage(str(seg_path))
    seg_arr = sitk.GetArrayFromImage(seg)
    positive_voxels = int(np.count_nonzero(seg_arr))
    spacing_x, spacing_y, spacing_z = seg.GetSpacing()
    tumor_volume_ml = positive_voxels * spacing_x * spacing_y * spacing_z / 1000.0
    return positive_voxels, tumor_volume_ml


def print_summary(seg_path, viz_outputs, positive_voxels, tumor_volume_ml):
    print("\nPanCancerSeg inference complete")
    print(f"Segmentation mask : {seg_path}")
    print("Slice PNGs        :")
    for label, path in viz_outputs["slices"].items():
        print(f"  {label:9s} : {path}")
    print(f"Overlay video     : {viz_outputs['video']}")
    print(f"Positive voxels   : {positive_voxels}")
    print(f"Tumor volume      : {tumor_volume_ml:.3f} mL")


if __name__ == "__main__":
    main()