Synthrad2025 / process.py
FelixzeroSun
update and debug
0f107cb
import os
# os.environ["nnUNet_raw"] = "/home/head_neck/algorithm-template_2/nnunet_raw"
# os.environ["nnUNet_preprocessed"] = "/home/head_neck/algorithm-template_2/nnunet_preprocessed"
# os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_5"
os.environ["nnUNet_raw"] = "./nnunet_raw"
os.environ["nnUNet_preprocessed"] = "./nnunet_preprocessed"
# os.environ["nnUNet_results"] = "/home/head_neck/algorithm-template/nnunet_results_task_2"
# os.environ["nnUNet_results"] = "./nnunet_results_task_2"
from typing import Dict
import tempfile
import subprocess
import SimpleITK as sitk
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
save_json
# import torch
import numpy as np
from base_algorithm import BaseSynthradAlgorithm
from revert_normalisation import get_ct_normalisation_values, revert_normalisation_single_modified
import torch
import shutil
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
force_cpu = os.getenv("FORCE_CPU", "0") == "1"
device = torch.device("cuda:0" if torch.cuda.is_available() and not force_cpu else "cpu")
class SynthradAlgorithm2(BaseSynthradAlgorithm):
"""
This class implements a simple synthetic CT generation algorithm that segments all values greater than 2 in the input image.
Author: Suraj Pai (b.pai@maastrichtuniversity.nl)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def predict(self, input_dict: Dict[str, sitk.Image]) -> sitk.Image:
assert list(input_dict.keys()) == ["image", "mask", "region"]
region = input_dict["region"]
mr_sitk = input_dict["image"]
mask_sitk = input_dict["mask"]
mr_np = sitk.GetArrayFromImage(mr_sitk).astype("float32")
mask_np = sitk.GetArrayFromImage(mask_sitk).astype("float32")
mr_np[mask_np == 0] = 0
preprocessed_mr_sitk = sitk.GetImageFromArray(mr_np)
preprocessed_mr_sitk.CopyInformation(mr_sitk)
if region == "Head and Neck":
dataset_name = "Dataset542"
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres"
plans_path = "./542_gt_nnUNetResEncUNetLPlans.json"
if region == "Abdomen":
dataset_name = "Dataset540"
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres"
plans_path = "./540_gt_nnUNetResEncUNetLPlans.json"
if region == "Thorax":
dataset_name = "Dataset544"
result_folder = "nnUNetTrainerMRCT_loss_masked_perception_masked__nnUNetResEncUNetLPlans__3d_fullres"
plans_path = "./544_gt_nnUNetResEncUNetLPlans.json"
# predictor = nnUNetPredictor(
# tile_step_size=0.5,
# use_gaussian=True,
# use_mirroring=True,
# perform_everything_on_device=True,
# device=torch.device('cuda', 0),
# verbose=True,
# verbose_preprocessing=True,
# allow_tqdm=True
# )
predictor = nnUNetPredictor(
tile_step_size=0.50,
use_gaussian=True,
use_mirroring=False,
perform_everything_on_device=(device.type == "cuda"),
device=device,
verbose=True,
)
predictor.initialize_from_trained_model_folder(
join(os.environ["nnUNet_results"], f'{dataset_name}/{result_folder}'),
# use_folds=(0, 1, 2, 3, 4),
use_folds=(0,),
checkpoint_name='checkpoint_final.pth',
)
sitk_spacing = mr_sitk.GetSpacing()
sitk_origin = mr_sitk.GetOrigin()
sitk_dir = mr_sitk.GetDirection()
props = {
'sitk_stuff': {
'spacing': tuple(sitk_spacing),
'origin': tuple(sitk_origin),
'direction': tuple(sitk_dir),
},
'spacing': [sitk_spacing[2], sitk_spacing[1], sitk_spacing[0]]
}
img = sitk.GetArrayFromImage(mr_sitk).astype(np.float32)
img = np.expand_dims(img, 0)
ret = predictor.predict_single_npy_array(img, props, None, 'TRUNCATED', False)
pred_path = "./TRUNCATED.nii.gz"
pred_sitk = sitk.ReadImage(pred_path)
ct_mean, ct_std = get_ct_normalisation_values(plans_path)
mask_sitk = sitk.Cast(mask_sitk, sitk.sitkUInt8)
pred_sitk = revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=mask_sitk)
os.remove(pred_path)
shutil.rmtree("./imagesTs", ignore_errors=True)
shutil.rmtree("./predictions", ignore_errors=True)
return pred_sitk
# if __name__ == '__main__':
# # Run the algorithm on the default input and output paths specified in BaseSynthradAlgorithm.
# SynthradAlgorithm().process()
# if __name__ == '__main__':
# # # test Brain
# # #start_time = time.time()
# # # img_s_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mr.mha'
# # # img_m_path = '//datasets/work/hb-synthrad2023/source/raw_data/Task1/brain/1BA001/mask.mha'
# # # img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_brain/p2p3D/ensemble_t1_brain_final_e2/test_predictions/1BA001/ct_fakeB.mha'
# # # region = 'Head and Neck'
# # # test Pelvis
# img_s_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_0000.mha'
# img_m_path = '/home/head_neck/algorithm-template_2/task2/2ABA033_mask.mha'
# # #img_fakeB_path = '//datasets/work/hb-synthrad2023/work/bw_workplace/output/task1_pelvis/p2p3D/exp6_data_7_size_256_256_56_batch_3_lr_0.0002_aug3d_fold0_resumed_multiscale_2dsample_5/epoch_best/epoch75/test_predictions/1PA005/ct_fakeB.mha'
# region = 'Abdomen'
# # start test
# input_dict = {
# "image": sitk.ReadImage(img_s_path),
# "mask": sitk.ReadImage(img_m_path),
# "region": region
# }
# algorithm = SynthradAlgorithm()
# img_t = algorithm.predict(input_dict)