Spaces:
Paused
Paused
| import gradio as gr | |
| # create nnunet input types | |
| # run nnunet | |
| # export | |
| import os | |
| import pickle | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Union | |
| import numpy as np | |
| import SimpleITK as sitk | |
| from evalutils import SegmentationAlgorithm | |
| from evalutils.validators import (UniqueImagesValidator, | |
| UniquePathIndicesValidator) | |
| from picai_baseline.nnunet.softmax_export import \ | |
| save_softmax_nifti_from_softmax | |
| from picai_prep import atomic_image_write | |
| from picai_prep.preprocessing import (PreprocessingSettings, Sample, | |
| resample_to_reference_scan) | |
| class MissingSequenceError(Exception): | |
| """Exception raised when a sequence is missing.""" | |
| def __init__(self, name, folder): | |
| message = f"Could not find scan for {name} in {folder} (files: {os.listdir(folder)})" | |
| super().__init__(message) | |
| class MultipleScansSameSequencesError(Exception): | |
| """Exception raised when multiple scans of the same sequences are provided.""" | |
| def __init__(self, name, folder): | |
| message = f"Found multiple scans for {name} in {folder} (files: {os.listdir(folder)})" | |
| super().__init__(message) | |
| def convert_to_original_extent(pred: np.ndarray, pkl_path: Union[Path, str], dst_path: Union[Path, str]): | |
| # convert to nnUNet's internal softmax format | |
| pred = np.array([1-pred, pred]) | |
| # read physical properties of current case | |
| with open(pkl_path, "rb") as fp: | |
| properties = pickle.load(fp) | |
| # let nnUNet resample to original physical space | |
| save_softmax_nifti_from_softmax( | |
| segmentation_softmax=pred, | |
| out_fname=str(dst_path), | |
| properties_dict=properties, | |
| ) | |
| def strip_metadata(img: sitk.Image) -> None: | |
| for key in img.GetMetaDataKeys(): | |
| img.EraseMetaData(key) | |
| def overwrite_affine(fixed_img: sitk.Image, moving_img: sitk.Image) -> sitk.Image: | |
| moving_img.SetOrigin(fixed_img.GetOrigin()) | |
| moving_img.SetDirection(fixed_img.GetDirection()) | |
| moving_img.SetSpacing(fixed_img.GetSpacing()) | |
| return moving_img | |
| class ProstateSegmentationAlgorithm(SegmentationAlgorithm): | |
| """ | |
| Wrapper to deploy trained prostate segmentation nnU-Net model from | |
| https://github.com/DIAGNijmegen/picai_baseline as a | |
| grand-challenge.org algorithm. | |
| """ | |
| def __init__(self): | |
| super().__init__( | |
| validators=dict( | |
| input_image=( | |
| UniqueImagesValidator(), | |
| UniquePathIndicesValidator(), | |
| ) | |
| ), | |
| ) | |
| # input / output paths for algorithm | |
| self.input_dirs = [ | |
| "./input/images/transverse-t2-prostate-mri" | |
| ] | |
| self.scan_paths = [] | |
| self.prostate_segmentation_path_pz = Path("./output/images/softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz.mha") | |
| self.prostate_segmentation_path_tz = Path("./output/images/softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz.mha") | |
| self.prostate_segmentation_path = Path("./output/images/prostate-zonal-segmentation/prostate_gland.mha") | |
| # input / output paths for nnUNet | |
| self.nnunet_inp_dir = Path("./nnunet/input") | |
| self.nnunet_out_dir = Path("./nnunet/output") | |
| self.nnunet_results = Path("./results") | |
| # ensure required folders exist | |
| self.nnunet_inp_dir.mkdir(exist_ok=True, parents=True) | |
| self.nnunet_out_dir.mkdir(exist_ok=True, parents=True) | |
| self.prostate_segmentation_path_pz.parent.mkdir(exist_ok=True, parents=True) | |
| # input validation for multiple inputs | |
| scan_glob_format = "*.mha" | |
| for folder in self.input_dirs: | |
| file_paths = list(Path(folder).glob(scan_glob_format)) | |
| if len(file_paths) == 0: | |
| raise MissingSequenceError(name=folder.split("/")[-1], folder=folder) | |
| elif len(file_paths) >= 2: | |
| raise MultipleScansSameSequencesError(name=folder.split("/")[-1], folder=folder) | |
| else: | |
| # append scan path to algorithm input paths | |
| self.scan_paths += [file_paths[0]] | |
| def preprocess_input(self): | |
| """Preprocess input images to nnUNet Raw Data Archive format""" | |
| # set up Sample | |
| sample = Sample( | |
| scans=[ | |
| sitk.ReadImage(str(path)) | |
| for path in [self.scan_paths[0]] | |
| ], | |
| settings=PreprocessingSettings( | |
| physical_size=[81.0, 192.0, 192.0], | |
| crop_only=True | |
| ) | |
| ) | |
| # perform preprocessing | |
| sample.preprocess() | |
| # write preprocessed scans to nnUNet input directory | |
| for i, scan in enumerate(sample.scans): | |
| path = self.nnunet_inp_dir / f"scan_{i:04d}.nii.gz" | |
| atomic_image_write(scan, path) | |
| # Note: need to overwrite process because of flexible inputs, which requires custom data loading | |
| def process(self): | |
| """ | |
| Load bpMRI scans and segment the prostate glands | |
| """ | |
| # perform preprocessing | |
| self.preprocess_input() | |
| # perform inference using nnUNet | |
| self.predict( | |
| task="Task848_experiment48", | |
| trainer="nnUNetTrainerV2_MMS", | |
| checkpoint="model_best", | |
| folds="0" | |
| ) | |
| pred_path_prostate = str(self.nnunet_out_dir / "scan.npz") | |
| sm_arr = np.load(pred_path_prostate)['softmax'] | |
| pz_arr = np.array(sm_arr[1, :, :, :]).astype('float32') | |
| tz_arr = np.array(sm_arr[2, :, :, :]).astype('float32') | |
| # read postprocessed prediction | |
| pred_path = str(self.nnunet_out_dir / "scan.nii.gz") | |
| pred_postprocessed: sitk.Image = sitk.ReadImage(pred_path) | |
| # remove metadata to get rid of SimpleITK warning | |
| strip_metadata(pred_postprocessed) | |
| # save postprocessed prediction to output | |
| atomic_image_write(pred_postprocessed, self.prostate_segmentation_path, mkdir=True) | |
| for pred, save_path in [ | |
| (pz_arr, self.prostate_segmentation_path_pz), | |
| (tz_arr, self.prostate_segmentation_path_tz), | |
| ]: | |
| # the prediction is currently at the size and location of the nnU-Net preprocessed | |
| # scan, so we need to convert it to the original extent before we continue | |
| convert_to_original_extent( | |
| pred=pred, | |
| pkl_path=self.nnunet_out_dir / "scan.pkl", | |
| dst_path=self.nnunet_out_dir / "softmax.nii.gz", | |
| ) | |
| # now each voxel in softmax.nii.gz corresponds to the same voxel in the reference scan | |
| pred = sitk.ReadImage(str(self.nnunet_out_dir / "softmax.nii.gz")) | |
| # convert prediction to a SimpleITK image and infuse the physical metadata of the reference scan | |
| reference_scan_original_path = str(self.scan_paths[0]) | |
| reference_scan = sitk.ReadImage(reference_scan_original_path) | |
| pred = resample_to_reference_scan(pred, reference_scan_original=reference_scan) | |
| # clip small values to 0 to save disk space | |
| arr = sitk.GetArrayFromImage(pred) | |
| arr[arr < 1e-3] = 0 | |
| pred_clipped = sitk.GetImageFromArray(arr) | |
| pred_clipped.CopyInformation(pred) | |
| # remove metadata to get rid of SimpleITK warning | |
| strip_metadata(pred_clipped) | |
| # save prediction to output folder | |
| atomic_image_write(pred_clipped, save_path, mkdir=True) | |
| def predict(self, task, trainer="nnUNetTrainerV2", network="3d_fullres", | |
| checkpoint="model_final_checkpoint", folds="0,1,2,3,4", store_probability_maps=True, | |
| disable_augmentation=False, disable_patch_overlap=False): | |
| """ | |
| Use trained nnUNet network to generate segmentation masks | |
| """ | |
| # Set environment variables | |
| os.environ['RESULTS_FOLDER'] = str(self.nnunet_results) | |
| # Run prediction script | |
| cmd = [ | |
| 'nnUNet_predict', | |
| '-t', task, | |
| '-i', str(self.nnunet_inp_dir), | |
| '-o', str(self.nnunet_out_dir), | |
| '-m', network, | |
| '-tr', trainer, | |
| '--num_threads_preprocessing', '2', | |
| '--num_threads_nifti_save', '1' | |
| ] | |
| if folds: | |
| cmd.append('-f') | |
| cmd.extend(folds.split(',')) | |
| if checkpoint: | |
| cmd.append('-chk') | |
| cmd.append(checkpoint) | |
| if store_probability_maps: | |
| cmd.append('--save_npz') | |
| if disable_augmentation: | |
| cmd.append('--disable_tta') | |
| if disable_patch_overlap: | |
| cmd.extend(['--step_size', '1']) | |
| subprocess.check_call(cmd) | |
| def predict(input_file): | |
| print("Making prediction") | |
| image = sitk.ReadImage(input_file) | |
| os.makedirs("./input/images/transverse-t2-prostate-mri/", exist_ok=True) | |
| os.makedirs("./output/images/softmax-prostate-peripheral-zone-segmentation", exist_ok=True) | |
| os.makedirs("./output/images/softmax-prostate-central-gland-segmentation", exist_ok=True) | |
| os.makedirs("./output/images/prostate-zonal-segmentation", exist_ok=True) | |
| sitk.WriteImage(image, "./input/images/transverse-t2-prostate-mri/1009_2222_t2w.mha") | |
| ProstateSegmentationAlgorithm().process() | |
| return ( | |
| "./output/images/softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz.mha", | |
| "./output/images/softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz.mha", | |
| "./output/images/prostate-zonal-segmentation/prostate_gland.mha", | |
| ) | |
| print("Starting interface") | |
| demo = gr.Interface( | |
| title="Hevi.AI prostate inference", | |
| description="description text", | |
| article="article text", | |
| fn=predict, | |
| inputs=gr.File(label="input T2 image (3d)", file_count="single", file_types=[".mha", ".nii.gz", ".nii"]), | |
| outputs=[ | |
| gr.File(label="softmax-prostate-peripheral-zone-segmentation/prostate_gland_sm_pz"), | |
| gr.File(label="softmax-prostate-central-gland-segmentation/prostate_gland_sm_tz"), | |
| gr.File(label="prostate-zonal-segmentation/prostate_gland"), | |
| ], | |
| cache_examples=False, | |
| # outputs=gr.Label(num_top_classes=3), | |
| allow_flagging="never", | |
| concurrency_limit=1, | |
| ) | |
| print("Launching interface") | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |