AutoSeg4ETICA / preprocessing /registration.py
Chris Xiao
upload files
c642393
import os
import ants
import nrrd
import numpy as np
import glob
import slicerio
import shutil
import argparse
def parse_command_line():
print('-----'*10)
print('Parsing Command Line Arguments')
parser = argparse.ArgumentParser(
description='pipeline for dataset nnUNet preprocessing')
parser.add_argument('-bp', metavar='base path', type=str,
help="Absolute path of the base directory")
parser.add_argument('-ip', metavar='image path', type=str,
help="Relative path of the image directory")
parser.add_argument('-sp', metavar='segmentation path', type=str,
help="Relative path of the image directory")
parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
help='a list of label name and corresponding value')
argv = parser.parse_args()
return argv
def split_and_registration(template, target, base, images_path, seg_path, fomat, checked=False, has_label=False):
print('-----'*10)
print('Creating file paths')
# Define the path for template, target, and segmentations (from template)
fixed_path = os.path.join(base, images_path, template + '.' + fomat)
moving_path = os.path.join(base, images_path, target + '.' + fomat)
images_output = os.path.join(base, 'imagesRS/', target + '.nii.gz')
print('-----'*10)
print(f'Reading in the template {template} and target {target} image')
# Read the template and target image
template_image = ants.image_read(fixed_path)
target_image = ants.image_read(moving_path)
print('-----'*10)
print('Performing the template and target image registration')
transform_forward = ants.registration(fixed=template_image, moving=target_image,
type_of_transform="AffineFast", verbose=False)
if has_label:
segmentation_path = os.path.join(
base, seg_path, target + '.nii.gz')
segmentation_output = os.path.join(
base, 'labelsRS/', target + '.nii.gz')
print('-----'*10)
print('Reading in the segmentation')
# Split segmentations into individual components
segment_target = ants.image_read(segmentation_path)
print('-----'*10)
print('Applying the transformation for label propagation and image registration')
predicted_targets_image = ants.apply_transforms(
fixed=template_image,
moving=segment_target,
transformlist=transform_forward["fwdtransforms"],
interpolator="genericLabel",
verbose=False)
predicted_targets_image.to_file(segmentation_output)
reg_img = ants.apply_transforms(
fixed=template_image,
moving=target_image,
transformlist=transform_forward["fwdtransforms"],
interpolator="linear",
verbose=False)
print('-----'*10)
print("writing out transformed template segmentation")
reg_img.to_file(images_output)
print('Label Propagation & Image Registration complete')
def convert_to_one_hot(data, header, segment_indices=None):
print('-----'*10)
print("converting to one hot")
layer_values = get_layer_values(header)
label_values = get_label_values(header)
# Newer Slicer NRRD (compressed layers)
if layer_values and label_values:
assert len(layer_values) == len(label_values)
if len(data.shape) == 3:
x_dim, y_dim, z_dim = data.shape
elif len(data.shape) == 4:
x_dim, y_dim, z_dim = data.shape[1:]
num_segments = len(layer_values)
one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
if segment_indices is None:
segment_indices = list(range(num_segments))
elif isinstance(segment_indices, int):
segment_indices = [segment_indices]
elif not isinstance(segment_indices, list):
print("incorrectly specified segment indices")
return
# Check if NRRD is composed of one layer 0
if np.max(layer_values) == 0:
for i, seg_idx in enumerate(segment_indices):
layer = layer_values[seg_idx]
label = label_values[seg_idx]
one_hot[i] = 1*(data == label).astype(np.uint8)
else:
for i, seg_idx in enumerate(segment_indices):
layer = layer_values[seg_idx]
label = label_values[seg_idx]
one_hot[i] = 1*(data[layer] == label).astype(np.uint8)
# Binary labelmap
elif len(data.shape) == 3:
x_dim, y_dim, z_dim = data.shape
num_segments = np.max(data)
one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
if segment_indices is None:
segment_indices = list(range(1, num_segments + 1))
elif isinstance(segment_indices, int):
segment_indices = [segment_indices]
elif not isinstance(segment_indices, list):
print("incorrectly specified segment indices")
return
for i, seg_idx in enumerate(segment_indices):
one_hot[i] = 1*(data == seg_idx).astype(np.uint8)
# Older Slicer NRRD (already one-hot)
else:
return data
return one_hot
def get_layer_values(header, indices=None):
layer_values = []
num_segments = len([key for key in header.keys() if "Layer" in key])
for i in range(num_segments):
layer_values.append(int(header['Segment{}_Layer'.format(i)]))
return layer_values
def get_label_values(header, indices=None):
label_values = []
num_segments = len([key for key in header.keys() if "LabelValue" in key])
for i in range(num_segments):
label_values.append(int(header['Segment{}_LabelValue'.format(i)]))
return label_values
def get_num_segments(header, indices=None):
num_segments = len([key for key in header.keys() if "LabelValue" in key])
return num_segments
def checkCorrespondence(segmentation, base, paired_list, filename):
print(filename)
assert type(paired_list) == list
data, tempSeg = nrrd.read(os.path.join(base, segmentation, filename))
seg_info = slicerio.read_segmentation_info(
os.path.join(base, segmentation, filename))
output_voxels, output_header = slicerio.extract_segments(
data, tempSeg, seg_info, paired_list)
output = os.path.join(base, 'MatchedSegs/' +
filename)
nrrd.write(output, output_voxels, output_header)
print('---'*10)
print('Check the label names and values')
print(slicerio.read_segmentation_info(output))
return output
def checkSegFormat(base, segmentation, paired_list, check=False):
path = os.path.join(base, segmentation)
save_dir = os.path.join(base, 're-format_labels')
try:
os.mkdir(save_dir)
except:
print(f'{save_dir} already exists')
for file in os.listdir(path):
name = file.split('.')[0]
if file.endswith('seg.nrrd') or file.endswith('nrrd'):
if check:
output_path = checkCorrespondence(
segmentation, base, paired_list, file)
ants_img = ants.image_read(output_path)
header = nrrd.read_header(output_path)
else:
ants_img = ants.image_read(os.path.join(path, file))
header = nrrd.read_header(os.path.join(path, file))
segmentations = True
filename = os.path.join(save_dir, name + '.nii.gz')
nrrd2nifti(ants_img, header, filename, segmentations)
elif file.endswith('nii'):
image = ants.image_read(os.path.join(path, file))
image.to_file(os.path.join(save_dir, name + '.nii.gz'))
elif file.endswith('nii.gz'):
shutil.copy(os.path.join(path, file), save_dir)
return save_dir
def nrrd2nifti(img, header, filename, segmentations=True):
img_as_np = img.view(single_components=segmentations)
if segmentations:
data = convert_to_one_hot(img_as_np, header)
foreground = np.max(data, axis=0)
labelmap = np.multiply(np.argmax(data, axis=0) + 1,
foreground).astype('uint8')
segmentation_img = ants.from_numpy(
labelmap, origin=img.origin, spacing=img.spacing, direction=img.direction)
print('-- Saving NII Segmentations')
segmentation_img.to_file(filename)
else:
print('-- Saving NII Volume')
img.to_file(filename)
def find_template(base, image_path, fomat):
scans = sorted(glob.glob(os.path.join(base, image_path) + '/*' + fomat))
template = os.path.basename(scans[0]).split('.')[0]
return template
def find_template_V2(base, image_path, fomat):
maxD = -np.inf
for i in glob.glob(os.path.join(base, image_path) + '/*' + fomat):
id = os.path.basename(i).split('.')[0]
img = ants.image_read(i)
thirdD = img.shape[2]
if thirdD > maxD:
template = id
maxD = thirdD
print(maxD, template)
return template
def path_to_id(path, fomat):
ids = []
for i in glob.glob(path + '/*' + fomat):
id = os.path.basename(i).split('.')[0]
ids.append(id)
return ids
def checkFormat(base, images_path):
path = os.path.join(base, images_path)
for file in os.listdir(path):
if file.endswith('.nii'):
ret = 'nii'
break
elif file.endswith('.nii.gz'):
ret = 'nii.gz'
break
elif file.endswith('.nrrd'):
ret = 'nrrd'
break
elif file.endswith('.seg.nrrd'):
ret = 'seg.nrrd'
break
return ret
def main():
args = parse_command_line()
base = args.bp
images_path = args.ip
segmentation = args.sp
label_list = args.sl
images_output = os.path.join(base, 'imagesRS')
labels_output = os.path.join(base, 'labelsRS')
fomat = checkFormat(base, images_path)
fomat_seg = checkFormat(base, segmentation)
template = find_template(base, images_path, fomat)
label_lists = path_to_id(os.path.join(base, segmentation), fomat_seg)
if label_list is not None:
matched_output = os.path.join(base, 'MatchedSegs')
try:
os.mkdir(matched_output)
except:
print(f"{matched_output} already exists")
try:
os.mkdir(images_output)
except:
print(f"{images_output} already exists")
try:
os.mkdir(labels_output)
except:
print(f"{labels_output} already exists")
paired_list = []
if label_list is not None:
for i in range(0, len(label_list), 2):
if not label_list[i].isdigit():
print(
"Wrong order of input argument for pair-wising label value and its name !!!")
return
else:
value = label_list[i]
if not label_list[i+1].isdigit():
key = label_list[i+1]
ele = tuple((key, value))
paired_list.append(ele)
else:
print(
"Wrong input argument for pair-wising label value and its name !!!")
return
# print(new_segmentation)
seg_output_path = checkSegFormat(
base, segmentation, paired_list, check=True)
for j in sorted(glob.glob(os.path.join(base, images_path) + '/*' + fomat)):
id = os.path.basename(j).split('.')[0]
if id == template:
pass
else:
target = id
if id in label_lists:
split_and_registration(
template, target, base, images_path, seg_output_path, fomat, checked=True, has_label=True)
else:
split_and_registration(
template, target, base, images_path, seg_output_path, fomat, checked=True, has_label=False)
image = ants.image_read(os.path.join(
base, images_path, template + '.' + fomat))
image.to_file(os.path.join(base, images_output, template + '.nii.gz'))
fomat = 'nii.gz'
images_path = os.path.join(base, 'imagesRS/')
if template in label_lists:
split_and_registration(
target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=True)
else:
split_and_registration(
target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=False)
else:
seg_output_path = checkSegFormat(
base, segmentation, paired_list, check=False)
for i in sorted(glob.glob(os.path.join(base, images_path) + '/*' + fomat)):
id = os.path.basename(i).split('.')[0]
if id == template:
pass
else:
target = id
if id in label_lists:
split_and_registration(
template, target, base, images_path, seg_output_path, fomat, checked=False, has_label=True)
else:
split_and_registration(
template, target, base, images_path, seg_output_path, fomat, checked=False, has_label=False)
image = ants.image_read(os.path.join(
base, images_path, template + '.' + fomat))
image.to_file(os.path.join(base, images_output, template + '.nii.gz'))
images_path = os.path.join(base, 'imagesRS/')
fomat = 'nii.gz'
if template in label_lists:
split_and_registration(
target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=True)
else:
split_and_registration(
target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=False)
if __name__ == '__main__':
main()