Spaces:
Sleeping
Sleeping
File size: 6,169 Bytes
0f107cb 867f0d3 0f107cb 867f0d3 0f107cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
# os.environ["nnUNet_raw"] = "/home/head_neck/algorithm-template_2/nnunet_raw"
# os.environ["nnUNet_preprocessed"] = "/home/head_neck/algorithm-template_2/nnunet_preprocessed"
# os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_task_1"
os.environ["nnUNet_raw"] = "./nnunet_raw"
os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed"
# os.environ["nnUNet_results"] = "./nnunet_results"
from typing import Dict
import tempfile
import subprocess
import SimpleITK as sitk
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
save_json
# import torch
import numpy as np
from base_algorithm import BaseSynthradAlgorithm
from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified
import torch
import shutil
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
force_cpu = os.getenv("FORCE_CPU", "0") == "1"
device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu")
class SynthradAlgorithm1(BaseSynthradAlgorithm):
"""
This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image.
Author: Suraj Pai (b.pai@maastrichtuniversity.nl)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image:
assert list(input_dict.keys()) == ["image", "mask", "region"]
region = input_dict["region"]
mr_sitk = input_dict["image"]
mask_sitk = input_dict["mask"]
mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32")
mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32")
mr_np[mask_np == 0] = 0
preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np)
preprocessed_mr_sitk.CopyInformation(mr_sitk)
if region == "Head and Neck":
dataset_name = "Dataset262"
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres"
plans_path = "./262_gt_nnUNetResEncUNetLPlans.json"
if region == "Abdomen":
dataset_name = "Dataset260"
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres"
plans_path = "./260_gt_nnUNetResEncUNetLPlans.json"
if region == "Thorax":
dataset_name = "Dataset264"
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres"
plans_path = "./264_gt_nnUNetResEncUNetLPlans.json"
# predictor = nnUNetPredictor(
# tile_step_size=0.5,
# use_gaussian=True,
# use_mirroring=True,
# perform_everything_on_device=True,
# device=torch.device('cuda', 0),
# verbose=True,
# verbose_preprocessing=True,
# allow_tqdm=True
# )
predictor = nnUNetPredictor(
tile_step_size=0.50,
use_gaussian=True,
use_mirroring=False,
perform_everything_on_device=(device.type == "cuda"),
device=device,
verbose=True,
)
predictor.initialize_from_trained_model_folder(
join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'),
#use_folds=(0, 1, 2, 3, 4),
use_folds=(0, ),
checkpoint_name='checkpoint_final.pth',
)
sitk_spacing = mr_sitk.GetSpacing()
sitk_origin = mr_sitk.GetOrigin()
sitk_dir = mr_sitk.GetDirection()
props = {
'sitk_stuff': {
'spacing': tuple(sitk_spacing),
'origin': tuple(sitk_origin),
'direction': tuple(sitk_dir),
},
'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]]
}
img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32)
img = np.expand_dims(img, 0)
ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False)
pred_path = "./TRUNCATED.nii.gz"
pred_sitk = sitk.ReadImage(pred_path)
ct_mean, ct_std = get_ct_normalisation_values(plans_path)
mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8)
pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk)
os.remove(pred_path)
shutil.rmtree("./imagesTs", ignore_errors=True)
shutil.rmtree("./predictions", ignore_errors=True)
return pred_sitk
# if __name__ == '__main__':
# # Run the algorithm on the default input and output paths specified in BaseSynthradAlgorithm.
# SynthradAlgorithm().process()
# if __name__ == '__main__':
# # # test Brain
# # #start_time = time.time()
# # # img_s_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mr.mha'
# # # img_m_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mask.mha'
# # # img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_brain/p2p3D/ensemble_t1_brain_final_e2/test_predictions/1BA001/ct_fakeB.mha'
# # # region = 'Head and Neck'
# # # test Pelvis
# img_s_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_0000.mha'
# img_m_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_mask.mha'
# # #img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_pelvis/p2p3D/exp6_data_7_size_256_256_56_batch_3_lr_0.0002_aug3d_fold0_resumed_multiscale_2dsample_5/epoch_best/epoch75/test_predictions/1PA005/ct_fakeB.mha'
# region = 'Abdomen'
# # start test
# input_dict = {
# "image": sitk.ReadImage(img_s_path),
# "mask": sitk.ReadImage(img_m_path),
# "region": region
# }
# algorithm = SynthradAlgorithm()
# img_t = algorithm.predict(input_dict) |