import monai import torch import itk import numpy as np import matplotlib.pyplot as plt import matplotlib as mpl import os import nibabel as nib import sys import json from pathlib import Path mpl.rc('figure', max_open_warning = 0) ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute()) sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils')) sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/loss_function')) sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/preprocess')) from process_data import ( take_data_pairs, subdivide_list_of_data_pairs ) from utils import ( plot_2D_vector_field, jacobian_determinant, plot_2D_deformation, load_json ) from losses import ( warp_func, warp_nearest_func, lncc_loss_func, dice_loss_func2, dice_loss_func ) def load_seg_dataset(data_list): transform_seg_available = monai.transforms.Compose( transforms=[ monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True, allow_missing_keys=True), #monai.transforms.TransposeD( #keys=['img', 'seg'], indices=(2, 1, 0)), monai.transforms.AddChannelD(keys=['img', 'seg'], allow_missing_keys=True), monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest'), allow_missing_keys=True), #monai.transforms.OrientationD(keys=['img', 'seg'], axcodes='RAS'), monai.transforms.ToTensorD(keys=['img', 'seg'], allow_missing_keys=True) ] ) itk.ProcessObject.SetGlobalWarningDisplay(False) dataset_seg_available_train = monai.data.CacheDataset( data=data_list, transform=transform_seg_available, cache_num=16, hash_as_key=True ) return dataset_seg_available_train def load_reg_dataset(data_list): transform_pair = monai.transforms.Compose( transforms=[ monai.transforms.LoadImageD( keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True), #monai.transforms.TransposeD(keys=['img1', 'seg1', 'img2', 'seg2'], indices=(2, 1, 0), allow_missing_keys=True), # if resize is not None else monai.transforms.Identity() monai.transforms.ToTensorD( keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), monai.transforms.AddChannelD( keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=( 'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True), #monai.transforms.OrientationD( #keys=['img1', 'seg1', 'img2', 'seg2'], axcodes='RAS', allow_missing_keys=True), monai.transforms.ConcatItemsD( keys=['img1', 'img2'], name='img12', dim=0), monai.transforms.DeleteItemsD(keys=['img1', 'img2']), ] ) dataset_pairs_train_subdivided = { seg_availability: monai.data.CacheDataset( data=data, transform=transform_pair, cache_num=32, hash_as_key=True ) for seg_availability, data in data_list.items() } return dataset_pairs_train_subdivided def get_nii_info(data, reg=False): headers = [] affines = [] ids = [] if not reg: for i in range(len(data)): item = data[i] if 'seg' in item.keys(): id = os.path.basename(item['seg']).split('.')[0] seg = nib.load(item['seg']) num_labels = len(np.unique(seg.get_fdata())) headers.append(seg.header) affines.append(seg.affine) ids.append(id) else: id = os.path.basename(item['img']).split('.')[0] img = nib.load(item['img']) headers.append(img.header) affines.append(img.affine) ids.append(id) else: headers = {'00': [], '01': [], '10': [], '11': []} affines = {'00': [], '01': [], '10': [], '11': []} ids = {'00': [], '01': [], '10': [], '11': []} for i in range(len(data)): header = {} affine = {} id = {} item = data[i] keys = item.keys() if 'seg1' in keys and 'seg2' in keys: for key in keys: idd = os.path.basename(item[key]).split('.')[0] ele = nib.load(item[key]) header[key] = ele.header affine[key] = ele.affine id[key] = idd headers['11'].append(header) affines['11'].append(affine) ids['11'].append(id) elif 'seg1' in keys: for key in keys: idd = os.path.basename(item[key]).split('.')[0] ele = nib.load(item[key]) header[key] = ele.header affine[key] = ele.affine id[key] = idd headers['10'].append(header) affines['10'].append(affine) ids['10'].append(id) elif 'seg2' in keys: for key in keys: idd = os.path.basename(item[key]).split('.')[0] ele = nib.load(item[key]) header[key] = ele.header affine[key] = ele.affine id[key] = idd headers['01'].append(header) affines['01'].append(affine) ids['01'].append(id) else: for key in keys: idd = os.path.basename(item[key]).split('.')[0] ele = nib.load(item[key]) header[key] = ele.header affine[key] = ele.affine id[key] = idd headers['00'].append(header) affines['00'].append(affine) ids['00'].append(id) return headers, affines, ids def seg_training_inference(seg_net, device, model_path, output_path, num_label, json_path=None, data=None): if json_path is not None: assert data is None json_file = load_json(json_path) raw_data = json_file['total_test'] else: assert data is not None raw_data = data headers, affines, ids = get_nii_info(raw_data, reg=False) seg_net.to(device) seg_net.load_state_dict(torch.load(model_path, map_location=device)) seg_net.eval() dice_metric = monai.metrics.DiceMetric(include_background=False, reduction='none') data_seg = load_seg_dataset(raw_data) k = 0 eval_losses = [] eval_los = [] for i in data_seg: has_seg = False header1 = headers[k] affine1 = affines[k] id = ids[k] data_item = i test_input = data_item['img'] if 'seg' in data_item.keys(): test_gt = data_item['seg'] has_seg = True with torch.no_grad(): test_seg_predicted = seg_net(test_input.unsqueeze(0).to(device)).cpu() prediction = torch.argmax(torch.softmax( test_seg_predicted, dim=1), dim=1, keepdim=True)[0, 0] prediction1 = torch.argmax(torch.softmax( test_seg_predicted, dim=1), dim=1, keepdim=True) onehot_pred = monai.networks.one_hot(prediction1, num_label) if has_seg: onehot_gt = monai.networks.one_hot(test_gt.unsqueeze(0), num_label) dsc = dice_metric(onehot_pred, onehot_gt).numpy() eval_los.append(dsc) eval_loss = f"Scan ID: {id}, dice score: {dsc}" eval_losses.append(eval_loss) pred_np = prediction.detach().cpu().numpy() print(f'{id}: {np.unique(pred_np)}') pred_np = pred_np.astype('int16') nii = nib.Nifti1Image(pred_np, affine=affine1, header=header1) nii.header.get_xyzt_units() nib.save(nii, (os.path.join(output_path, id + '.nii.gz'))) k += 1 del test_seg_predicted average = np.mean(eval_los, 0) with open(os.path.join(output_path, 'seg_dsc.txt'), 'w') as f: for s in eval_losses: f.write(s + '\n') f.write('\n\nAverage Dice Score: ' + str(average)) torch.cuda.empty_cache() def reg_training_inference(reg_net, device, model_path, output_path, num_label, json_path=None, data=None): if json_path is not None: assert data is None json_file = load_json(json_path) raw_data = json_file['total_test'] else: assert data is not None raw_data = data # Run this cell to try out reg net on a random validation pair reg_net.to(device) reg_net.load_state_dict(torch.load(model_path, map_location=device)) reg_net.eval() data_list = take_data_pairs(raw_data) headers, affines, ids = get_nii_info(data_list, reg=True) subvided_data_list = subdivide_list_of_data_pairs(data_list) subvided_dataset = load_reg_dataset(subvided_data_list) warp = warp_func() warp_nearest = warp_nearest_func() lncc_loss = lncc_loss_func() k = 0 if len(subvided_data_list['01']) != 0: dataset01 = subvided_dataset['01'] #test_len = int(len(dataset01) / 4) for j in range(len(dataset01)): data_item = dataset01[j] img12 = data_item['img12'].unsqueeze(0).to(device) moving_raw_seg = data_item['seg2'].unsqueeze(0).to(device) moving_seg = monai.networks.one_hot(moving_raw_seg, num_label) id = ids['01'][k] affine = affines['01'][k] header = headers['01'][k] with torch.no_grad(): reg_net_example_output = reg_net(img12) example_warped_image = warp( img12[:, [1], :, :, :], # moving image reg_net_example_output # warping ) example_warped_seg = warp_nearest( moving_seg, reg_net_example_output ) moving_img = img12[0, 1, :, :, :] target_img = img12[0, 0, :, :, :] id_target_img = id['img1'] id_moving_img = id['img2'] head_target_img = header['img1'] head_target_seg = header['img1'] aff_target_img = affine['img1'] aff_target_seg = affine['img1'] prediction = torch.argmax(torch.softmax( example_warped_seg, dim=1), dim=1, keepdim=True)[0, 0] prediction1 = torch.argmax(torch.softmax( example_warped_seg, dim=1), dim=1, keepdim=True) warped_img_np = example_warped_image[0, 0].detach().cpu().numpy() #warped_img_np = np.transpose(warped_img_np, (2, 1, 0)) warped_seg_np = prediction.detach().cpu().numpy() #warped_seg_np = np.transpose(warped_seg_np, (2, 1, 0)) nii_seg = nib.Nifti1Image( warped_seg_np, affine=aff_target_seg, header=head_target_seg) nii = nib.Nifti1Image( warped_img_np, affine=aff_target_img, header=head_target_img) nii.to_filename(os.path.join( output_path, id_moving_img + '_to_' + id_target_img + '.nii.gz')) nii_seg.to_filename(os.path.join( output_path, id_moving_img + '_to_' + id_target_img + '_seg.nii.gz')) grid_spacing = 5 det = jacobian_determinant(reg_net_example_output.cpu().detach()[0]) visualize(target_img.cpu(), id_target_img, moving_img.cpu(), id_moving_img, example_warped_image[0, 0].cpu(), reg_net_example_output.cpu().detach()[0], det, grid_spacing, normalize_by='slice', cmap='gray', threshold=None, linewidth=1, color='darkblue', downsampling=None, threshold_det=0, output=output_path ) k += 1 del reg_net_example_output, img12, example_warped_image, example_warped_seg if len(subvided_data_list['11']) != 0: dataset11 = subvided_dataset['11'] k = 0 eval_losses_img = [] eval_losses_seg = [] eval_los = [] #test_len = int(len(dataset11) / 4) for i in range(len(dataset11)): data_item = dataset11[i] img12 = data_item['img12'].unsqueeze(0).to(device) gt_raw_seg = data_item['seg1'].unsqueeze(0).to(device) moving_raw_seg = data_item['seg2'].unsqueeze(0).to(device) moving_seg = monai.networks.one_hot(moving_raw_seg, num_label) gt_seg = monai.networks.one_hot(gt_raw_seg, num_label) id = ids['11'][k] affine = affines['11'][k] header = headers['11'][k] with torch.no_grad(): reg_net_example_output = reg_net(img12) example_warped_image = warp( img12[:, [1], :, :, :], # moving image reg_net_example_output # warping ) example_warped_seg = warp_nearest( moving_seg, reg_net_example_output ) moving_img = img12[0, 1, :, :, :] target_img = img12[0, 0, :, :, :] id_target_img = id['img1'] id_moving_img = id['img2'] head_target_img = header['img1'] head_target_seg = header['seg1'] aff_target_img = affine['img1'] aff_target_seg = affine['seg1'] dice_metric = monai.metrics.DiceMetric(include_background=False, reduction='none') prediction = torch.argmax(torch.softmax( example_warped_seg, dim=1), dim=1, keepdim=True)[0, 0] prediction1 = torch.argmax(torch.softmax( example_warped_seg, dim=1), dim=1, keepdim=True) onehot_pred = monai.networks.one_hot(prediction1, num_label) dsc = dice_metric(onehot_pred, gt_seg).detach().cpu().numpy() eval_los.append(dsc) eval_loss_seg = f"Scan {id_moving_img} to {id_target_img}, dice score: {dsc}" eval_losses_seg.append(eval_loss_seg) warped_img_np = example_warped_image[0, 0].detach().cpu().numpy() #warped_img_np = np.transpose(warped_img_np, (2, 1, 0)) warped_seg_np = prediction.detach().cpu().numpy() #warped_seg_np = np.transpose(warped_seg_np, (2, 1, 0)) nii_seg = nib.Nifti1Image( warped_seg_np, affine=aff_target_seg, header=head_target_seg) nii = nib.Nifti1Image( warped_img_np, affine=aff_target_img, header=head_target_img) nii.to_filename(os.path.join( output_path, id_moving_img + '_to_' + id_target_img + '.nii.gz')) nii_seg.to_filename(os.path.join( output_path, id_moving_img + '_to_' + id_target_img + '_seg.nii.gz')) grid_spacing = 5 det = jacobian_determinant(reg_net_example_output.cpu().detach()[0]) visualize(target_img.cpu(), id_target_img, moving_img.cpu(), id_moving_img, example_warped_image[0, 0].cpu(), reg_net_example_output.cpu().detach()[0], det, grid_spacing, normalize_by='slice', cmap='gray', threshold=None, linewidth=1, color='darkblue', downsampling=None, threshold_det=0, output=output_path ) loss = lncc_loss(example_warped_image, img12[:, [0], :, :, :]).item() eval_loss_img = f"Warped {id_moving_img} to {id_target_img}, similarity loss: {loss}, number of folds: {(det<=0).sum()}" eval_losses_img.append(eval_loss_img) k += 1 del reg_net_example_output, img12, example_warped_image, example_warped_seg with open(os.path.join(output_path, "reg_img_losses.txt"), 'w') as f: for s in eval_losses_img: f.write(s + '\n') average = np.mean(eval_los, 0) with open(os.path.join(output_path, "reg_seg_dsc.txt"), 'w') as f: for s in eval_losses_seg: f.write(s + '\n') f.write('\n\nAverage Dice Score: ' + str(average)) torch.cuda.empty_cache() def visualize(target, target_id, moving, moving_id, warped, vector_field, det, grid_spacing, normalize_by='volume', cmap=None, threshold=None, linewidth=1, color='red', downsampling=None, threshold_det=None, output=None ): if normalize_by == "slice": vmin = None vmax_moving = None vmax_target = None vmax_warped = None vmax_det = None elif normalize_by == "volume": vmin = 0 vmax_moving = moving.max().item() vmax_target = target.max().item() vmax_warped = warped.max().item() vmax_det = det.max().item() else: raise(ValueError( f"Invalid value '{normalize_by}' given for normalize_by")) # half-way slices plt.figure(figsize=(24, 24)) x, y, z = np.array(moving.shape)//2 moving_imgs = (moving[x, :, :], moving[:, y, :], moving[:, :, z]) target_imgs = (target[x, :, :], target[:, y, :], target[:, :, z]) warped_imgs = (warped[x, :, :], warped[:, y, :], warped[:, :, z]) det_imgs = (det[x, :, :], det[:, y, :], det[:, :, z]) for i in range(3): im = moving_imgs[i] plt.subplot(6, 3, i+1) plt.axis('off') plt.title(f'moving image: {moving_id}') plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_moving, cmap=cmap) # threshold will be useful when displaying jacobian determinant images; # we will want to clearly see where the jacobian determinant is negative if threshold is not None: red = np.zeros(im.shape+(4,)) # RGBA array red[im <= threshold] = [1, 0, 0, 1] plt.imshow(red, origin='lower') for k in range(3): j = k + 4 im = target_imgs[k] plt.subplot(6, 3, j) plt.axis('off') plt.title(f'target image: {target_id}') plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_target, cmap=cmap) # threshold will be useful when displaying jacobian determinant images; # we will want to clearly see where the jacobian determinant is negative if threshold is not None: red = np.zeros(im.shape+(4,)) # RGBA array red[im <= threshold] = [1, 0, 0, 1] plt.imshow(red, origin='lower') for m in range(3): j = 7 + m im = warped_imgs[m] plt.subplot(6, 3, j) plt.axis('off') plt.title(f'warped image: {moving_id} to {target_id}') plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_warped, cmap=cmap) # threshold will be useful when displaying jacobian determinant images; # we will want to clearly see where the jacobian determinant is negative if threshold is not None: red = np.zeros(im.shape+(4,)) # RGBA array red[im <= threshold] = [1, 0, 0, 1] plt.imshow(red, origin='lower') if downsampling is None: # guess a reasonable downsampling value to make a nice plot downsampling = max(1, int(max(vector_field.shape[1:])) >> 5) x, y, z = np.array(vector_field.shape[1:])//2 # half-way slices plt.subplot(6, 3, 10) plt.axis('off') plt.title(f'deformation vector field: {moving_id} to {target_id}') plot_2D_vector_field(vector_field[[1, 2], x, :, :], downsampling) plt.subplot(6, 3, 11) plt.axis('off') plt.title(f'deformation vector field: {moving_id} to {target_id}') plot_2D_vector_field(vector_field[[0, 2], :, y, :], downsampling) plt.subplot(6, 3, 12) plt.axis('off') plt.title(f'deformation vector field: {moving_id} to {target_id}') plot_2D_vector_field(vector_field[[0, 1], :, :, z], downsampling) x, y, z = np.array(vector_field.shape[1:])//2 # half-way slices plt.subplot(6, 3, 13) plt.axis('off') plt.title(f'deformation vector field on grid: {moving_id} to {target_id}') plot_2D_deformation( vector_field[[1, 2], x, :, :], grid_spacing, linewidth=linewidth, color=color) plt.subplot(6, 3, 14) plt.axis('off') plt.title(f'deformation vector field on grid: {moving_id} to {target_id}') plot_2D_deformation( vector_field[[0, 2], :, y, :], grid_spacing, linewidth=linewidth, color=color) plt.subplot(6, 3, 15) plt.axis('off') plt.title(f'deformation vector field on grid: {moving_id} to {target_id}') plot_2D_deformation( vector_field[[0, 1], :, :, z], grid_spacing, linewidth=linewidth, color=color) for n in range(3): o = n + 16 im = det_imgs[n] plt.subplot(6, 3, o) plt.axis('off') plt.title(f'jacobian determinant: {moving_id} to {target_id}') plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_det, cmap=None) # threshold will be useful when displaying jacobian determinant images; # we will want to clearly see where the jacobian determinant is negative if threshold_det is not None: red = np.zeros(im.shape+(4,)) # RGBA array red[im <= threshold_det] = [1, 0, 0, 1] plt.imshow(red, origin='lower') plt.savefig(os.path.join( output, f'reg_net_infer_{moving_id}_to_{target_id}.png'))