import os # os.environ["nnUNet_raw"] = "/home/head_neck/algorithm-template_2/nnunet_raw" # os.environ["nnUNet_preprocessed"] = "/home/head_neck/algorithm-template_2/nnunet_preprocessed" # os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_5" os.environ["nnUNet_raw"] = "./nnunet_raw" os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed" # os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_task_2" # os.environ["nnUNet_results"] = "./nnunet_results_task_2" from typing import Dict import tempfile import subprocess import SimpleITK as sitk from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ save_json # import torch import numpy as np from base_algorithm import BaseSynthradAlgorithm from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified import torch import shutil import os os.environ["OPENBLAS_NUM_THREADS"] = "1" force_cpu = os.getenv("FORCE_CPU", "0") == "1" device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu") class SynthradAlgorithm2(BaseSynthradAlgorithm): """ This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image. Author: Suraj Pai (b.pai@maastrichtuniversity.nl) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image: assert list(input_dict.keys()) == ["image", "mask", "region"] region = input_dict["region"] mr_sitk = input_dict["image"] mask_sitk = input_dict["mask"] mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32") mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32") mr_np[mask_np == 0] = 0 preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np) preprocessed_mr_sitk.CopyInformation(mr_sitk) if region == "Head and Neck": dataset_name = "Dataset542" result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" plans_path = "./542_gt_nnUNetResEncUNetLPlans.json" if region == "Abdomen": dataset_name = "Dataset540" result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" plans_path = "./540_gt_nnUNetResEncUNetLPlans.json" if region == "Thorax": dataset_name = "Dataset544" result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" plans_path = "./544_gt_nnUNetResEncUNetLPlans.json" # predictor = nnUNetPredictor( # tile_step_size=0.5, # use_gaussian=True, # use_mirroring=True, # perform_everything_on_device=True, # device=torch.device('cuda', 0), # verbose=True, # verbose_preprocessing=True, # allow_tqdm=True # ) predictor = nnUNetPredictor( tile_step_size=0.50, use_gaussian=True, use_mirroring=False, perform_everything_on_device=(device.type == "cuda"), device=device, verbose=True, ) predictor.initialize_from_trained_model_folder( join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'), # use_folds=(0, 1, 2, 3, 4), use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) sitk_spacing = mr_sitk.GetSpacing() sitk_origin = mr_sitk.GetOrigin() sitk_dir = mr_sitk.GetDirection() props = { 'sitk_stuff': { 'spacing': tuple(sitk_spacing), 'origin': tuple(sitk_origin), 'direction': tuple(sitk_dir), }, 'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]] } img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32) img = np.expand_dims(img, 0) ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False) pred_path = "./TRUNCATED.nii.gz" pred_sitk = sitk.ReadImage(pred_path) ct_mean, ct_std = get_ct_normalisation_values(plans_path) mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8) pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk) os.remove(pred_path) shutil.rmtree("./imagesTs", ignore_errors=True) shutil.rmtree("./predictions", ignore_errors=True) return pred_sitk # if __name__ == '__main__': # # Run the algorithm on the default input and output paths specified in BaseSynthradAlgorithm. # SynthradAlgorithm().process() # if __name__ == '__main__': # # # test Brain # # #start_time = time.time() # # # img_s_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mr.mha' # # # img_m_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mask.mha' # # # img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_brain/p2p3D/ensemble_t1_brain_final_e2/test_predictions/1BA001/ct_fakeB.mha' # # # region = 'Head and Neck' # # # test Pelvis # img_s_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_0000.mha' # img_m_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_mask.mha' # # #img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_pelvis/p2p3D/exp6_data_7_size_256_256_56_batch_3_lr_0.0002_aug3d_fold0_resumed_multiscale_2dsample_5/epoch_best/epoch75/test_predictions/1PA005/ct_fakeB.mha' # region = 'Abdomen' # # start test # input_dict = { # "image": sitk.ReadImage(img_s_path), # "mask": sitk.ReadImage(img_m_path), # "region": region # } # algorithm = SynthradAlgorithm() # img_t = algorithm.predict(input_dict)