File size: 11,539 Bytes
fc1b682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a5ed5f
 
 
fc1b682
 
1a5ed5f
 
fc1b682
 
1a5ed5f
 
 
fc1b682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dccc77b
fc1b682
dccc77b
 
 
fc1b682
 
dccc77b
 
fc1b682
 
 
 
dccc77b
 
 
fc1b682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dccc77b
 
 
 
 
 
fc1b682
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import gradio as gr

#  Copyright 2022 Diagnostic Image Analysis Group, Radboudumc, Nijmegen, The Netherlands
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import json
import os
import pickle
import subprocess
from pathlib import Path
from typing import Union

import numpy as np
import SimpleITK as sitk
from evalutils import SegmentationAlgorithm
from evalutils.validators import (UniqueImagesValidator,
                                  UniquePathIndicesValidator)
from picai_baseline.nnunet.softmax_export import \
    save_softmax_nifti_from_softmax
from picai_prep.data_utils import atomic_image_write
from picai_prep.preprocessing import Sample, crop_or_pad
from report_guided_annotation import extract_lesion_candidates


class MissingSequenceError(Exception):
    """Exception raised when a sequence is missing."""

    def __init__(self, name, folder):
        message = f"Could not find scan for {name} in {folder} (files: {os.listdir(folder)})"
        super().__init__(message)


class MultipleScansSameSequencesError(Exception):
    """Exception raised when multiple scans of the same sequences are provided."""

    def __init__(self, name, folder):
        message = f"Found multiple scans for {name} in {folder} (files: {os.listdir(folder)})"
        super().__init__(message)


def convert_to_original_extent(
    pred: np.ndarray,
    pkl_path: Union[Path, str],
    dst_path: Union[Path, str]
) -> sitk.Image:
    # convert to nnUNet's internal softmax format
    pred = np.array([1-pred, pred])

    # read physical properties of current case
    with open(pkl_path, "rb") as fp:
        properties = pickle.load(fp)

    # let nnUNet resample to original physical space
    save_softmax_nifti_from_softmax(
        segmentation_softmax=pred,
        out_fname=str(dst_path),
        properties_dict=properties,
    )

    # now each voxel in softmax.nii.gz corresponds to the same voxel in the original (T2-weighted) scan
    pred_ensemble = sitk.ReadImage(str(dst_path))

    return pred_ensemble


def extract_lesion_candidates_cropped(pred: np.ndarray, threshold: Union[str, float]):
    size = pred.shape
    pred = crop_or_pad(pred, (20, 384, 384))
    pred = crop_or_pad(pred, size)
    return extract_lesion_candidates(pred, threshold=threshold)[0]


class csPCaAlgorithm(SegmentationAlgorithm):
    """
    Wrapper to deploy trained baseline nnU-Net model from
    https://github.com/DIAGNijmegen/picai_baseline as a
    grand-challenge.org algorithm.
    """

    def __init__(self):
        super().__init__(
            validators=dict(
                input_image=(
                    UniqueImagesValidator(),
                    UniquePathIndicesValidator(),
                )
            ),
        )

        # input / output paths for algorithm
        self.image_input_dirs = [
            "./input/images/transverse-t2-prostate-mri",
            "./input/images/transverse-adc-prostate-mri",
            "./input/images/transverse-hbv-prostate-mri",
        ]
        self.scan_paths = []
        self.cspca_detection_map_path = Path("./output/images/cspca-detection-map/cspca_detection_map.mha")
        self.case_confidence_path = Path("./output/cspca-case-level-likelihood.json")

        # input / output paths for nnUNet
        self.nnunet_inp_dir = Path("./nnunet/input")
        self.nnunet_out_dir = Path("./nnunet/output")
        self.nnunet_results = Path("./results")

        # ensure required folders exist
        self.nnunet_inp_dir.mkdir(exist_ok=True, parents=True)
        self.nnunet_out_dir.mkdir(exist_ok=True, parents=True)
        self.cspca_detection_map_path.parent.mkdir(exist_ok=True, parents=True)

        # input validation for multiple inputs
        scan_glob_format = "*.mha"
        for folder in self.image_input_dirs:
            file_paths = list(Path(folder).glob(scan_glob_format))
            if len(file_paths) == 0:
                raise MissingSequenceError(name=folder.split("/")[-1], folder=folder)
            elif len(file_paths) >= 2:
                raise MultipleScansSameSequencesError(name=folder.split("/")[-1], folder=folder)
            else:
                # append scan path to algorithm input paths
                self.scan_paths += [file_paths[0]]

    def preprocess_input(self):
        """Preprocess input images to nnUNet Raw Data Archive format"""
        # set up Sample
        sample = Sample(
            scans=[
                sitk.ReadImage(str(path))
                for path in self.scan_paths
            ],
        )

        # perform preprocessing
        sample.preprocess()

        # write preprocessed scans to nnUNet input directory
        for i, scan in enumerate(sample.scans):
            path = self.nnunet_inp_dir / f"scan_{i:04d}.nii.gz"
            atomic_image_write(scan, path)

    # Note: need to overwrite process because of flexible inputs, which requires custom data loading
    def process(self):
        """
        Load bpMRI scans and generate detection map for clinically significant prostate cancer
        """
        # perform preprocessing
        self.preprocess_input()

        # perform inference using nnUNet
        pred_ensemble = None
        ensemble_count = 0
        for trainer in [
            "nnUNetTrainerV2_Loss_FL_and_CE_checkpoints",
        ]:
            # predict sample
            self.predict(
                task="Task2203_picai_baseline",
                trainer=trainer,
                checkpoint="model_best",
            )

            # read softmax prediction
            pred_path = str(self.nnunet_out_dir / "scan.npz")
            pred = np.array(np.load(pred_path)['softmax'][1]).astype('float32')
            os.remove(pred_path)
            if pred_ensemble is None:
                pred_ensemble = pred
            else:
                pred_ensemble += pred
            ensemble_count += 1

        # average the accumulated confidence scores
        pred_ensemble /= ensemble_count

        # the prediction is currently at the size and location of the nnU-Net preprocessed
        # scan, so we need to convert it to the original extent before we continue
        pred_ensemble = convert_to_original_extent(
            pred=pred_ensemble,
            pkl_path=self.nnunet_out_dir / "scan.pkl",
            dst_path=self.nnunet_out_dir / "softmax.nii.gz",
        )

        # extract lesion candidates from softmax prediction
        # note: we set predictions outside the central 81 x 192 x 192 mm to zero, as this is far outside the prostate
        detection_map = extract_lesion_candidates_cropped(
            pred=sitk.GetArrayFromImage(pred_ensemble),
            threshold="dynamic"
        )

        # convert detection map to a SimpleITK image and infuse the physical metadata of original T2-weighted scan
        reference_scan_original_path = str(self.scan_paths[0])
        reference_scan_original = sitk.ReadImage(reference_scan_original_path)
        detection_map: sitk.Image = sitk.GetImageFromArray(detection_map)
        detection_map.CopyInformation(reference_scan_original)

        # save prediction to output folder
        atomic_image_write(detection_map, str(self.cspca_detection_map_path))

        # save case-level likelihood
        with open(self.case_confidence_path, 'w') as fp:
            json.dump(float(np.max(sitk.GetArrayFromImage(detection_map))), fp)

    def predict(self, task, trainer="nnUNetTrainerV2", network="3d_fullres",
                checkpoint="model_final_checkpoint", folds="0,1,2,3,4", store_probability_maps=True,
                disable_augmentation=False, disable_patch_overlap=False):
        """
        Use trained nnUNet network to generate segmentation masks
        """

        # Set environment variables
        os.environ['RESULTS_FOLDER'] = str(self.nnunet_results)

        # Run prediction script
        cmd = [
            'nnUNet_predict',
            '-t', task,
            '-i', str(self.nnunet_inp_dir),
            '-o', str(self.nnunet_out_dir),
            '-m', network,
            '-tr', trainer,
            '--num_threads_preprocessing', '2',
            '--num_threads_nifti_save', '1'
        ]

        if folds:
            cmd.append('-f')
            cmd.extend(folds.split(','))

        if checkpoint:
            cmd.append('-chk')
            cmd.append(checkpoint)

        if store_probability_maps:
            cmd.append('--save_npz')

        if disable_augmentation:
            cmd.append('--disable_tta')

        if disable_patch_overlap:
            cmd.extend(['--step_size', '1'])

        print(subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True).stdout)


def predict(t2_file, adc_file, hbv_file):
    print("Making prediction")
    t2_file = sitk.ReadImage(t2_file)
    adc_file = sitk.ReadImage(adc_file)
    hbv_file = sitk.ReadImage(hbv_file)
    
    os.makedirs("./input/images/transverse-t2-prostate-mri/", exist_ok=True)
    os.makedirs("./input/images/transverse-adc-prostate-mri/", exist_ok=True)
    os.makedirs("./input/images/transverse-hbv-prostate-mri/", exist_ok=True)
    os.makedirs("./output/images/softmax-prostate-peripheral-zone-segmentation", exist_ok=True)
    os.makedirs("./output/images/softmax-prostate-central-gland-segmentation", exist_ok=True)
    os.makedirs("./output/images/prostate-zonal-segmentation", exist_ok=True)
    
    sitk.WriteImage(t2_file, "./input/images/transverse-t2-prostate-mri/1009_2222_t2w.mha")
    sitk.WriteImage(adc_file, "./input/images/transverse-t2-prostate-mri/1009_2222_adc.mha")
    sitk.WriteImage(hbv_file, "./input/images/transverse-t2-prostate-mri/1009_2222_hbv.mha")
    
    csPCaAlgorithm().process()
    
    return (
        "./output/images/softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz.mha",
        "./output/images/softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz.mha",
        "./output/images/prostate-zonal-segmentation/prostate_gland.mha",
    )

print("Starting interface")
demo = gr.Interface(
    title="Hevi.AI prostate inference",
    description="description text",
    article="article text",
    fn=predict,
    inputs=[
        gr.File(label="input T2 image (3d)", file_count="single", file_types=[".mha", ".nii.gz", ".nii"]),
        gr.File(label="input ADC image (3d)", file_count="single", file_types=[".mha", ".nii.gz", ".nii"]),
        gr.File(label="input HBV image (3d)", file_count="single", file_types=[".mha", ".nii.gz", ".nii"]),
    ],
        
    outputs=[
        gr.File(label="softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz"),
        gr.File(label="softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz"),
        gr.File(label="prostate-zonal-segmentation/prostate_gland"),
    ],
    cache_examples=False,
    # outputs=gr.Label(num_top_classes=3),
    allow_flagging="never",
    concurrency_limit=1,
)
print("Launching interface")
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)