import sys import SimpleITK as sitk import json import glob import os from tqdm import tqdm import numpy as np import torch # revert normalisation def get_ct_normalisation_values(ct_plan_path): """ Get the mean and standard deviation for CT normalisation. """ # Load the nnUNet plans file for CT with open(ct_plan_path, "r") as f: ct_plan = json.load(f) ct_mean = ct_plan['foreground_intensity_properties_per_channel']["0"]['mean'] ct_std = ct_plan['foreground_intensity_properties_per_channel']["0"]['std'] print(f"CT mean: {ct_mean}, CT std: {ct_std}") return ct_mean, ct_std def revert_normalisation(pred_path, ct_mean, ct_std, save_path=None, mask_path=None, mask_outside_value=-1000): """ Revert the normalisation of a CT image. """ if save_path is None: save_path = pred_path + '_revert_norm' os.makedirs(save_path, exist_ok=True) imgs = glob.glob(os.path.join(pred_path, "*.nii.gz")) + \ glob.glob(os.path.join(pred_path, "*.mha")) if mask_path: print(f"Applying mask from {mask_path} with outside value {mask_outside_value}") else: print("No mask provided, normalisation will be applied to all images.") for img in tqdm(imgs): img_sitk = sitk.ReadImage(img) img_array = sitk.GetArrayFromImage(img_sitk) img_array = img_array * ct_std + ct_mean img_sitk_reverted = sitk.GetImageFromArray(img_array) img_sitk_reverted.CopyInformation(img_sitk) # if mask_path is provided, apply the mask if mask_path: filename = os.path.basename(img) filename = filename.replace('_0000', '') if '_0000' in filename else filename mask_itk = sitk.ReadImage(os.path.join(mask_path, filename)) img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value) sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img))) # print(f"Reverted saved to {os.path.join(save_path, os.path.basename(img))}") import SimpleITK as sitk import numpy as np def print_sitk_space(img: sitk.Image, name: str = "img"): if not isinstance(img, sitk.Image): print(f"[{name}] 不是 SimpleITK.Image(得到 {type(img)}),没有空间信息可打印。") return size = img.GetSize() # (x, y, z) spacing = img.GetSpacing() # (x, y, z) origin = img.GetOrigin() # (x, y, z) direction = np.array(img.GetDirection()) dim = img.GetDimension() if direction.size == dim*dim: direction = direction.reshape(dim, dim) print(f"[{name}] size (x,y,z) = {size}") print(f"[{name}] spacing (x,y,z) = {spacing}") print(f"[{name}] origin (x,y,z) = {origin}") print(f"[{name}] direction matrix =\n{direction}") print(f"[{name}] pixel type = {img.GetPixelIDTypeAsString()}") def revert_normalisation_modified(pred_path, ct_mean, ct_std, save_path=None, mask_path=None, mask_sitk=None, mask_outside_value=-1000): if save_path is None: save_path = pred_path + '_revert_norm' os.makedirs(save_path, exist_ok=True) imgs = glob.glob(os.path.join(pred_path, "*.nii.gz")) + \ glob.glob(os.path.join(pred_path, "*.mha")) if mask_path: print(f"Applying mask from {mask_path} with outside value {mask_outside_value}") elif mask_sitk is not None: print(f"Applying provided mask_sitk with outside value {mask_outside_value}") else: print("No mask provided, normalisation will be applied to all images.") for img in tqdm(imgs): img_sitk = sitk.ReadImage(img) img_array = sitk.GetArrayFromImage(img_sitk) img_array = img_array * ct_std + ct_mean img_sitk_reverted = sitk.GetImageFromArray(img_array) img_sitk_reverted.CopyInformation(img_sitk) if mask_path: filename = os.path.basename(img) filename = filename.replace('_0000', '') if '_0000' in filename else filename mask_itk = sitk.ReadImage(os.path.join(mask_path, filename)) img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_itk, outsideValue=mask_outside_value) elif mask_sitk is not None: img_sitk_reverted = sitk.Mask(img_sitk_reverted, mask_sitk, outsideValue=mask_outside_value) sitk.WriteImage(img_sitk_reverted, os.path.join(save_path, os.path.basename(img))) def revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=None, mr_sitk = None,outside_value=-1000): print(type(pred_sitk)) # print() # arr = sitk.GetArrayFromImage(pred_sitk).astype(np.float32) # print(arr) arr = pred_sitk * float(ct_std) + float(ct_mean) # out = sitk.GetImageFromArray(arr) # out.CopyInformation(mr_sitk) if mask_sitk is not None: out = sitk.Mask(arr, mask_sitk, outsideValue=outside_value) return out # def revert_normalisation_single_modified(pred_sitk, ct_mean, ct_std, mask_sitk=None, mr_sitk=None, outside_value=-1000): # import SimpleITK as sitk # import numpy as np # print_sitk_space(pred_sitk, "pred_sitk (in)") # 打印传入影像的空间信息 # arr = sitk.GetArrayFromImage(pred_sitk).astype(np.float32) # (z, y, x) # arr = arr * float(ct_std) + float(ct_mean) # out = sitk.GetImageFromArray(arr) # 这里生成的新图默认 spacing=(1,1,1), origin=(0,0,0), direction=I # # 用参考影像复制空间信息:优先用 mr_sitk(如果你希望与原始 MR 对齐) # ref = mr_sitk if mr_sitk is not None else pred_sitk # out.CopyInformation(ref) # print_sitk_space(out, "out (after CopyInformation)") # 打印复制后的空间信息 # if mask_sitk is not None: # # 如果 out 和 mask 的网格不完全一致,可以先重采样到 mask 的网格 # if (out.GetSize()!=mask_sitk.GetSize() or # out.GetSpacing()!=mask_sitk.GetSpacing() or # out.GetOrigin()!=mask_sitk.GetOrigin() or # out.GetDirection()!=mask_sitk.GetDirection()): # out = sitk.Resample(out, mask_sitk, sitk.Transform(), sitk.sitkLinear, outside_value, out.GetPixelID()) # out = sitk.Mask(out, sitk.Cast(mask_sitk, sitk.sitkUInt8), outsideValue=outside_value) # return out def revert_normalisation_single(pred_sitk, ct_mean, ct_std): arr = sitk.GetArrayFromImage(pred_sitk) arr = arr * ct_std + ct_mean reverted = sitk.GetImageFromArray(arr) reverted.CopyInformation(pred_sitk) return reverted if __name__ == "__main__": ct_plan_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/preprocessed/Dataset203_synthrad2025_task1_CT/nnUNetPlans.json" ct_mean, ct_std = get_ct_normalisation_values(ct_plan_path) pred_path = "/datasets/work/hb-synthrad2023/work/synthrad2025/bw_workplace/data/nnunet_struct/results/Dataset202_synthrad2025_task1_MR_mask/nnUNetTrainerMRCT__nnUNetPlans__3d_fullres/fold_0/validation" revert_normalisation(pred_path, ct_mean, ct_std, save_path=pred_path + "_revert_norm")