LVM-Med / datasets_split /MMWHS_MR_Heart_split.py
duynhm's picture
Initial commit
be2715b
raw
history blame contribute delete
No virus
5.3 kB
import math
import os
import random
import numpy as np
from torch.utils.data import Dataset
import nibabel
from scipy import ndimage
import glob
from skimage.io import imsave
from utils.dataset_prepare import split_data, save_fileLabel_3D
# input image is the volume
def __itensity_normalize_one_volume__(image):
# normalization following Med3D
top_per = np.percentile(image, 99.5)
bot_per = np.percentile(image, 0.5)
image[image > top_per] = top_per
image[image < bot_per] = bot_per
image = (image - np.mean(image)) / np.std(image)
image = image / 10.0
image[image < 0] = 0.0
image[image > 1] = 1.0
return image
def __training_data_process__(data, label):
# crop data according net input size
data = data.get_fdata()
label = label.get_fdata()
# normalization datas
data = __itensity_normalize_one_volume__(data)
# changing label values
label[label == 205] = 30
label[label == 420] = 60
label[label == 500] = 90
label[label == 550] = 120
label[label == 600] = 150
label[label == 820] = 180
label[label == 850] = 210
return data, label
def preprocess_vol(img_name, label_name):
assert os.path.isfile(img_name)
assert os.path.isfile(label_name)
img = nibabel.load(img_name) # We have transposed the data from WHD format to DHW
assert img is not None
mask = nibabel.load(label_name)
assert mask is not None
img_array, mask_array = __training_data_process__(img, mask)
assert img_array.shape == mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
return (img_array*255).astype('uint8'), mask_array.astype('uint8')
#if __name__ == '__main__':
def MMWHS_MR_Heart_split():
dataset_name = "MMWHS_MR_Heart"
### Training set
data_dir = './dataset_demo/MMWHS_MR_Heart/Raw/train/'
img_fold_list = os.listdir(data_dir)
dest_dir = './dataset_demo/MMWHS_MR_Heart/train/' # dir for saving train images
dest_dir_label = './dataset_demo/MMWHS_MR_Heart/train_labels/'
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
if not os.path.exists(dest_dir_label):
os.makedirs(dest_dir_label)
for vol_name in img_fold_list:
if 'label' in vol_name:
continue
mask_name = os.path.join(data_dir, vol_name).replace('image','label')
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name)
print(img_flair.shape, mask.shape)
# img_array.shape[2] is the length of depth dimension
for depth in range(0, img_flair.shape[2]):
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False)
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False)
### Validation set
data_dir = './dataset_demo/MMWHS_MR_Heart/Raw/valid/'
img_fold_list = os.listdir(data_dir)
dest_dir = './dataset_demo/MMWHS_MR_Heart/valid/'
dest_dir_label = './dataset_demo/MMWHS_MR_Heart/valid_labels/'
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
if not os.path.exists(dest_dir_label):
os.makedirs(dest_dir_label)
for vol_name in img_fold_list:
if 'label' in vol_name:
continue
mask_name = os.path.join(data_dir, vol_name).replace('image','label')
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name)
print(img_flair.shape, mask.shape)
# img_array.shape[2] is the length of depth dimension
for depth in range(0, img_flair.shape[2]):
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False)
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False)
### Testing set
data_dir = './dataset_demo/MMWHS_MR_Heart/Raw/test/'
img_fold_list = os.listdir(data_dir)
dest_dir = './dataset_demo/MMWHS_MR_Heart/test/'
dest_dir_label = './dataset_demo/MMWHS_MR_Heart/test_labels/'
if not os.path.exists(dest_dir):
os.makedirs(dest_dir)
if not os.path.exists(dest_dir_label):
os.makedirs(dest_dir_label)
for vol_name in img_fold_list:
if 'label' in vol_name:
continue
mask_name = os.path.join(data_dir, vol_name).replace('image','label')
img_flair, mask = preprocess_vol(os.path.join(data_dir, vol_name), mask_name)
print(img_flair.shape, mask.shape)
# img_array.shape[2] is the length of depth dimension
for depth in range(0, img_flair.shape[2]):
imsave(os.path.join(dest_dir, vol_name.split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), img_flair[:, :, depth], check_contrast=False)
imsave(os.path.join(dest_dir_label, vol_name.replace('image','label').split('.')[0] + '_frame_' + str(depth).zfill(3) + '.png'), mask[:, :, depth], check_contrast=False)
save_fileLabel_3D(dataset_name)