Chris Xiao
upload files
2ca2f68
import numpy as np
import nibabel as nib
import ants
import argparse
import pandas as pd
import glob
import os
import surface_distance
import nrrd
import shutil
import distanceVertex2Mesh
import textwrap
def parse_command_line():
print('---'*10)
print('Parsing Command Line Arguments')
parser = argparse.ArgumentParser(
description='Inference evaluation pipeline for image registration-segmentation', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('-bp', metavar='base path', type=str,
help="Absolute path of the base directory")
parser.add_argument('-gp', metavar='ground truth path', type=str,
help="Relative path of the ground truth segmentation directory")
parser.add_argument('-pp', metavar='predicted path', type=str,
help="Relative path of predicted segmentation directory")
parser.add_argument('-sp', metavar='save path', type=str,
help="Relative path of CSV file directory to save, if not specify, default is base directory")
parser.add_argument('-vt', metavar='validation type', type=str, nargs='+',
help=textwrap.dedent('''Validation type:
dsc: Dice Score
ahd: Average Hausdorff Distance
whd: Weighted Hausdorff Distance
'''))
parser.add_argument('-pm', metavar='probability map path', type=str,
help="Relative path of text file directory of probability map")
parser.add_argument('-fn', metavar='file name', type=str,
help="name of output file")
parser.add_argument('-reg', action='store_true',
help="check if the input files are registration predictions")
parser.add_argument('-tp', metavar='type of segmentation', type=str,
help=textwrap.dedent('''Segmentation type:
ET: Eustachian Tube
NC: Nasal Cavity
HT: Head Tumor
'''))
parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
help='a list of label name and corresponding value')
parser.add_argument('-cp', metavar='current prefix of filenames', type=str,
help='current prefix of filenames')
argv = parser.parse_args()
return argv
def rename(prefix, filename):
name = filename.split('.')[0][-3:]
name = prefix + '_' + name
return name
def dice_coefficient_and_hausdorff_distance(filename, img_np_pred, img_np_gt, num_classes, spacing, probability_map, dsc, ahd, whd, average_DSC, average_HD):
df = pd.DataFrame()
data_gt, bool_gt = make_one_hot(img_np_gt, num_classes)
data_pred, bool_pred = make_one_hot(img_np_pred, num_classes)
for i in range(1, num_classes):
df1 = pd.DataFrame([[filename, i]], columns=[
'File ID', 'Label Value'])
if dsc:
if data_pred[i].any():
volume_sum = data_gt[i].sum() + data_pred[i].sum()
if volume_sum == 0:
return np.NaN
volume_intersect = (data_gt[i] & data_pred[i]).sum()
dice = 2*volume_intersect / volume_sum
df1['Dice Score'] = dice
average_DSC[i-1] += dice
else:
dice = 0.0
df1['Dice Score'] = dice
average_DSC[i-1] += dice
if ahd:
if data_pred[i].any():
avd = average_hausdorff_distance(bool_gt[i], bool_pred[i], spacing)
df1['Average Hausdorff Distance'] = avd
average_HD[i-1] += avd
else:
avd = np.nan
df1['Average Hausdorff Distance'] = avd
average_HD[i-1] += avd
if whd:
# wgd = weighted_hausdorff_distance(gt, pred, probability_map)
# df1['Weighted Hausdorff Distance'] = wgd
pass
df = pd.concat([df, df1])
return df, average_DSC, average_HD
def make_one_hot(img_np, num_classes):
img_one_hot_dice = np.zeros(
(num_classes, img_np.shape[0], img_np.shape[1], img_np.shape[2]), dtype=np.int8)
img_one_hot_hd = np.zeros(
(num_classes, img_np.shape[0], img_np.shape[1], img_np.shape[2]), dtype=bool)
for i in range(num_classes):
a = (img_np == i)
img_one_hot_dice[i, :, :, :] = a
img_one_hot_hd[i, :, :, :] = a
return img_one_hot_dice, img_one_hot_hd
def average_hausdorff_distance(img_np_gt, img_np_pred, spacing):
surf_distance = surface_distance.compute_surface_distances(
img_np_gt, img_np_pred, spacing)
gp, pg = surface_distance.compute_average_surface_distance(surf_distance)
return (gp + pg) / 2
def checkSegFormat(base, segmentation, type, prefix=None):
if type == 'gt':
save_dir = os.path.join(base, 'gt_reformat_labels')
path = segmentation
else:
save_dir = os.path.join(base, 'pred_reformat_labels')
path = os.path.join(base, segmentation)
try:
os.mkdir(save_dir)
except:
print(f'{save_dir} already exists')
for file in os.listdir(path):
if type == 'gt':
if prefix is not None:
name = rename(prefix, file)
else:
name = file.split('.')[0]
else:
name = file.split('.')[0]
if file.endswith('seg.nrrd'):
ants_img = ants.image_read(os.path.join(path, file))
header = nrrd.read_header(os.path.join(path, file))
filename = os.path.join(save_dir, name + '.nii.gz')
nrrd2nifti(ants_img, header, filename)
elif file.endswith('nii'):
image = ants.image_read(os.path.join(path, file))
image.to_file(os.path.join(save_dir, name + '.nii.gz'))
elif file.endswith('nii.gz'):
shutil.copy(os.path.join(path, file), os.path.join(save_dir, name + '.nii.gz'))
return save_dir
def nrrd2nifti(img, header, filename):
img_as_np = img.view(single_components=True)
data = convert_to_one_hot(img_as_np, header)
foreground = np.max(data, axis=0)
labelmap = np.multiply(np.argmax(data, axis=0) + 1,
foreground).astype('uint8')
segmentation_img = ants.from_numpy(
labelmap, origin=img.origin, spacing=img.spacing, direction=img.direction)
print('-- Saving NII Segmentations')
segmentation_img.to_file(filename)
def convert_to_one_hot(data, header, segment_indices=None):
print('---'*10)
print("converting to one hot")
layer_values = get_layer_values(header)
label_values = get_label_values(header)
# Newer Slicer NRRD (compressed layers)
if layer_values and label_values:
assert len(layer_values) == len(label_values)
if len(data.shape) == 3:
x_dim, y_dim, z_dim = data.shape
elif len(data.shape) == 4:
x_dim, y_dim, z_dim = data.shape[1:]
num_segments = len(layer_values)
one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
if segment_indices is None:
segment_indices = list(range(num_segments))
elif isinstance(segment_indices, int):
segment_indices = [segment_indices]
elif not isinstance(segment_indices, list):
print("incorrectly specified segment indices")
return
# Check if NRRD is composed of one layer 0
if np.max(layer_values) == 0:
for i, seg_idx in enumerate(segment_indices):
layer = layer_values[seg_idx]
label = label_values[seg_idx]
one_hot[i] = 1*(data == label).astype(np.uint8)
else:
for i, seg_idx in enumerate(segment_indices):
layer = layer_values[seg_idx]
label = label_values[seg_idx]
one_hot[i] = 1*(data[layer] == label).astype(np.uint8)
# Binary labelmap
elif len(data.shape) == 3:
x_dim, y_dim, z_dim = data.shape
num_segments = np.max(data)
one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
if segment_indices is None:
segment_indices = list(range(1, num_segments + 1))
elif isinstance(segment_indices, int):
segment_indices = [segment_indices]
elif not isinstance(segment_indices, list):
print("incorrectly specified segment indices")
return
for i, seg_idx in enumerate(segment_indices):
one_hot[i] = 1*(data == seg_idx).astype(np.uint8)
# Older Slicer NRRD (already one-hot)
else:
return data
return one_hot
def get_layer_values(header):
layer_values = []
num_segments = len([key for key in header.keys() if "Layer" in key])
for i in range(num_segments):
layer_values.append(int(header['Segment{}_Layer'.format(i)]))
return layer_values
def get_label_values(header):
label_values = []
num_segments = len([key for key in header.keys() if "LabelValue" in key])
for i in range(num_segments):
label_values.append(int(header['Segment{}_LabelValue'.format(i)]))
return label_values
def main():
args = parse_command_line()
base = args.bp
gt_path = args.gp
pred_path = args.pp
if args.sp is None:
save_path = base
else:
save_path = args.sp
validation_type = args.vt
probability_map_path = args.pm
filename = args.fn
reg = args.reg
seg_type = args.tp
label_list = args.sl
current_prefix = args.cp
if probability_map_path is not None:
probability_map = np.loadtxt(os.path.join(base, probability_map_path))
else:
probability_map = None
dsc = False
ahd = False
whd = False
for i in range(len(validation_type)):
if validation_type[i] == 'dsc':
dsc = True
elif validation_type[i] == 'ahd':
ahd = True
elif validation_type[i] == 'whd':
whd = True
else:
print('wrong validation type, please choose correct one !!!')
return
filepath = os.path.join(base, save_path, 'output_' + filename + '.csv')
save_dir = os.path.join(base, save_path)
gt_output_path = checkSegFormat(base, gt_path, 'gt', current_prefix)
pred_output_path = checkSegFormat(base, pred_path, 'pred', current_prefix)
try:
os.mkdir(save_dir)
except:
print(f'{save_dir} already exists')
try:
os.mknod(filepath)
except:
print(f'{filepath} already exists')
DSC = pd.DataFrame()
file = glob.glob(os.path.join(base, gt_output_path) + '/*nii.gz')[0]
seg_file = ants.image_read(file)
num_class = np.unique(seg_file.numpy().ravel()).shape[0]
average_DSC = np.zeros((num_class-1))
average_HD = np.zeros((num_class-1))
k = 0
for i in glob.glob(os.path.join(base, pred_output_path) + '/*nii.gz'):
k += 1
pred_img = ants.image_read(i)
pred_spacing = list(pred_img.spacing)
if reg and seg_type == 'ET':
file_name = os.path.basename(i).split('.')[0].split('_')[4] + '_' + os.path.basename(
i).split('.')[0].split('_')[5] + '_' + os.path.basename(i).split('.')[0].split('_')[6]
file_name1 = os.path.basename(i).split('.')[0]
elif reg and seg_type == 'NC':
file_name = os.path.basename(i).split(
'.')[0].split('_')[3] + '_' + os.path.basename(i).split('.')[0].split('_')[4]
file_name1 = os.path.basename(i).split('.')[0]
elif reg and seg_type == 'HT':
file_name = os.path.basename(i).split('.')[0].split('_')[2]
file_name1 = os.path.basename(i).split('.')[0]
else:
file_name = os.path.basename(i).split('.')[0]
file_name1 = os.path.basename(i).split('.')[0]
gt_seg = os.path.join(base, gt_output_path, file_name + '.nii.gz')
gt_img = ants.image_read(gt_seg)
gt_spacing = list(gt_img.spacing)
if gt_spacing != pred_spacing:
print(
"Spacing of prediction and ground_truth is not matched, please check again !!!")
return
ref = pred_img
data_ref = ref.numpy()
pred = gt_img
data_pred = pred.numpy()
num_class = len(np.unique(data_pred))
ds, aver_DSC, aver_HD = dice_coefficient_and_hausdorff_distance(
file_name1, data_ref, data_pred, num_class, pred_spacing, probability_map, dsc, ahd, whd, average_DSC, average_HD)
DSC = pd.concat([DSC, ds])
average_DSC = aver_DSC
average_HD = aver_HD
avg_DSC = average_DSC / k
avg_HD = average_HD / k
print(avg_DSC)
with open(os.path.join(base, save_path, "metric.txt"), 'w') as f:
f.write("Label Value Label Name Average Dice Score Average Mean HD\n")
for i in range(len(avg_DSC)):
f.write(f'{str(i+1):^12}{str(label_list[2*i+1]):^12}{str(avg_DSC[i]):^20}{str(avg_HD[i]):^18}\n')
DSC.to_csv(filepath)
if __name__ == '__main__':
main()