|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
os.environ["nnUNet_raw"] = "./nnunet_raw" |
|
|
os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed" |
|
|
os.environ["nnUNet_results"] = "./nnunet_results" |
|
|
from typing import Dict |
|
|
import tempfile |
|
|
import subprocess |
|
|
import SimpleITK as sitk |
|
|
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor |
|
|
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ |
|
|
save_json |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from base_algorithm import BaseSynthradAlgorithm |
|
|
from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified |
|
|
|
|
|
import torch |
|
|
import shutil |
|
|
|
|
|
import os |
|
|
|
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SynthradAlgorithm(BaseSynthradAlgorithm): |
|
|
""" |
|
|
This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image. |
|
|
|
|
|
Author: Suraj Pai (b.pai@maastrichtuniversity.nl) |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image: |
|
|
assert list(input_dict.keys()) == ["image", "mask", "region"] |
|
|
|
|
|
region = input_dict["region"] |
|
|
mr_sitk = input_dict["image"] |
|
|
mask_sitk = input_dict["mask"] |
|
|
|
|
|
mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32") |
|
|
mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32") |
|
|
|
|
|
mr_np[mask_np == 0] = 0 |
|
|
|
|
|
preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np) |
|
|
preprocessed_mr_sitk.CopyInformation(mr_sitk) |
|
|
|
|
|
if region == "Head and Neck": |
|
|
dataset_name = "Dataset262" |
|
|
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
|
|
plans_path = "./262_gt_nnUNetResEncUNetLPlans.json" |
|
|
if region == "Abdomen": |
|
|
dataset_name = "Dataset260" |
|
|
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
|
|
plans_path = "./260_gt_nnUNetResEncUNetLPlans.json" |
|
|
if region == "Thorax": |
|
|
dataset_name = "Dataset264" |
|
|
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" |
|
|
plans_path = "./264_gt_nnUNetResEncUNetLPlans.json" |
|
|
|
|
|
|
|
|
|
|
|
predictor = nnUNetPredictor( |
|
|
tile_step_size=0.5, |
|
|
use_gaussian=True, |
|
|
use_mirroring=True, |
|
|
perform_everything_on_device=True, |
|
|
device=torch.device('cuda', 0), |
|
|
verbose=True, |
|
|
verbose_preprocessing=True, |
|
|
allow_tqdm=True |
|
|
) |
|
|
predictor.initialize_from_trained_model_folder( |
|
|
join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'), |
|
|
use_folds=(0, 1, 2, 3, 4), |
|
|
checkpoint_name='checkpoint_final.pth', |
|
|
) |
|
|
|
|
|
sitk_spacing = mr_sitk.GetSpacing() |
|
|
sitk_origin = mr_sitk.GetOrigin() |
|
|
sitk_dir = mr_sitk.GetDirection() |
|
|
|
|
|
props = { |
|
|
'sitk_stuff': { |
|
|
'spacing': tuple(sitk_spacing), |
|
|
'origin': tuple(sitk_origin), |
|
|
'direction': tuple(sitk_dir), |
|
|
}, |
|
|
'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]] |
|
|
} |
|
|
|
|
|
img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32) |
|
|
img = np.expand_dims(img, 0) |
|
|
ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False) |
|
|
|
|
|
pred_path = "./TRUNCATED.nii.gz" |
|
|
pred_sitk = sitk.ReadImage(pred_path) |
|
|
|
|
|
ct_mean, ct_std = get_ct_normalisation_values(plans_path) |
|
|
mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8) |
|
|
|
|
|
pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk) |
|
|
os.remove(pred_path) |
|
|
shutil.rmtree("./imagesTs", ignore_errors=True) |
|
|
shutil.rmtree("./predictions", ignore_errors=True) |
|
|
return pred_sitk |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
SynthradAlgorithm().process() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|