AutoSeg4Sinonasal / deepatlas /preprocess /crop_flip_training.py
Chris Xiao
upload files
2ca2f68
raw
history blame
13.3 kB
import numpy as np
import glob
import ants
import nibabel as nib
import os
import argparse
import sys
from crop import crop, cropV2, save_fileV2
from pathlib import Path
def parse_command_line():
parser = argparse.ArgumentParser(
description='pipeline for data preprocessing')
parser.add_argument('-rs', metavar='shape after resizing', type=int, nargs='+',
help='shape after resizing the image and segmentation. Expected to be 2^N')
parser.add_argument('-fp', action='store_true',
help='check if need to flip the data')
parser.add_argument('-ti', metavar='task id and name', type=str,
help='task name and id')
argv = parser.parse_args()
return argv
def path_to_id(path):
ids = []
for i in glob.glob(path + '/*nii.gz'):
id = os.path.basename(i).split('.')[0]
ids.append(id)
return ids
def pad(raw_image, bound_x, bound_y, bound_z, resize, seg=False):
diff_x = resize[0] - (bound_x[1]-bound_x[0])
diff_y = resize[1] - (bound_y[1]-bound_y[0])
diff_z = resize[2] - (bound_z[1]-bound_z[0])
if diff_x < 0 or diff_y < 0 or diff_z < 0:
sys.exit(
'the dimension of ROI is larger than the resizing dimension, please choose a different padding dimension')
left_y, right_y = split(diff_y)
left_z, right_z = split(diff_z)
left_x, right_x = split(diff_x)
new_bound_x_left = bound_x[0] - left_x
new_bound_x_right = bound_x[1] + right_x
new_bound_y_left = bound_y[0] - left_y
new_bound_y_right = bound_y[1] + right_y
new_bound_z_left = bound_z[0] - left_z
new_bound_z_right = bound_z[1] + right_z
# check if x_dim out of bounds
if new_bound_x_left < 0:
new_bound_x_left = 0
new_bound_x_right = bound_x[1] + diff_x - bound_x[0]
elif new_bound_x_right > raw_image.shape[0]:
new_bound_x_right = raw_image.shape[0]
new_bound_x_left = bound_x[0] - \
(diff_x - (raw_image.shape[0] - bound_x[1]))
# check if y_dim out of bounds
if new_bound_y_left < 0:
new_bound_y_left = 0
new_bound_y_right = bound_y[1] + diff_y - bound_y[0]
elif new_bound_y_right > raw_image.shape[1]:
new_bound_y_right = raw_image.shape[1]
new_bound_y_left = bound_y[0] - \
(diff_y - (raw_image.shape[1] - bound_y[1]))
# check if z_dim out of bounds
if new_bound_z_left < 0:
new_bound_z_left = 0
new_bound_z_right = bound_z[1] + diff_z - bound_z[0]
elif new_bound_z_right > raw_image.shape[2]:
new_bound_z_right = raw_image.shape[2]
new_bound_z_left = bound_z[0] - \
(diff_z - (raw_image.shape[2] - bound_z[1]))
assert new_bound_x_right - new_bound_x_left == resize[0]
assert new_bound_y_right - new_bound_y_left == resize[1]
assert new_bound_z_right - new_bound_z_left == resize[2]
if not seg:
return raw_image[new_bound_x_left:new_bound_x_right, new_bound_y_left:new_bound_y_right, new_bound_z_left:new_bound_z_right]
else:
new_seg = np.zeros_like(raw_image)
new_seg[bound_x[0]:bound_x[1],
bound_y[0]:bound_y[1], bound_z[0]:bound_z[1]] = raw_image[bound_x[0]:bound_x[1], bound_y[0]:bound_y[1], bound_z[0]:bound_z[1]]
return new_seg[new_bound_x_left:new_bound_x_right, new_bound_y_left:new_bound_y_right, new_bound_z_left:new_bound_z_right]
def split(distance):
if distance == 0:
return 0, 0
half_dist = int(distance / 2)
left = int(half_dist * 0.8)
right = distance - left
return left, right
def crop_and_flip(nib_img, nib_seg, ants_img, ants_seg, resize):
img = nib_img.get_fdata()
seg = nib_seg.get_fdata()
gem = ants.label_geometry_measures(ants_seg, ants_img)
low_x = min(list(gem.loc[:, 'BoundingBoxLower_x']))
upp_x = max(list(gem.loc[:, 'BoundingBoxUpper_x']))
low_y = min(list(gem.loc[:, 'BoundingBoxLower_y']))
upp_y = max(list(gem.loc[:, 'BoundingBoxUpper_y']))
low_z = min(list(gem.loc[:, 'BoundingBoxLower_z']))
upp_z = max(list(gem.loc[:, 'BoundingBoxUpper_z']))
img = Zscore_normalization(img)
#img = MinMax_normalization(img)
# Compute mid point
mid_x = int((low_x + upp_x) / 2)
tuple_x_left = tuple([low_x, mid_x])
tuple_x_right = tuple([mid_x, upp_x])
tuple_y = tuple([low_y, upp_y])
tuple_z = tuple([low_z, upp_z])
left_img = pad(img, tuple_x_left, tuple_y, tuple_z, resize, seg=False)
left_seg = pad(seg, tuple_x_left, tuple_y, tuple_z, resize, seg=True)
right_img = pad(img, tuple_x_right, tuple_y, tuple_z, resize, seg=False)
right_seg = pad(seg, tuple_x_right, tuple_y, tuple_z, resize, seg=True)
flipped_right_img = np.flip(right_img, axis=0)
flipped_right_seg = np.flip(right_seg, axis=0)
return left_img, left_seg, flipped_right_img, flipped_right_seg
def crop_and_flip_V2(nib_img, ants_img, resize, geo_info):
img = nib_img.get_fdata()
tuple_x = geo_info[0]
tuple_y = geo_info[1]
tuple_z = geo_info[2]
low_x = tuple_x[0]
upp_x = tuple_x[1]
img = Zscore_normalization(img)
#img = MinMax_normalization(img)
# Compute mid point
mid_x = int((low_x + upp_x) / 2)
tuple_x_left = tuple([low_x, mid_x])
tuple_x_right = tuple([mid_x, upp_x])
left_img = pad(img, tuple_x_left, tuple_y, tuple_z, resize, seg=False)
right_img = pad(img, tuple_x_right, tuple_y, tuple_z, resize, seg=False)
flipped_right_img = np.flip(right_img, axis=0)
return left_img, flipped_right_img
def MinMax_normalization(scan):
lb = np.amin(scan)
ub = np.amax(scan)
scan = (scan - lb) / (ub - lb)
return scan
def Zscore_normalization(scan):
mean = np.mean(scan)
std = np.std(scan)
lb = np.percentile(scan, 0.05)
ub = np.percentile(scan, 99.5)
scan = np.clip(scan, lb, ub)
scan = (scan - mean) / std
return scan
def load_data(img_path, seg_path):
nib_seg = nib.load(seg_path)
nib_img = nib.load(img_path)
ants_seg = ants.image_read(seg_path)
ants_img = ants.image_read(img_path)
return nib_img, nib_seg, ants_img, ants_seg
def crop_flip_save_file(left_img, left_seg, flipped_right_img, flipped_right_seg, nib_img, nib_seg, output_img, output_seg, scan_id):
left_img_nii = nib.Nifti1Image(
left_img, affine=nib_img.affine, header=nib_img.header)
left_seg_nii = nib.Nifti1Image(
left_seg, affine=nib_seg.affine, header=nib_seg.header)
right_img_nii = nib.Nifti1Image(
flipped_right_img, affine=nib_img.affine, header=nib_img.header)
right_seg_nii = nib.Nifti1Image(
flipped_right_seg, affine=nib_seg.affine, header=nib_seg.header)
left_img_nii.to_filename(os.path.join(
output_img, 'right_' + scan_id + '.nii.gz'))
left_seg_nii.to_filename(os.path.join(
output_seg, 'right_' + scan_id + '.nii.gz'))
right_img_nii.to_filename(os.path.join(
output_img, 'left_' + scan_id + '.nii.gz'))
right_seg_nii.to_filename(os.path.join(
output_seg, 'left_' + scan_id + '.nii.gz'))
def get_geometry_info(seg_path, img_path):
abs_low_x = np.Inf
abs_upp_x = -np.Inf
abs_low_y = np.Inf
abs_upp_y = -np.Inf
abs_low_z = np.Inf
abs_upp_z = -np.Inf
for i in sorted(glob.glob(os.path.join(img_path, '*.nii.gz'))):
name = os.path.basename(i)
if os.path.exists(os.path.join(seg_path, name)):
seg = ants.image_read(os.path.join(seg_path, name))
img = ants.image_read(i)
gem = ants.label_geometry_measures(seg, img)
low_x = min(list(gem.loc[:, 'BoundingBoxLower_x']))
upp_x = max(list(gem.loc[:, 'BoundingBoxUpper_x']))
low_y = min(list(gem.loc[:, 'BoundingBoxLower_y']))
upp_y = max(list(gem.loc[:, 'BoundingBoxUpper_y']))
low_z = min(list(gem.loc[:, 'BoundingBoxLower_z']))
upp_z = max(list(gem.loc[:, 'BoundingBoxUpper_z']))
if low_x < abs_low_x:
abs_low_x = low_x
if upp_x > abs_upp_x:
abs_upp_x = upp_x
if low_y < abs_low_y:
abs_low_y = low_y
if upp_y > abs_upp_y:
abs_upp_y = upp_y
if low_z < abs_low_z:
abs_low_z = low_z
if upp_z > abs_upp_z:
abs_upp_z = upp_z
tuple_x = tuple([abs_low_x, abs_upp_x])
tuple_y = tuple([abs_low_y, abs_upp_y])
tuple_z = tuple([abs_low_z, abs_upp_z])
return [tuple_x, tuple_y, tuple_z]
def crop_flip_save_file_V2(left_img, flipped_right_img, nib_img, output_img, scan_id):
left_img_nii = nib.Nifti1Image(
left_img, affine=nib_img.affine, header=nib_img.header)
right_img_nii = nib.Nifti1Image(
flipped_right_img, affine=nib_img.affine, header=nib_img.header)
left_img_nii.to_filename(os.path.join(
output_img, 'right_' + scan_id + '.nii.gz'))
right_img_nii.to_filename(os.path.join(
output_img, 'left_' + scan_id + '.nii.gz'))
def crop_save_file(left_img, left_seg, nib_img, nib_seg, output_img, output_seg, scan_id):
left_img_nii = nib.Nifti1Image(
left_img, affine=nib_img.affine, header=nib_img.header)
left_seg_nii = nib.Nifti1Image(
left_seg, affine=nib_seg.affine, header=nib_seg.header)
left_img_nii.to_filename(os.path.join(
output_img, scan_id + '.nii.gz'))
left_seg_nii.to_filename(os.path.join(
output_seg, scan_id + '.nii.gz'))
def main():
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute())
args = parse_command_line()
resize_shape = args.rs
flipped = args.fp
deepatlas_path = ROOT_DIR
task_id = args.ti
base_path = os.path.join(deepatlas_path, 'deepatlas_raw_data_base', task_id, 'Training_dataset')
image_path = os.path.join(base_path, 'images')
seg_path = os.path.join(base_path, 'labels')
output_path = os.path.join(deepatlas_path, 'deepatlas_preprocessed')
task_path = os.path.join(deepatlas_path, 'deepatlas_preprocessed', task_id)
training_data_path = os.path.join(deepatlas_path, 'deepatlas_preprocessed', task_id, 'Training_dataset')
output_img = os.path.join(deepatlas_path, 'deepatlas_preprocessed', task_id, 'Training_dataset', 'images')
output_seg = os.path.join(deepatlas_path, 'deepatlas_preprocessed', task_id, 'Training_dataset', 'labels')
label_list = path_to_id(seg_path)
geo_info = get_geometry_info(seg_path, image_path)
print(geo_info)
try:
os.mkdir(output_path)
except:
print(f'{output_path} is already existed')
try:
os.mkdir(task_path)
except:
print(f'{task_path} is already existed')
try:
os.mkdir(training_data_path)
except:
print(f"{training_data_path} already exists")
try:
os.mkdir(output_path)
except:
print(f'{output_path} is already existed')
try:
os.mkdir(output_img)
except:
print(f'{output_img} is already existed')
try:
os.mkdir(output_seg)
except:
print(f'{output_seg} is already existed')
for i in sorted(glob.glob(image_path + '/*nii.gz')):
id = os.path.basename(i).split('.')[0]
if id in label_list:
label_path = os.path.join(seg_path, id + '.nii.gz')
nib_img, nib_seg, ants_img, ants_seg = load_data(i, label_path)
if flipped:
left_img, left_seg, flipped_right_img, flipped_right_seg = crop_and_flip(
nib_img, nib_seg, ants_img, ants_seg, resize_shape)
print(
'Scan ID: ' + id + f', img & seg before cropping: {nib_img.get_fdata().shape}, after cropping, flipping and padding: {left_img.shape} and {flipped_right_img.shape}')
crop_flip_save_file(left_img, left_seg, flipped_right_img, flipped_right_seg,
nib_img, nib_seg, output_img, output_seg, id)
else:
left_img, left_seg = crop(
nib_img, nib_seg, ants_img, ants_seg, resize_shape)
print(
'Scan ID: ' + id + f', img & seg before cropping: {nib_img.get_fdata().shape}, after cropping and padding the image and seg: {left_img.shape}')
crop_save_file(left_img, left_seg, nib_img,
nib_seg, output_img, output_seg, id)
else:
nib_img = nib.load(i)
ant_img = ants.image_read(i)
if flipped:
left_img, flipped_right_img = crop_and_flip_V2(nib_img, ant_img, resize_shape, geo_info)
print(
'Scan ID: ' + id + f', img before cropping: {nib_img.get_fdata().shape}, after cropping, flipping and padding: {left_img.shape} and {flipped_right_img.shape}')
crop_flip_save_file_V2(left_img, flipped_right_img, nib_img, output_img, id)
else:
outImg = cropV2(nib_img, ant_img, resize_shape, geo_info)
print(
'Scan ID: ' + id + f', img before cropping: {nib_img.get_fdata().shape}, after cropping and padding the image: {outImg.shape}')
save_fileV2(outImg, nib_img, output_img, id)
if __name__ == '__main__':
main()