osbm commited on
Commit
fc1b682
1 Parent(s): c3047a5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +299 -0
main.py CHANGED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # Copyright 2022 Diagnostic Image Analysis Group, Radboudumc, Nijmegen, The Netherlands
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ import pickle
20
+ import subprocess
21
+ from pathlib import Path
22
+ from typing import Union
23
+
24
+ import numpy as np
25
+ import SimpleITK as sitk
26
+ from evalutils import SegmentationAlgorithm
27
+ from evalutils.validators import (UniqueImagesValidator,
28
+ UniquePathIndicesValidator)
29
+ from picai_baseline.nnunet.softmax_export import \
30
+ save_softmax_nifti_from_softmax
31
+ from picai_prep.data_utils import atomic_image_write
32
+ from picai_prep.preprocessing import Sample, crop_or_pad
33
+ from report_guided_annotation import extract_lesion_candidates
34
+
35
+
36
+ class MissingSequenceError(Exception):
37
+ """Exception raised when a sequence is missing."""
38
+
39
+ def __init__(self, name, folder):
40
+ message = f"Could not find scan for {name} in {folder} (files: {os.listdir(folder)})"
41
+ super().__init__(message)
42
+
43
+
44
+ class MultipleScansSameSequencesError(Exception):
45
+ """Exception raised when multiple scans of the same sequences are provided."""
46
+
47
+ def __init__(self, name, folder):
48
+ message = f"Found multiple scans for {name} in {folder} (files: {os.listdir(folder)})"
49
+ super().__init__(message)
50
+
51
+
52
+ def convert_to_original_extent(
53
+ pred: np.ndarray,
54
+ pkl_path: Union[Path, str],
55
+ dst_path: Union[Path, str]
56
+ ) -> sitk.Image:
57
+ # convert to nnUNet's internal softmax format
58
+ pred = np.array([1-pred, pred])
59
+
60
+ # read physical properties of current case
61
+ with open(pkl_path, "rb") as fp:
62
+ properties = pickle.load(fp)
63
+
64
+ # let nnUNet resample to original physical space
65
+ save_softmax_nifti_from_softmax(
66
+ segmentation_softmax=pred,
67
+ out_fname=str(dst_path),
68
+ properties_dict=properties,
69
+ )
70
+
71
+ # now each voxel in softmax.nii.gz corresponds to the same voxel in the original (T2-weighted) scan
72
+ pred_ensemble = sitk.ReadImage(str(dst_path))
73
+
74
+ return pred_ensemble
75
+
76
+
77
+ def extract_lesion_candidates_cropped(pred: np.ndarray, threshold: Union[str, float]):
78
+ size = pred.shape
79
+ pred = crop_or_pad(pred, (20, 384, 384))
80
+ pred = crop_or_pad(pred, size)
81
+ return extract_lesion_candidates(pred, threshold=threshold)[0]
82
+
83
+
84
+ class csPCaAlgorithm(SegmentationAlgorithm):
85
+ """
86
+ Wrapper to deploy trained baseline nnU-Net model from
87
+ https://github.com/DIAGNijmegen/picai_baseline as a
88
+ grand-challenge.org algorithm.
89
+ """
90
+
91
+ def __init__(self):
92
+ super().__init__(
93
+ validators=dict(
94
+ input_image=(
95
+ UniqueImagesValidator(),
96
+ UniquePathIndicesValidator(),
97
+ )
98
+ ),
99
+ )
100
+
101
+ # input / output paths for algorithm
102
+ self.image_input_dirs = [
103
+ "/input/images/transverse-t2-prostate-mri",
104
+ "/input/images/transverse-adc-prostate-mri",
105
+ "/input/images/transverse-hbv-prostate-mri",
106
+ ]
107
+ self.scan_paths = []
108
+ self.cspca_detection_map_path = Path("/output/images/cspca-detection-map/cspca_detection_map.mha")
109
+ self.case_confidence_path = Path("/output/cspca-case-level-likelihood.json")
110
+
111
+ # input / output paths for nnUNet
112
+ self.nnunet_inp_dir = Path("/opt/algorithm/nnunet/input")
113
+ self.nnunet_out_dir = Path("/opt/algorithm/nnunet/output")
114
+ self.nnunet_results = Path("/opt/algorithm/results")
115
+
116
+ # ensure required folders exist
117
+ self.nnunet_inp_dir.mkdir(exist_ok=True, parents=True)
118
+ self.nnunet_out_dir.mkdir(exist_ok=True, parents=True)
119
+ self.cspca_detection_map_path.parent.mkdir(exist_ok=True, parents=True)
120
+
121
+ # input validation for multiple inputs
122
+ scan_glob_format = "*.mha"
123
+ for folder in self.image_input_dirs:
124
+ file_paths = list(Path(folder).glob(scan_glob_format))
125
+ if len(file_paths) == 0:
126
+ raise MissingSequenceError(name=folder.split("/")[-1], folder=folder)
127
+ elif len(file_paths) >= 2:
128
+ raise MultipleScansSameSequencesError(name=folder.split("/")[-1], folder=folder)
129
+ else:
130
+ # append scan path to algorithm input paths
131
+ self.scan_paths += [file_paths[0]]
132
+
133
+ def preprocess_input(self):
134
+ """Preprocess input images to nnUNet Raw Data Archive format"""
135
+ # set up Sample
136
+ sample = Sample(
137
+ scans=[
138
+ sitk.ReadImage(str(path))
139
+ for path in self.scan_paths
140
+ ],
141
+ )
142
+
143
+ # perform preprocessing
144
+ sample.preprocess()
145
+
146
+ # write preprocessed scans to nnUNet input directory
147
+ for i, scan in enumerate(sample.scans):
148
+ path = self.nnunet_inp_dir / f"scan_{i:04d}.nii.gz"
149
+ atomic_image_write(scan, path)
150
+
151
+ # Note: need to overwrite process because of flexible inputs, which requires custom data loading
152
+ def process(self):
153
+ """
154
+ Load bpMRI scans and generate detection map for clinically significant prostate cancer
155
+ """
156
+ # perform preprocessing
157
+ self.preprocess_input()
158
+
159
+ # perform inference using nnUNet
160
+ pred_ensemble = None
161
+ ensemble_count = 0
162
+ for trainer in [
163
+ "nnUNetTrainerV2_Loss_FL_and_CE_checkpoints",
164
+ ]:
165
+ # predict sample
166
+ self.predict(
167
+ task="Task2203_picai_baseline",
168
+ trainer=trainer,
169
+ checkpoint="model_best",
170
+ )
171
+
172
+ # read softmax prediction
173
+ pred_path = str(self.nnunet_out_dir / "scan.npz")
174
+ pred = np.array(np.load(pred_path)['softmax'][1]).astype('float32')
175
+ os.remove(pred_path)
176
+ if pred_ensemble is None:
177
+ pred_ensemble = pred
178
+ else:
179
+ pred_ensemble += pred
180
+ ensemble_count += 1
181
+
182
+ # average the accumulated confidence scores
183
+ pred_ensemble /= ensemble_count
184
+
185
+ # the prediction is currently at the size and location of the nnU-Net preprocessed
186
+ # scan, so we need to convert it to the original extent before we continue
187
+ pred_ensemble = convert_to_original_extent(
188
+ pred=pred_ensemble,
189
+ pkl_path=self.nnunet_out_dir / "scan.pkl",
190
+ dst_path=self.nnunet_out_dir / "softmax.nii.gz",
191
+ )
192
+
193
+ # extract lesion candidates from softmax prediction
194
+ # note: we set predictions outside the central 81 x 192 x 192 mm to zero, as this is far outside the prostate
195
+ detection_map = extract_lesion_candidates_cropped(
196
+ pred=sitk.GetArrayFromImage(pred_ensemble),
197
+ threshold="dynamic"
198
+ )
199
+
200
+ # convert detection map to a SimpleITK image and infuse the physical metadata of original T2-weighted scan
201
+ reference_scan_original_path = str(self.scan_paths[0])
202
+ reference_scan_original = sitk.ReadImage(reference_scan_original_path)
203
+ detection_map: sitk.Image = sitk.GetImageFromArray(detection_map)
204
+ detection_map.CopyInformation(reference_scan_original)
205
+
206
+ # save prediction to output folder
207
+ atomic_image_write(detection_map, str(self.cspca_detection_map_path))
208
+
209
+ # save case-level likelihood
210
+ with open(self.case_confidence_path, 'w') as fp:
211
+ json.dump(float(np.max(sitk.GetArrayFromImage(detection_map))), fp)
212
+
213
+ def predict(self, task, trainer="nnUNetTrainerV2", network="3d_fullres",
214
+ checkpoint="model_final_checkpoint", folds="0,1,2,3,4", store_probability_maps=True,
215
+ disable_augmentation=False, disable_patch_overlap=False):
216
+ """
217
+ Use trained nnUNet network to generate segmentation masks
218
+ """
219
+
220
+ # Set environment variables
221
+ os.environ['RESULTS_FOLDER'] = str(self.nnunet_results)
222
+
223
+ # Run prediction script
224
+ cmd = [
225
+ 'nnUNet_predict',
226
+ '-t', task,
227
+ '-i', str(self.nnunet_inp_dir),
228
+ '-o', str(self.nnunet_out_dir),
229
+ '-m', network,
230
+ '-tr', trainer,
231
+ '--num_threads_preprocessing', '2',
232
+ '--num_threads_nifti_save', '1'
233
+ ]
234
+
235
+ if folds:
236
+ cmd.append('-f')
237
+ cmd.extend(folds.split(','))
238
+
239
+ if checkpoint:
240
+ cmd.append('-chk')
241
+ cmd.append(checkpoint)
242
+
243
+ if store_probability_maps:
244
+ cmd.append('--save_npz')
245
+
246
+ if disable_augmentation:
247
+ cmd.append('--disable_tta')
248
+
249
+ if disable_patch_overlap:
250
+ cmd.extend(['--step_size', '1'])
251
+
252
+ print(subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True).stdout)
253
+
254
+
255
+ def predict(input):
256
+
257
+
258
+
259
+ def predict(input_file):
260
+ print("Making prediction")
261
+ image = sitk.ReadImage(input_file)
262
+
263
+ os.makedirs("./input/images/transverse-t2-prostate-mri/", exist_ok=True)
264
+ os.makedirs("./output/images/softmax-prostate-peripheral-zone-segmentation", exist_ok=True)
265
+ os.makedirs("./output/images/softmax-prostate-central-gland-segmentation", exist_ok=True)
266
+ os.makedirs("./output/images/prostate-zonal-segmentation", exist_ok=True)
267
+
268
+ sitk.WriteImage(image, "./input/images/transverse-t2-prostate-mri/1009_2222_t2w.mha")
269
+
270
+ csPCaAlgorithm().process()
271
+
272
+
273
+ return (
274
+ "./output/images/softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz.mha",
275
+ "./output/images/softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz.mha",
276
+ "./output/images/prostate-zonal-segmentation/prostate_gland.mha",
277
+ )
278
+
279
+ print("Starting interface")
280
+ demo = gr.Interface(
281
+ title="Hevi.AI prostate inference",
282
+ description="description text",
283
+ article="article text",
284
+ fn=predict,
285
+ inputs=gr.File(label="input T2 image (3d)", file_count="single", file_types=[".mha", ".nii.gz", ".nii"]),
286
+ outputs=[
287
+ gr.File(label="softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz"),
288
+ gr.File(label="softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz"),
289
+ gr.File(label="prostate-zonal-segmentation/prostate_gland"),
290
+ ],
291
+ cache_examples=False,
292
+ # outputs=gr.Label(num_top_classes=3),
293
+ allow_flagging="never",
294
+ concurrency_limit=1,
295
+ )
296
+ print("Launching interface")
297
+ demo.queue()
298
+ demo.launch(server_name="0.0.0.0", server_port=7860)
299
+