Spaces:
Running
Running
| import os | |
| # os.environ["nnUNet_raw"] = "/home/head_neck/algorithm-template_2/nnunet_raw" | |
| # os.environ["nnUNet_preprocessed"] = "/home/head_neck/algorithm-template_2/nnunet_preprocessed" | |
| # os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_5" | |
| os.environ["nnUNet_raw"] = "./nnunet_raw" | |
| os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed" | |
| # os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_task_2" | |
| # os.environ["nnUNet_results"] = "./nnunet_results_task_2" | |
| from typing import Dict | |
| import tempfile | |
| import subprocess | |
| import SimpleITK as sitk | |
| from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor | |
| from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ | |
| save_json | |
| # import torch | |
| import numpy as np | |
| from base_algorithm import BaseSynthradAlgorithm | |
| from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified | |
| import torch | |
| import shutil | |
| import os | |
| os.environ["OPENBLAS_NUM_THREADS"] = "1" | |
| force_cpu = os.getenv("FORCE_CPU", "0") == "1" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu") | |
| class SynthradAlgorithm2(BaseSynthradAlgorithm): | |
| """ | |
| This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image. | |
| Author: Suraj Pai (b.pai@maastrichtuniversity.nl) | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image: | |
| assert list(input_dict.keys()) == ["image", "mask", "region"] | |
| region = input_dict["region"] | |
| mr_sitk = input_dict["image"] | |
| mask_sitk = input_dict["mask"] | |
| mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32") | |
| mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32") | |
| mr_np[mask_np == 0] = 0 | |
| preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np) | |
| preprocessed_mr_sitk.CopyInformation(mr_sitk) | |
| if region == "Head and Neck": | |
| dataset_name = "Dataset542" | |
| result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" | |
| plans_path = "./542_gt_nnUNetResEncUNetLPlans.json" | |
| if region == "Abdomen": | |
| dataset_name = "Dataset540" | |
| result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" | |
| plans_path = "./540_gt_nnUNetResEncUNetLPlans.json" | |
| if region == "Thorax": | |
| dataset_name = "Dataset544" | |
| result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres" | |
| plans_path = "./544_gt_nnUNetResEncUNetLPlans.json" | |
| # predictor = nnUNetPredictor( | |
| # tile_step_size=0.5, | |
| # use_gaussian=True, | |
| # use_mirroring=True, | |
| # perform_everything_on_device=True, | |
| # device=torch.device('cuda', 0), | |
| # verbose=True, | |
| # verbose_preprocessing=True, | |
| # allow_tqdm=True | |
| # ) | |
| predictor = nnUNetPredictor( | |
| tile_step_size=0.50, | |
| use_gaussian=True, | |
| use_mirroring=False, | |
| perform_everything_on_device=(device.type == "cuda"), | |
| device=device, | |
| verbose=True, | |
| ) | |
| predictor.initialize_from_trained_model_folder( | |
| join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'), | |
| # use_folds=(0, 1, 2, 3, 4), | |
| use_folds=(0,), | |
| checkpoint_name='checkpoint_final.pth', | |
| ) | |
| sitk_spacing = mr_sitk.GetSpacing() | |
| sitk_origin = mr_sitk.GetOrigin() | |
| sitk_dir = mr_sitk.GetDirection() | |
| props = { | |
| 'sitk_stuff': { | |
| 'spacing': tuple(sitk_spacing), | |
| 'origin': tuple(sitk_origin), | |
| 'direction': tuple(sitk_dir), | |
| }, | |
| 'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]] | |
| } | |
| img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32) | |
| img = np.expand_dims(img, 0) | |
| ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False) | |
| pred_path = "./TRUNCATED.nii.gz" | |
| pred_sitk = sitk.ReadImage(pred_path) | |
| ct_mean, ct_std = get_ct_normalisation_values(plans_path) | |
| mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8) | |
| pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk) | |
| os.remove(pred_path) | |
| shutil.rmtree("./imagesTs", ignore_errors=True) | |
| shutil.rmtree("./predictions", ignore_errors=True) | |
| return pred_sitk | |
| # if __name__ == '__main__': | |
| # # Run the algorithm on the default input and output paths specified in BaseSynthradAlgorithm. | |
| # SynthradAlgorithm().process() | |
| # if __name__ == '__main__': | |
| # # # test Brain | |
| # # #start_time = time.time() | |
| # # # img_s_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mr.mha' | |
| # # # img_m_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mask.mha' | |
| # # # img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_brain/p2p3D/ensemble_t1_brain_final_e2/test_predictions/1BA001/ct_fakeB.mha' | |
| # # # region = 'Head and Neck' | |
| # # # test Pelvis | |
| # img_s_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_0000.mha' | |
| # img_m_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_mask.mha' | |
| # # #img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_pelvis/p2p3D/exp6_data_7_size_256_256_56_batch_3_lr_0.0002_aug3d_fold0_resumed_multiscale_2dsample_5/epoch_best/epoch75/test_predictions/1PA005/ct_fakeB.mha' | |
| # region = 'Abdomen' | |
| # # start test | |
| # input_dict = { | |
| # "image": sitk.ReadImage(img_s_path), | |
| # "mask": sitk.ReadImage(img_m_path), | |
| # "region": region | |
| # } | |
| # algorithm = SynthradAlgorithm() | |
| # img_t = algorithm.predict(input_dict) |