DDMR / Brain_study /Evaluate_network__test_fixed.py
andreped's picture
Renamed module to ddmr
a27d55f
import os, sys
import shutil
import time
import h5py
import matplotlib.pyplot as plt
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
import tensorflow as tf
# tf.enable_eager_execution(config=config)
import numpy as np
import pandas as pd
import voxelmorph as vxm
import ddmr.utils.constants as C
from ddmr.utils.operators import min_max_norm, safe_medpy_metric
from ddmr.utils.nifti_utils import save_nifti
from ddmr.layers import AugmentationLayer, UncertaintyWeighting
from ddmr.losses import StructuralSimilarity_simplified, NCC, GeneralizedDICEScore, HausdorffDistanceErosion, target_registration_error
from ddmr.ms_ssim_tf import MultiScaleStructuralSimilarity
from ddmr.utils.acummulated_optimizer import AdamAccumulated
from ddmr.utils.visualization import save_disp_map_img, plot_predictions
from ddmr.utils.misc import resize_displacement_map, scale_transformation
from ddmr.utils.misc import DisplacementMapInterpolator, get_segmentations_centroids, segmentation_ohe_to_cardinal
from EvaluationScripts.Evaluate_class import EvaluationFigures, resize_pts_to_original_space, resize_img_to_original_space, resize_transformation
from scipy.ndimage import zoom
from tqdm import tqdm
import medpy.metric as medpy_metrics
import h5py
import re
from Brain_study.data_generator import BatchGenerator
from voxelmorph.tf.layers import SpatialTransformer
import argparse
from skimage.transform import warp
import neurite as ne
DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'
MODEL_FILE = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/checkpoints/best_model.h5'
DATA_ROOT_DIR = '/mnt/EncryptedData1/Users/javier/train_output/Brain_study/ERASE/MS_SSIM/BASELINE_L_ssim__MET_mse_ncc_ssim_162756-29062021/'
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', nargs='+', type=str, help='.h5 of the model', default=None)
parser.add_argument('-d', '--dir', nargs='+', type=str, help='Directory where ./checkpoints/best_model.h5 is located', default=None)
parser.add_argument('--gpu', type=int, help='GPU', default=0)
parser.add_argument('--dataset', type=str, help='Dataset to run predictions on',
default='/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/test')
parser.add_argument('--erase', type=bool, help='Erase the content of the output folder', default=False)
parser.add_argument('--outdirname', type=str, default='Evaluate')
parser.add_argument('--savenifti', type=bool, default=True)
args = parser.parse_args()
if args.model is not None:
assert '.h5' in args.model[0], 'No checkpoint file provided, use -d/--dir instead'
MODEL_FILE_LIST = args.model
DATA_ROOT_DIR_LIST = [os.path.split(model_path)[0] for model_path in args.model]
elif args.dir is not None:
assert '.h5' not in args.dir[0], 'Provided checkpoint file, user -m/--model instead'
MODEL_FILE_LIST = [os.path.join(dir_path, 'checkpoints', 'best_model.h5') for dir_path in args.dir]
DATA_ROOT_DIR_LIST = args.dir
else:
raise ValueError("Provide either the model file or the directory ./containing checkpoints/best_model.h5")
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) # Check availability before running using 'nvidia-smi'
DATASET = args.dataset
list_test_files = [os.path.join(DATASET, f) for f in os.listdir(DATASET) if f.endswith('h5') and 'dm' not in f]
list_test_files.sort()
with h5py.File(list_test_files[0], 'r') as f:
image_input_shape = image_output_shape = list(f['fix_image'][:].shape[:-1])
nb_labels = f['fix_segmentations'][:].shape[-1] - 1 # Skip background label
# Header of the metrics csv file
csv_header = ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'DICE_MACRO', 'HD', 'HD95', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
# TF stuff
config = tf.compat.v1.ConfigProto() # device_count={'GPU':0})
config.gpu_options.allow_growth = True
config.log_device_placement = False ## to log device placement (on which device the operation ran)
config.allow_soft_placement = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)
# Loss and metric functions. Common to all models
loss_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).loss,
NCC(image_input_shape).loss,
vxm.losses.MSE().loss,
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).loss]
metric_fncs = [StructuralSimilarity_simplified(patch_size=2, dim=3, dynamic_range=1.).metric,
NCC(image_input_shape).metric,
vxm.losses.MSE().loss,
MultiScaleStructuralSimilarity(max_val=1., filter_size=3).metric,
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric,
HausdorffDistanceErosion(3, 10, im_shape=image_output_shape + [nb_labels]).metric,
GeneralizedDICEScore(image_output_shape + [nb_labels], num_labels=nb_labels).metric_macro]
### METRICS GRAPH ###
fix_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='fix_img')
pred_img_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, 1), name='pred_img')
fix_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='fix_seg')
pred_seg_ph = tf.compat.v1.placeholder(tf.float32, (1, None, None, None, nb_labels), name='pred_seg')
ssim_tf = metric_fncs[0](fix_img_ph, pred_img_ph)
ncc_tf = metric_fncs[1](fix_img_ph, pred_img_ph)
mse_tf = metric_fncs[2](fix_img_ph, pred_img_ph)
ms_ssim_tf = metric_fncs[3](fix_img_ph, pred_img_ph)
dice_tf = metric_fncs[4](fix_seg_ph, pred_seg_ph)
hd_tf = metric_fncs[5](fix_seg_ph, pred_seg_ph)
dice_macro_tf = metric_fncs[6](fix_seg_ph, pred_seg_ph)
# hd_exact_tf = HausdorffDistance_exact(fix_seg_ph, pred_seg_ph, ohe=True)
# Needed for VxmDense type of network
warp_segmentation = vxm.networks.Transform(image_output_shape, interp_method='nearest', nb_feats=nb_labels)
dm_interp = DisplacementMapInterpolator(image_output_shape, 'griddata', step=4)
for MODEL_FILE, DATA_ROOT_DIR in zip(MODEL_FILE_LIST, DATA_ROOT_DIR_LIST):
print('MODEL LOCATION: ', MODEL_FILE)
# data_folder = '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/BASELINE_Affine_ncc___mse_ncc_160606-25022021'
output_folder = os.path.join(DATA_ROOT_DIR, args.outdirname) # '/mnt/EncryptedData1/Users/javier/train_output/DDMR/THESIS/eval/BASELINE_TRAIN_Affine_ncc_EVAL_Affine'
# os.makedirs(os.path.join(output_folder, 'images'), exist_ok=True)
if args.erase:
shutil.rmtree(output_folder, ignore_errors=True)
os.makedirs(output_folder, exist_ok=True)
print('DESTINATION FOLDER: ', output_folder)
try:
network = tf.keras.models.load_model(MODEL_FILE, {#'VxmDenseSemiSupervisedSeg': vxm.networks.VxmDenseSemiSupervisedSeg,
'VxmDense': vxm.networks.VxmDense,
'AdamAccumulated': AdamAccumulated,
'loss': loss_fncs,
'metric': metric_fncs},
compile=False)
except ValueError as e:
enc_features = [32, 64, 128, 256, 512, 1024] # const.ENCODER_FILTERS
dec_features = enc_features[::-1] + [16, 16] # const.ENCODER_FILTERS[::-1]
nb_features = [enc_features, dec_features]
if False: #re.search('^UW|SEGGUIDED_', MODEL_FILE):
network = vxm.networks.VxmDenseSemiSupervisedSeg(inshape=image_output_shape,
nb_labels=nb_labels,
nb_unet_features=nb_features,
int_steps=0,
int_downsize=1,
seg_downsize=1)
else:
# only load the weights into the same model. To get the same runtime
network = vxm.networks.VxmDense(inshape=image_output_shape,
nb_unet_features=nb_features,
int_steps=0)
network.load_weights(MODEL_FILE, by_name=True)
# Record metrics
metrics_file = os.path.join(output_folder, 'metrics.csv')
with open(metrics_file, 'w') as f:
f.write(';'.join(csv_header)+'\n')
ssim = ncc = mse = ms_ssim = dice = hd = 0
with sess.as_default():
sess.run(tf.global_variables_initializer())
network.load_weights(MODEL_FILE, by_name=True)
network.summary(line_length=C.SUMMARY_LINE_LENGTH)
progress_bar = tqdm(enumerate(list_test_files, 1), desc='Evaluation', total=len(list_test_files))
for step, in_batch in progress_bar:
with h5py.File(in_batch, 'r') as f:
fix_img = f['fix_image'][:][np.newaxis, ...] # Add batch axis
mov_img = f['mov_image'][:][np.newaxis, ...]
fix_seg = f['fix_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
mov_seg = f['mov_segmentations'][..., 1:][np.newaxis, ...].astype(np.float32)
fix_centroids = f['fix_centroids'][1:, ...]
isotropic_shape = f['isotropic_shape'][:]
voxel_size = np.divide(fix_img.shape[1:-1], isotropic_shape)
if network.name == 'vxm_dense_semi_supervised_seg':
t0 = time.time()
pred_img, disp_map, pred_seg = network.predict([mov_img, fix_img, mov_seg, fix_seg]) # predict([source, target])
t1 = time.time()
else:
t0 = time.time()
pred_img, disp_map = network.predict([mov_img, fix_img])
pred_seg = warp_segmentation.predict([mov_seg, disp_map])
t1 = time.time()
pred_img = min_max_norm(pred_img)
mov_centroids, missing_lbls = get_segmentations_centroids(mov_seg[0, ...], ohe=True, expected_lbls=range(1, nb_labels + 1))
# pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) # with tps, it returns the pred_centroids directly
pred_centroids = dm_interp(disp_map, mov_centroids, backwards=True) + mov_centroids
# Up sample the segmentation masks to isotropic resolution
zoom_factors = np.diag(scale_transformation(image_output_shape, isotropic_shape))
pred_seg_isot = zoom(pred_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
fix_seg_isot = zoom(fix_seg[0, ...], zoom_factors, order=0)[np.newaxis, ...]
pred_img_isot = zoom(pred_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
fix_img_isot = zoom(fix_img[0, ...], zoom_factors, order=3)[np.newaxis, ...]
# I need the labels to be OHE to compute the segmentation metrics.
# dice, hd, dice_macro = sess.run([dice_tf, hd_tf, dice_macro_tf], {'fix_seg:0': fix_seg, 'pred_seg:0': pred_seg})
dice = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) / np.sum(fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
hd = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd, {'voxelspacing': voxel_size}))
hd95 = np.mean(safe_medpy_metric(pred_seg_isot[0, ...], fix_seg_isot[0, ...], nb_labels, medpy_metrics.hd95, {'voxelspacing': voxel_size}))
dice_macro = np.mean([medpy_metrics.dc(pred_seg_isot[0, ..., l], fix_seg_isot[0, ..., l]) for l in range(nb_labels)])
pred_seg_card = segmentation_ohe_to_cardinal(pred_seg).astype(np.float32)
mov_seg_card = segmentation_ohe_to_cardinal(mov_seg).astype(np.float32)
fix_seg_card = segmentation_ohe_to_cardinal(fix_seg).astype(np.float32)
ssim, ncc, mse, ms_ssim = sess.run([ssim_tf, ncc_tf, mse_tf, ms_ssim_tf], {'fix_img:0': fix_img_isot, 'pred_img:0': pred_img_isot})
ssim = np.mean(ssim) # returns a list of values, which correspond to the ssim of each patch
ms_ssim = ms_ssim[0] # returns an array of shape (1,)
# Rescale the points back to isotropic space, where we have a correspondence voxel <-> mm
# upsample_scale = 128 / 64
fix_centroids_isotropic = fix_centroids * voxel_size
# mov_centroids_isotropic = mov_centroids * upsample_scale
pred_centroids_isotropic = pred_centroids * voxel_size
# fix_centroids_isotropic = np.divide(fix_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
# # mov_centroids_isotropic = np.divide(mov_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
# pred_centroids_isotropic = np.divide(pred_centroids_isotropic, C.IXI_DATASET_iso_to_cubic_scales)
# Now we can measure the TRE in mm
tre_array = target_registration_error(fix_centroids_isotropic, pred_centroids_isotropic, False).eval()
tre = np.mean([v for v in tre_array if not np.isnan(v)])
# ['File', 'SSIM', 'MS-SSIM', 'NCC', 'MSE', 'DICE', 'HD', 'Time', 'TRE', 'No_missing_lbls', 'Missing_lbls']
new_line = [step, ssim, ms_ssim, ncc, mse, dice, dice_macro, hd, hd95, t1-t0, tre, len(missing_lbls), missing_lbls]
with open(metrics_file, 'a') as f:
f.write(';'.join(map(str, new_line))+'\n')
if args.savenifti:
save_nifti(fix_img[0, ...], os.path.join(output_folder, '{:03d}_fix_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(mov_img[0, ...], os.path.join(output_folder, '{:03d}_mov_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(pred_img[0, ...], os.path.join(output_folder, '{:03d}_pred_img_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(fix_seg[0, ...], os.path.join(output_folder, '{:03d}_fix_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(mov_seg[0, ...], os.path.join(output_folder, '{:03d}_mov_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
save_nifti(pred_seg[0, ...], os.path.join(output_folder, '{:03d}_pred_seg_ssim_{:.03f}_dice_{:.03f}.nii.gz'.format(step, ssim, dice)), verbose=False)
# with h5py.File(os.path.join(output_folder, '{:03d}_centroids.h5'.format(step)), 'w') as f:
# f.create_dataset('fix_centroids', dtype=np.float32, data=fix_centroids)
# f.create_dataset('mov_centroids', dtype=np.float32, data=mov_centroids)
# f.create_dataset('pred_centroids', dtype=np.float32, data=pred_centroids)
# f.create_dataset('fix_centroids_isotropic', dtype=np.float32, data=fix_centroids_isotropic)
# f.create_dataset('mov_centroids_isotropic', dtype=np.float32, data=mov_centroids_isotropic)
# magnitude = np.sqrt(np.sum(disp_map[0, ...] ** 2, axis=-1))
# _ = plt.hist(magnitude.flatten())
# plt.title('Histogram of disp. magnitudes')
# plt.show(block=False)
# plt.savefig(os.path.join(output_folder, '{:03d}_hist_mag_ssim_{:.03f}_dice_{:.03f}.png'.format(step, ssim, dice)))
# plt.close()
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, seg_batches=[fix_seg_card, mov_seg_card, pred_seg_card], filename=os.path.join(output_folder, '{:03d}_figures_seg.png'.format(step)), show=False, step=16)
plot_predictions(img_batches=[fix_img, mov_img, pred_img], disp_map_batch=disp_map, filename=os.path.join(output_folder, '{:03d}_figures_img.png'.format(step)), show=False, step=16)
save_disp_map_img(disp_map, 'Displacement map', os.path.join(output_folder, '{:03d}_disp_map_fig.png'.format(step)), show=False, step=16)
progress_bar.set_description('SSIM {:.04f}\tDICE: {:.04f}'.format(ssim, dice))
print('Summary\n=======\n')
metrics_df = pd.read_csv(metrics_file, sep=';', header=0)
print('\nAVG:\n')
print(metrics_df.mean(axis=0))
print('\nSTD:\n')
print(metrics_df.std(axis=0))
print('\nHD95perc:\n')
print(metrics_df['HD'].describe(percentiles=[.95]))
print('\n=======\n')
tf.keras.backend.clear_session()
# sess.close()
del network
print('Done')