Data_Engineering / MnM2_clean /dataclean_MnM2.py
maxmo2009's picture
Initial upload: data cleanup pipeline for 12 medical imaging datasets
da9fb1e verified
#coding:utf-8
'''
write by ygq
create on 2025-08-26
update MnMs2 data clean
nM2数据集的处理逻辑(个人理解,目前是按照这个思路来编写的处理脚本):
1.LA或者SA需要分开存储处理;
2.ED/ES我理解是舒张|收缩状态的图像信息,只是对应CINE(LA或SA)的某一帧;考虑到没有找到对应的头文件信息,不知道具体对应哪一帧;
3.这个数据集应该不是最原始的MnM2数据集,像是经过某些处理后的;同时没有找到对应的头文件信息;
4.带gt的文件为label标注文件,包含0,1,2,3【0:背景 1:左心室腔(LV)2:右心室腔(RV)3:左心室心肌(Myo)】--需要帮忙确认下
a.需要单独保存LA-CINE以及SA-CINE的重处理后的文件;
b.另外需要单独处理LA-ED,LA-ES以及SA-ED,SA-ES的重处理后的文件【spaceing以及size同CINE】;以及label标注文件;
##暂时将LA-ED/ES分开,可以考虑计算每个cine的时次图层的图像均值来判定ED/ES对应的所在帧【试验可行】;--20250825
分割标签:NIFTI 格式,标签值:
0:背景
1:左心室腔(LV)
2:右心室腔(RV)
3:左心室心肌(Myo
当前版本没有元文件信息
'''
import os
import glob
import pandas as pd
import SimpleITK as sitk
import argparse
import json
from tqdm import tqdm
from util import meta_data
import util
import numpy as np
# from bert_helper import *
TASK_VALUE="segmentation"
CLAMP_RANGE_CT = [-300,300]
CLAMP_RANGE_MRI = None # MRI images threshold placeholder TBC...
TARGET_VOXEL_SPACING=None
LABEL_DICT={
"0":"backgroud",
"1":"LV",#左心室 Blood Pools
"3":"MYO",#左心室心肌
"2":"RV"#右心室 Blood Pools
}
# def find_metadata_files(path):
# # for Cancer Image Archive (TCIA) dataset
# search_pattern = os.path.join(path, '**', 'metadata.csv')
# return glob.glob(search_pattern, recursive=True)
def find_metadata_files(path):
# for Cancer Image Archive (TCIA) dataset
search_pattern = os.path.join(path, '*.csv')
return glob.glob(search_pattern, recursive=True)
##added by yanguoqing on 20250527
def find_image_dirs(path):
return os.listdir(path)
##modify by yanguoqing on 20250527
def load_dicom_images(folder_path):
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(folder_path)
reader.SetFileNames(dicom_names)
image = reader.Execute()
return dicom_names,image
##added by yanguoqing on 20250527
def load_dicom_tag(imgs):
reader = sitk.ImageFileReader()
# dicom_names = reader.GetGDCMSeriesFileNames(folder_path)
reader.SetFileName(imgs)
reader.ReadImageInformation() # 仅读取元信息,不加载像素数据
# metadata_keys = reader.GetMetaDataKeys()
tag=reader.Execute()
return tag
def load_nrrd(fp):
return sitk.ReadImage(fp)
def save_nifti(image, output_path, folder_path):
# Set metadata in the NIfTI file's header
output_dirpath = os.path.dirname(output_path)
if not os.path.exists(output_dirpath):
print(f"Creating directory {output_dirpath}")
os.makedirs(output_dirpath)
# Set metadata in the NIfTI file's header
image.SetMetaData("FolderPath", folder_path)
sitk.WriteImage(image, output_path)
##modify by yanguoqing on 20250527
def convert_windows_to_linux_path(windows_path):
# Replace backslashes with forward slashes and remove the drive letter
# Some meta files have windows paths, but the data is stored on a linux server
linux_path = windows_path.replace('\\', '/')
if ':' in linux_path:
linux_path = linux_path.split(':', 1)[1]
return linux_path
def main(target_path, output_dir):
metadata_files = find_metadata_files(target_path)
pid_dirs=find_image_dirs(target_path)
# pid_dirs=["Training","Testing","Validation"]
failed_files = []
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
json_output_path = os.path.join(output_dir, 'nifti_mappings.json')
failed_files_path = os.path.join(output_dir, 'failed_files.json')
meta = meta_data()
# Initialize the JSON file
if not os.path.exists(json_output_path):
with open(json_output_path, 'w') as json_file:
json.dump({}, json_file)
meta_file=os.path.join(target_path,'211230_M&Ms_Dataset_information_diagnosis_opendataset.csv')
if os.path.isfile(meta_file):
mf_flag=True
df_meta=pd.read_csv(meta_file,sep=',')
else:
mf_flag=False
if pid_dirs:
for pid_dir in tqdm(pid_dirs, desc="Processing pid dirs"):
if not os.path.isdir(os.path.join(target_path,pid_dir)):
continue
meta_image_id=pid_dir
modality="MRI"
study='MnM2'##Dataset_name
full_dir=os.path.join(target_path,pid_dir)
dfs=find_image_dirs(full_dir)##list all nii.gz files
if len(dfs)>0:
for df in dfs:
##循环遍历查找SA.LA的CINE以及ES/ED以及对应的gt文件
if "CINE" in df:
##正常处理
label_flag=False
if "_LA_" in df:
la_flag=True
else:
la_flag=False
elif "ES.nii.gz" in df:
if "_LA_" in df:
la_flag=True
else:
la_flag=False
if os.path.isfile(os.path.join(full_dir,df.replace(".nii.gz","_gt.nii.gz"))):
label_flag=True
else:
label_flag=False
else:
continue
try:
##处理数据
full_path_image=os.path.join(full_dir,df)
sitk_img_original = util.load_nifti(full_path_image)
if sitk_img_original is None:
print(f" Failed to load image: {full_path_image}")
continue
original_spacing = list(sitk_img_original.GetSpacing())
original_size = list(sitk_img_original.GetSize())
sitk_img_processed = sitk_img_original
# is_4d_image = msd_dataset_info.get("tensorImageSize", "3D").upper() == "4D" or sitk_img_original.GetDimension() == 4
is_4d_image = sitk_img_original.GetDimension() == 4
frame_flag=False
# --- Resampling Logic (Revised for 4D) ---
if is_4d_image:
# Always process 4D images channel-wise for resampling
# logging.info(f" Processing 4D image channel-wise: {original_img_full_path}") # Keep log for errors only
channels = []
num_channels = original_size[3] if len(original_size) == 4 and sitk_img_original.GetDimension() == 4 else 1
channel_target_spacing = TARGET_VOXEL_SPACING if TARGET_VOXEL_SPACING else original_spacing[:3] # Use 3D spacing
for i in range(num_channels):
extractor = sitk.ExtractImageFilter()
current_3d_channel_size = original_size[:3]
if sitk_img_original.GetDimension() == 4:
extractor.SetSize([current_3d_channel_size[0], current_3d_channel_size[1], current_3d_channel_size[2], 0])
extractor.SetIndex([0,0,0,i])
channel_3d_img = extractor.Execute(sitk_img_original)
else:
channel_3d_img = sitk_img_original
if i > 0: break
channel_resampler = util.get_unisize_resampler(
channel_3d_img, 'linear',
spacing=channel_target_spacing, size=current_3d_channel_size
)
if channel_resampler:
channels.append(channel_resampler.Execute(channel_3d_img))
else:
channels.append(channel_3d_img)
if channels:
if len(channels) > 1: # Only join if there are multiple channels
sitk_img_processed = sitk.JoinSeriesImageFilter().Execute(channels)
##aded by yanguoqing on 2025-08-11
frame_flag=True
# imgDict={}
# for kf_idx in range(num_channels):
# imgDict[str(kf_idx)]='none'
# if str(meta_ed):imgDict[str(meta_ed)]='ed'
# if str(meta_es):imgDict[str(meta_es)]='es'
# meta.add_keyvalue('ImgDict',imgDict)
elif len(channels) == 1: # If only one channel resulted (e.g. original was 3D misidentified as 4D by tensorImageSize)
sitk_img_processed = channels[0]
elif TARGET_VOXEL_SPACING: # 3D image with target spacing
img_resampler_obj = util.get_unisize_resampler(sitk_img_original, 'linear',
spacing=TARGET_VOXEL_SPACING, size=original_size)
if img_resampler_obj: sitk_img_processed = img_resampler_obj.Execute(sitk_img_original)
else: # 3D image, no TARGET_VOXEL_SPACING
img_resampler_obj = util.get_unisize_resampler(sitk_img_original, 'linear',
spacing=original_spacing, size=original_size)
if img_resampler_obj: sitk_img_processed = img_resampler_obj.Execute(sitk_img_original)
CIA_other_info = {
'metadata_file':''
# 'Series_Description':serise_desc
}
CIA_other_info['split'] = "train"
CIA_other_info['Image_id']=meta_image_id
if mf_flag:
CIA_other_info['metadata_file']=meta_file
is_processed_4d = sitk_img_processed.GetDimension() == 4
clamp_range_to_use=None
if clamp_range_to_use and is_processed_4d:
clamped_channels_final = []
num_channels_final = sitk_img_processed.GetSize()[3] if len(sitk_img_processed.GetSize()) == 4 else 1
for i in range(num_channels_final):
extractor = sitk.ExtractImageFilter()
proc_size_final = sitk_img_processed.GetSize()
extractor.SetSize([proc_size_final[0], proc_size_final[1], proc_size_final[2], 0])
extractor.SetIndex([0,0,0,i])
channel_3d_img_to_clamp = extractor.Execute(sitk_img_processed)
clamped_channels_final.append(util.clamp_image(channel_3d_img_to_clamp, clamp_range_to_use))
if clamped_channels_final:
if len(clamped_channels_final) > 1:
sitk_img_processed = sitk.JoinSeriesImageFilter().Execute(clamped_channels_final)
elif len(clamped_channels_final) == 1:
sitk_img_processed = clamped_channels_final[0]
elif clamp_range_to_use: # 3D image
sitk_img_processed = util.clamp_image(sitk_img_processed, clamp_range_to_use)
output_path = os.path.join(output_dir,pid_dir, f"{df}")
# output_path=convert_windows_to_linux_path(output_path)
save_nifti(sitk_img_processed, output_path, full_path_image)
print(f"Saved NIfTI file to {output_path}")
label_path_dict = {}
if label_flag:
processed_lbl_full_path = os.path.join(output_dir, pid_dir, TASK_VALUE, f"{df}")
full_path_label=os.path.join(full_dir,df.replace(".nii.gz","_gt.nii.gz"))
sitk_lbl_original = util.load_nifti(full_path_label)
if not sitk_lbl_original:
print(f" Failed to load label: {full_path_label}")
processed_lbl_full_path = None
continue
if sitk_lbl_original:
label_resampler = sitk.ResampleImageFilter()
reference_for_label = sitk_img_processed # Default to processed image
if sitk_img_processed.GetDimension() == 4:
num_comp_proc = sitk_img_processed.GetSize()[3] if len(sitk_img_processed.GetSize()) == 4 else 1
if num_comp_proc > 0:
extractor = sitk.ExtractImageFilter()
proc_img_size_for_lbl_ref = sitk_img_processed.GetSize()
extractor.SetSize([proc_img_size_for_lbl_ref[0], proc_img_size_for_lbl_ref[1], proc_img_size_for_lbl_ref[2], 0])
extractor.SetIndex([0,0,0,0])
try:
reference_for_label = extractor.Execute(sitk_img_processed)
except Exception as ref_err:
print(f" Failed to extract 3D reference from 4D image: {output_path} for label alignment.")
# print(traceback.format_exc())
reference_for_label = None
else: # Fallback if extraction fails
print(f" Could not extract 3D reference for label from 4D image {output_path}. Label may not be correctly resampled.")
reference_for_label = None # This will cause an issue below if not handled
sitk_lbl_processed = None
if reference_for_label and reference_for_label.GetDimension() > 0:
label_resampler.SetInterpolator(sitk.sitkNearestNeighbor)
label_resampler.SetOutputPixelType(sitk_lbl_original.GetPixelID())
if sitk_lbl_original.GetDimension() == 4:
lbl_channels = []
lbl_size = list(sitk_lbl_original.GetSize())
for i in range(lbl_size[3]):
extractor = sitk.ExtractImageFilter()
extractor.SetSize([lbl_size[0], lbl_size[1], lbl_size[2], 0])
extractor.SetIndex([0, 0, 0, i])
single_channel = extractor.Execute(sitk_lbl_original)
label_resampler.SetReferenceImage(reference_for_label)
resampled_channel = label_resampler.Execute(single_channel)
lbl_channels.append(resampled_channel)
if len(lbl_channels) > 1:
sitk_lbl_processed = sitk.JoinSeriesImageFilter().Execute(lbl_channels)
elif len(lbl_channels) == 1:
sitk_lbl_processed = lbl_channels[0]
else:
label_resampler.SetReferenceImage(reference_for_label)
sitk_lbl_processed = label_resampler.Execute(sitk_lbl_original)
if processed_lbl_full_path:
if sitk_img_processed.GetSize()[:3] != sitk_lbl_processed.GetSize()[:3]:
print(f" Mismatch between image and label size (ignoring channels):")
print(f" Image size: {sitk_img_processed.GetSize()}")
print(f" Label size: {sitk_lbl_processed.GetSize()}")
util.save_nifti(sitk_lbl_processed, processed_lbl_full_path, full_path_label)
else:
print(f" Failed to set reference image for label resampling for {full_path_label}. Saving original label.")
util.save_nifti(sitk_lbl_original, processed_lbl_full_path, full_path_label) # Save original
# processed_lbl_full_path should still point to this saved original label
sitk_lbl_processed=sitk_lbl_original
else:
processed_lbl_full_path = None
else:
processed_lbl_full_path = None
if processed_lbl_full_path:
label_path_dict['heart'] = processed_lbl_full_path
print('compare image and label size',sitk_img_original.GetSize(),sitk_lbl_original.GetSize())
print('compare image and label size',sitk_img_processed.GetSize(),sitk_lbl_processed.GetSize())
try:
assert sitk_img_processed.GetSize() == sitk_lbl_processed.GetSize()
except Exception as e:
failed_files.append(full_path_label)
continue
except RuntimeError:
failed_files.append(full_path_image)
print(f"Failed to load MnMs images from {full_path_image}")
continue
size_processed = list(sitk_img_processed.GetSize())
print('size_processed',size_processed,original_size)
# meta.add_keyvalue('Image_id',meta_image_id)
meta.add_keyvalue('Spacing_mm',min(original_spacing[:3]))##保留前三个x,y,z的最小spacing
meta.add_keyvalue('OriImg_path',full_path_image)
meta.add_keyvalue('Size',size_processed) # 这里用处理后的size -- YH Jachin
meta.add_keyvalue('Modality',modality)
meta.add_keyvalue('Dataset_name',study)
meta.add_keyvalue('ROI','chest')
if processed_lbl_full_path:
print(label_path_dict.keys())
meta.add_keyvalue('Task',TASK_VALUE)
# meta.add_keyvalue('Label_tissue',list(label_path_dict.keys()))
meta.add_keyvalue('Label_path',{TASK_VALUE:label_path_dict})
meta.add_keyvalue('Label_Dict',LABEL_DICT)
meta.add_extra_keyvalue('Metadata',CIA_other_info)
# Write the mapping to the JSON file on the fly
with open(json_output_path, 'r+') as json_file:
existing_mappings = json.load(json_file)
existing_mappings[output_path] = meta.get_meta_data()
json_file.seek(0)
print(existing_mappings)
json.dump(existing_mappings, json_file, indent=4)
json_file.truncate()
else:
continue
with open(failed_files_path, "w") as json_file:
json.dump(failed_files, json_file)
print(f"The list has been written to {failed_files_path}")
print(f"Saved NIfTI mappings to {json_output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process DICOM files and save as NIfTI.")
parser.add_argument("--target_path", type=str, help="Path to the target directory containing metadata files.", default="/home/data/Github/data/data_gen_def/DATASETS/MnM2/MnM2/dataset/")
parser.add_argument("--output_dir", type=str, help="Directory to save the NIfTI files.", default="/home/data/Github/data/data_gen_def/DATASETS_processed/MnM2/")
args = parser.parse_args()
print(args.target_path, args.output_dir)
main(args.target_path, args.output_dir)