AutoSeg4Sinonasal / deepatlas /preprocess /registration_test.py
Chris Xiao
upload files
2ca2f68
raw
history blame
13.6 kB
import os
import ants
import nrrd
import numpy as np
import glob
import slicerio
import shutil
import argparse
from pathlib import Path
def parse_command_line():
print('---'*10)
print('Parsing Command Line Arguments')
parser = argparse.ArgumentParser(
description='pipeline for dataset co-alignment')
parser.add_argument('-bp', metavar='base path', type=str,
help="absolute path of the base directory")
parser.add_argument('-op', metavar='output path for both registration & crop steps', type=str,
help="relative path of the output directory, should be same name in the registration, crop and final prediction steps")
parser.add_argument('-template', metavar='template scan path', type=str,
help="relative path of the template scan directory")
parser.add_argument('-target_scan', metavar='target scan path', type=str,
help="relative path of the target image directory")
parser.add_argument('-target_seg', metavar='target segmentation path', type=str,
help="relative path of the target segmentation directory")
parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
help='a list of label name and corresponding value')
parser.add_argument('-ti', metavar='task id and name', type=str,
help='task name and id')
argv = parser.parse_args()
return argv
def split_and_registration(template, target, base, template_images_path, target_images_path, seg_path, img_out_path, seg_out_path, template_fomat, target_fomat, has_label=False):
print('---'*10)
print('Creating file paths')
# Define the path for template, target, and segmentations (from template)
fixed_path = os.path.join(base, template_images_path, template + '.' + template_fomat)
moving_path = os.path.join(base, target_images_path, target + '.' + target_fomat)
images_output = os.path.join(img_out_path, target + '.nii.gz')
print('---'*10)
print('Reading in the template and target image')
# Read the template and target image
template_image = ants.image_read(fixed_path)
target_image = ants.image_read(moving_path)
print('---'*10)
print('Performing the template and target image registration')
transform_forward = ants.registration(fixed=template_image, moving=target_image,
type_of_transform="Similarity", verbose=False)
if has_label:
segmentation_path = os.path.join(
base, seg_path, target + '.nii.gz')
segmentation_output = os.path.join(
seg_out_path, target + '.nii.gz')
print('---'*10)
print('Reading in the segmentation')
# Split segmentations into individual components
segment_target = ants.image_read(segmentation_path)
print('---'*10)
print('Applying the transformation for label propagation and image registration')
predicted_targets_image = ants.apply_transforms(
fixed=template_image,
moving=segment_target,
transformlist=transform_forward["fwdtransforms"],
interpolator="genericLabel",
verbose=False)
predicted_targets_image.to_file(segmentation_output)
reg_img = ants.apply_transforms(
fixed=template_image,
moving=target_image,
transformlist=transform_forward["fwdtransforms"],
interpolator="linear",
verbose=False)
print('---'*10)
print("writing out transformed template segmentation")
reg_img.to_file(images_output)
print('Label Propagation & Image Registration complete')
def convert_to_one_hot(data, header, segment_indices=None):
print('---'*10)
print("converting to one hot")
layer_values = get_layer_values(header)
label_values = get_label_values(header)
# Newer Slicer NRRD (compressed layers)
if layer_values and label_values:
assert len(layer_values) == len(label_values)
if len(data.shape) == 3:
x_dim, y_dim, z_dim = data.shape
elif len(data.shape) == 4:
x_dim, y_dim, z_dim = data.shape[1:]
num_segments = len(layer_values)
one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
if segment_indices is None:
segment_indices = list(range(num_segments))
elif isinstance(segment_indices, int):
segment_indices = [segment_indices]
elif not isinstance(segment_indices, list):
print("incorrectly specified segment indices")
return
# Check if NRRD is composed of one layer 0
if np.max(layer_values) == 0:
for i, seg_idx in enumerate(segment_indices):
layer = layer_values[seg_idx]
label = label_values[seg_idx]
one_hot[i] = 1*(data == label).astype(np.uint8)
else:
for i, seg_idx in enumerate(segment_indices):
layer = layer_values[seg_idx]
label = label_values[seg_idx]
one_hot[i] = 1*(data[layer] == label).astype(np.uint8)
# Binary labelmap
elif len(data.shape) == 3:
x_dim, y_dim, z_dim = data.shape
num_segments = np.max(data)
one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
if segment_indices is None:
segment_indices = list(range(1, num_segments + 1))
elif isinstance(segment_indices, int):
segment_indices = [segment_indices]
elif not isinstance(segment_indices, list):
print("incorrectly specified segment indices")
return
for i, seg_idx in enumerate(segment_indices):
one_hot[i] = 1*(data == seg_idx).astype(np.uint8)
# Older Slicer NRRD (already one-hot)
else:
return data
return one_hot
def get_layer_values(header, indices=None):
layer_values = []
num_segments = len([key for key in header.keys() if "Layer" in key])
for i in range(num_segments):
layer_values.append(int(header['Segment{}_Layer'.format(i)]))
return layer_values
def get_label_values(header, indices=None):
label_values = []
num_segments = len([key for key in header.keys() if "LabelValue" in key])
for i in range(num_segments):
label_values.append(int(header['Segment{}_LabelValue'.format(i)]))
return label_values
def get_num_segments(header, indices=None):
num_segments = len([key for key in header.keys() if "LabelValue" in key])
return num_segments
def checkCorrespondence(segmentation, base, paired_list, filename):
print(filename)
assert type(paired_list) == list
data, tempSeg = nrrd.read(os.path.join(base, segmentation, filename))
seg_info = slicerio.read_segmentation_info(
os.path.join(base, segmentation, filename))
output_voxels, output_header = slicerio.extract_segments(
data, tempSeg, seg_info, paired_list)
output = os.path.join(base, 'MatchedSegs/' +
filename)
nrrd.write(output, output_voxels, output_header)
print('---'*10)
print('Check the label names and values')
print(slicerio.read_segmentation_info(output))
return output
def checkSegFormat(base, segmentation, paired_list, check=False):
path = os.path.join(base, segmentation)
save_dir = os.path.join(base, 're-format_labels')
try:
os.mkdir(save_dir)
except:
print(f'{save_dir} already exists')
for file in os.listdir(path):
name = file.split('.')[0]
if file.endswith('seg.nrrd') or file.endswith('nrrd'):
if check:
output_path = checkCorrespondence(
segmentation, base, paired_list, file)
ants_img = ants.image_read(output_path)
header = nrrd.read_header(output_path)
else:
ants_img = ants.image_read(os.path.join(path, file))
header = nrrd.read_header(os.path.join(path, file))
segmentations = True
filename = os.path.join(save_dir, name + '.nii.gz')
nrrd2nifti(ants_img, header, filename, segmentations)
elif file.endswith('nii'):
image = ants.image_read(os.path.join(path, file))
image.to_file(os.path.join(save_dir, name + '.nii.gz'))
elif file.endswith('nii.gz'):
shutil.copy(os.path.join(path, file), save_dir)
return save_dir
def nrrd2nifti(img, header, filename, segmentations=True):
img_as_np = img.view(single_components=segmentations)
if segmentations:
data = convert_to_one_hot(img_as_np, header)
foreground = np.max(data, axis=0)
labelmap = np.multiply(np.argmax(data, axis=0) + 1,
foreground).astype('uint8')
segmentation_img = ants.from_numpy(
labelmap, origin=img.origin, spacing=img.spacing, direction=img.direction)
print('-- Saving NII Segmentations')
segmentation_img.to_file(filename)
else:
print('-- Saving NII Volume')
img.to_file(filename)
def find_template(base, image_path, fomat):
scans = sorted(glob.glob(os.path.join(base, image_path) + '/*' + fomat))
template = os.path.basename(scans[0]).split('.')[0]
return template
def find_template_V2(base, image_path, fomat):
maxD = -np.inf
for i in glob.glob(os.path.join(base, image_path) + '/*' + fomat):
id = os.path.basename(i).split('.')[0]
img = ants.image_read(i)
thirdD = img.shape[2]
if thirdD > maxD:
template = id
maxD = thirdD
return template
def path_to_id(path, fomat):
ids = []
for i in glob.glob(path + '/*' + fomat):
id = os.path.basename(i).split('.')[0]
ids.append(id)
return ids
def checkFormat(base, images_path):
path = os.path.join(base, images_path)
for file in os.listdir(path):
if file.endswith('.nii'):
ret = 'nii'
break
elif file.endswith('.nii.gz'):
ret = 'nii.gz'
break
elif file.endswith('.nrrd'):
ret = 'nrrd'
break
elif file.endswith('.seg.nrrd'):
ret = 'seg.nrrd'
break
return ret
def main():
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute())
args = parse_command_line()
base = args.bp
template_path = args.template
target_seg = args.target_seg
target_scan = args.target_scan
label_list = args.sl
task_id = args.ti
deepatlas_path = ROOT_DIR
task_path = os.path.join(deepatlas_path, 'deepatlas_raw_data_base', task_id)
output_data_path = os.path.join(task_path, 'customize_test_data')
out_data_path = os.path.join(output_data_path, args.op)
images_output = os.path.join(out_data_path, 'images')
labels_output = os.path.join(out_data_path, 'labels')
template_fomat = checkFormat(base, template_path)
target_fomat = checkFormat(base, target_scan)
fomat_seg = checkFormat(base, target_seg)
template = os.path.basename(glob.glob(os.path.join(base, template_path) + '/*' + template_fomat)[0]).split('.')[0]
label_lists = path_to_id(os.path.join(base, target_seg), fomat_seg)
if label_list is not None:
matched_output = os.path.join(base, 'MatchedSegs')
try:
os.mkdir(matched_output)
except:
print(f"{matched_output} already exists")
try:
os.mkdir(output_data_path)
except:
print(f"{output_data_path} already exists")
try:
os.mkdir(out_data_path)
except:
print(f"{out_data_path} already exists")
try:
os.mkdir(images_output)
except:
print(f"{images_output} already exists")
try:
os.mkdir(labels_output)
except:
print(f"{labels_output} already exists")
paired_list = []
if label_list is not None:
for i in range(0, len(label_list), 2):
if not label_list[i].isdigit():
print(
"Wrong order of input argument for pair-wising label value and its name !!!")
return
else:
value = label_list[i]
if not label_list[i+1].isdigit():
key = label_list[i+1]
ele = tuple((key, value))
paired_list.append(ele)
else:
print(
"Wrong input argument for pair-wising label value and its name !!!")
return
# print(new_segmentation)
seg_output_path = checkSegFormat(
base, target_seg, paired_list, check=True)
else:
seg_output_path = checkSegFormat(
base, target_seg, paired_list, check=False)
for i in sorted(glob.glob(os.path.join(base, target_scan) + '/*' + target_fomat)):
id = os.path.basename(i).split('.')[0]
target = id
if id in label_lists:
split_and_registration(
template, target, base, template_path, target_scan, seg_output_path, images_output, labels_output, template_fomat, target_fomat, has_label=True)
else:
split_and_registration(
template, target, base, template_path, target_scan, seg_output_path, images_output, labels_output, template_fomat, target_fomat, has_label=False)
if __name__ == '__main__':
main()