|
import monai |
|
import torch |
|
import itk |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib as mpl |
|
import os |
|
import nibabel as nib |
|
import sys |
|
import json |
|
from pathlib import Path |
|
mpl.rc('figure', max_open_warning = 0) |
|
ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute()) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils')) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/loss_function')) |
|
sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/preprocess')) |
|
|
|
from process_data import ( |
|
take_data_pairs, subdivide_list_of_data_pairs |
|
) |
|
from utils import ( |
|
plot_2D_vector_field, jacobian_determinant, plot_2D_deformation, load_json |
|
) |
|
from losses import ( |
|
warp_func, warp_nearest_func, lncc_loss_func, dice_loss_func2, dice_loss_func |
|
) |
|
|
|
def load_seg_dataset(data_list): |
|
transform_seg_available = monai.transforms.Compose( |
|
transforms=[ |
|
monai.transforms.LoadImageD(keys=['img', 'seg'], image_only=True, allow_missing_keys=True), |
|
|
|
|
|
monai.transforms.AddChannelD(keys=['img', 'seg'], allow_missing_keys=True), |
|
monai.transforms.SpacingD(keys=['img', 'seg'], pixdim=(1., 1., 1.), mode=('trilinear', 'nearest'), allow_missing_keys=True), |
|
|
|
monai.transforms.ToTensorD(keys=['img', 'seg'], allow_missing_keys=True) |
|
] |
|
) |
|
itk.ProcessObject.SetGlobalWarningDisplay(False) |
|
dataset_seg_available_train = monai.data.CacheDataset( |
|
data=data_list, |
|
transform=transform_seg_available, |
|
cache_num=16, |
|
hash_as_key=True |
|
) |
|
return dataset_seg_available_train |
|
|
|
|
|
def load_reg_dataset(data_list): |
|
transform_pair = monai.transforms.Compose( |
|
transforms=[ |
|
monai.transforms.LoadImageD( |
|
keys=['img1', 'seg1', 'img2', 'seg2'], image_only=True, allow_missing_keys=True), |
|
|
|
|
|
monai.transforms.ToTensorD( |
|
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), |
|
monai.transforms.AddChannelD( |
|
keys=['img1', 'seg1', 'img2', 'seg2'], allow_missing_keys=True), |
|
monai.transforms.SpacingD(keys=['img1', 'seg1', 'img2', 'seg2'], pixdim=(1., 1., 1.), mode=( |
|
'trilinear', 'nearest', 'trilinear', 'nearest'), allow_missing_keys=True), |
|
|
|
|
|
monai.transforms.ConcatItemsD( |
|
keys=['img1', 'img2'], name='img12', dim=0), |
|
monai.transforms.DeleteItemsD(keys=['img1', 'img2']), |
|
] |
|
) |
|
dataset_pairs_train_subdivided = { |
|
seg_availability: monai.data.CacheDataset( |
|
data=data, |
|
transform=transform_pair, |
|
cache_num=32, |
|
hash_as_key=True |
|
) |
|
for seg_availability, data in data_list.items() |
|
} |
|
|
|
return dataset_pairs_train_subdivided |
|
|
|
def get_nii_info(data, reg=False): |
|
headers = [] |
|
affines = [] |
|
ids = [] |
|
if not reg: |
|
for i in range(len(data)): |
|
item = data[i] |
|
if 'seg' in item.keys(): |
|
id = os.path.basename(item['seg']).split('.')[0] |
|
seg = nib.load(item['seg']) |
|
num_labels = len(np.unique(seg.get_fdata())) |
|
headers.append(seg.header) |
|
affines.append(seg.affine) |
|
ids.append(id) |
|
else: |
|
id = os.path.basename(item['img']).split('.')[0] |
|
img = nib.load(item['img']) |
|
headers.append(img.header) |
|
affines.append(img.affine) |
|
ids.append(id) |
|
else: |
|
headers = {'00': [], '01': [], '10': [], '11': []} |
|
affines = {'00': [], '01': [], '10': [], '11': []} |
|
ids = {'00': [], '01': [], '10': [], '11': []} |
|
for i in range(len(data)): |
|
header = {} |
|
affine = {} |
|
id = {} |
|
item = data[i] |
|
keys = item.keys() |
|
if 'seg1' in keys and 'seg2' in keys: |
|
for key in keys: |
|
idd = os.path.basename(item[key]).split('.')[0] |
|
ele = nib.load(item[key]) |
|
header[key] = ele.header |
|
affine[key] = ele.affine |
|
id[key] = idd |
|
|
|
headers['11'].append(header) |
|
affines['11'].append(affine) |
|
ids['11'].append(id) |
|
elif 'seg1' in keys: |
|
for key in keys: |
|
idd = os.path.basename(item[key]).split('.')[0] |
|
ele = nib.load(item[key]) |
|
header[key] = ele.header |
|
affine[key] = ele.affine |
|
id[key] = idd |
|
|
|
headers['10'].append(header) |
|
affines['10'].append(affine) |
|
ids['10'].append(id) |
|
elif 'seg2' in keys: |
|
for key in keys: |
|
idd = os.path.basename(item[key]).split('.')[0] |
|
ele = nib.load(item[key]) |
|
header[key] = ele.header |
|
affine[key] = ele.affine |
|
id[key] = idd |
|
|
|
headers['01'].append(header) |
|
affines['01'].append(affine) |
|
ids['01'].append(id) |
|
else: |
|
for key in keys: |
|
idd = os.path.basename(item[key]).split('.')[0] |
|
ele = nib.load(item[key]) |
|
header[key] = ele.header |
|
affine[key] = ele.affine |
|
id[key] = idd |
|
|
|
headers['00'].append(header) |
|
affines['00'].append(affine) |
|
ids['00'].append(id) |
|
|
|
return headers, affines, ids |
|
|
|
|
|
def seg_training_inference(seg_net, device, model_path, output_path, num_label, json_path=None, data=None): |
|
if json_path is not None: |
|
assert data is None |
|
json_file = load_json(json_path) |
|
raw_data = json_file['total_test'] |
|
else: |
|
assert data is not None |
|
raw_data = data |
|
headers, affines, ids = get_nii_info(raw_data, reg=False) |
|
seg_net.to(device) |
|
seg_net.load_state_dict(torch.load(model_path, map_location=device)) |
|
seg_net.eval() |
|
dice_metric = monai.metrics.DiceMetric(include_background=False, reduction='none') |
|
data_seg = load_seg_dataset(raw_data) |
|
k = 0 |
|
eval_losses = [] |
|
eval_los = [] |
|
for i in data_seg: |
|
has_seg = False |
|
header1 = headers[k] |
|
affine1 = affines[k] |
|
id = ids[k] |
|
data_item = i |
|
test_input = data_item['img'] |
|
if 'seg' in data_item.keys(): |
|
test_gt = data_item['seg'] |
|
has_seg = True |
|
with torch.no_grad(): |
|
test_seg_predicted = seg_net(test_input.unsqueeze(0).to(device)).cpu() |
|
|
|
prediction = torch.argmax(torch.softmax( |
|
test_seg_predicted, dim=1), dim=1, keepdim=True)[0, 0] |
|
prediction1 = torch.argmax(torch.softmax( |
|
test_seg_predicted, dim=1), dim=1, keepdim=True) |
|
|
|
onehot_pred = monai.networks.one_hot(prediction1, num_label) |
|
if has_seg: |
|
onehot_gt = monai.networks.one_hot(test_gt.unsqueeze(0), num_label) |
|
dsc = dice_metric(onehot_pred, onehot_gt).numpy() |
|
eval_los.append(dsc) |
|
eval_loss = f"Scan ID: {id}, dice score: {dsc}" |
|
eval_losses.append(eval_loss) |
|
|
|
pred_np = prediction.detach().cpu().numpy() |
|
print(f'{id}: {np.unique(pred_np)}') |
|
|
|
pred_np = pred_np.astype('int16') |
|
nii = nib.Nifti1Image(pred_np, affine=affine1, header=header1) |
|
nii.header.get_xyzt_units() |
|
nib.save(nii, (os.path.join(output_path, id + '.nii.gz'))) |
|
k += 1 |
|
|
|
del test_seg_predicted |
|
|
|
average = np.mean(eval_los, 0) |
|
|
|
with open(os.path.join(output_path, 'seg_dsc.txt'), 'w') as f: |
|
for s in eval_losses: |
|
f.write(s + '\n') |
|
f.write('\n\nAverage Dice Score: ' + str(average)) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def reg_training_inference(reg_net, device, model_path, output_path, num_label, json_path=None, data=None): |
|
if json_path is not None: |
|
assert data is None |
|
json_file = load_json(json_path) |
|
raw_data = json_file['total_test'] |
|
else: |
|
assert data is not None |
|
raw_data = data |
|
|
|
reg_net.to(device) |
|
reg_net.load_state_dict(torch.load(model_path, map_location=device)) |
|
reg_net.eval() |
|
data_list = take_data_pairs(raw_data) |
|
headers, affines, ids = get_nii_info(data_list, reg=True) |
|
subvided_data_list = subdivide_list_of_data_pairs(data_list) |
|
subvided_dataset = load_reg_dataset(subvided_data_list) |
|
warp = warp_func() |
|
warp_nearest = warp_nearest_func() |
|
lncc_loss = lncc_loss_func() |
|
k = 0 |
|
if len(subvided_data_list['01']) != 0: |
|
dataset01 = subvided_dataset['01'] |
|
|
|
for j in range(len(dataset01)): |
|
data_item = dataset01[j] |
|
img12 = data_item['img12'].unsqueeze(0).to(device) |
|
moving_raw_seg = data_item['seg2'].unsqueeze(0).to(device) |
|
moving_seg = monai.networks.one_hot(moving_raw_seg, num_label) |
|
id = ids['01'][k] |
|
affine = affines['01'][k] |
|
header = headers['01'][k] |
|
with torch.no_grad(): |
|
reg_net_example_output = reg_net(img12) |
|
|
|
example_warped_image = warp( |
|
img12[:, [1], :, :, :], |
|
reg_net_example_output |
|
) |
|
example_warped_seg = warp_nearest( |
|
moving_seg, |
|
reg_net_example_output |
|
) |
|
moving_img = img12[0, 1, :, :, :] |
|
target_img = img12[0, 0, :, :, :] |
|
id_target_img = id['img1'] |
|
id_moving_img = id['img2'] |
|
head_target_img = header['img1'] |
|
head_target_seg = header['img1'] |
|
aff_target_img = affine['img1'] |
|
aff_target_seg = affine['img1'] |
|
prediction = torch.argmax(torch.softmax( |
|
example_warped_seg, dim=1), dim=1, keepdim=True)[0, 0] |
|
prediction1 = torch.argmax(torch.softmax( |
|
example_warped_seg, dim=1), dim=1, keepdim=True) |
|
warped_img_np = example_warped_image[0, 0].detach().cpu().numpy() |
|
|
|
warped_seg_np = prediction.detach().cpu().numpy() |
|
|
|
nii_seg = nib.Nifti1Image( |
|
warped_seg_np, affine=aff_target_seg, header=head_target_seg) |
|
nii = nib.Nifti1Image( |
|
warped_img_np, affine=aff_target_img, header=head_target_img) |
|
nii.to_filename(os.path.join( |
|
output_path, id_moving_img + '_to_' + id_target_img + '.nii.gz')) |
|
nii_seg.to_filename(os.path.join( |
|
output_path, id_moving_img + '_to_' + id_target_img + '_seg.nii.gz')) |
|
grid_spacing = 5 |
|
det = jacobian_determinant(reg_net_example_output.cpu().detach()[0]) |
|
visualize(target_img.cpu(), |
|
id_target_img, |
|
moving_img.cpu(), |
|
id_moving_img, |
|
example_warped_image[0, 0].cpu(), |
|
reg_net_example_output.cpu().detach()[0], |
|
det, |
|
grid_spacing, |
|
normalize_by='slice', |
|
cmap='gray', |
|
threshold=None, |
|
linewidth=1, |
|
color='darkblue', |
|
downsampling=None, |
|
threshold_det=0, |
|
output=output_path |
|
) |
|
k += 1 |
|
del reg_net_example_output, img12, example_warped_image, example_warped_seg |
|
|
|
if len(subvided_data_list['11']) != 0: |
|
dataset11 = subvided_dataset['11'] |
|
k = 0 |
|
eval_losses_img = [] |
|
eval_losses_seg = [] |
|
eval_los = [] |
|
|
|
for i in range(len(dataset11)): |
|
data_item = dataset11[i] |
|
img12 = data_item['img12'].unsqueeze(0).to(device) |
|
gt_raw_seg = data_item['seg1'].unsqueeze(0).to(device) |
|
moving_raw_seg = data_item['seg2'].unsqueeze(0).to(device) |
|
moving_seg = monai.networks.one_hot(moving_raw_seg, num_label) |
|
gt_seg = monai.networks.one_hot(gt_raw_seg, num_label) |
|
id = ids['11'][k] |
|
affine = affines['11'][k] |
|
header = headers['11'][k] |
|
with torch.no_grad(): |
|
reg_net_example_output = reg_net(img12) |
|
|
|
example_warped_image = warp( |
|
img12[:, [1], :, :, :], |
|
reg_net_example_output |
|
) |
|
example_warped_seg = warp_nearest( |
|
moving_seg, |
|
reg_net_example_output |
|
) |
|
moving_img = img12[0, 1, :, :, :] |
|
target_img = img12[0, 0, :, :, :] |
|
id_target_img = id['img1'] |
|
id_moving_img = id['img2'] |
|
head_target_img = header['img1'] |
|
head_target_seg = header['seg1'] |
|
aff_target_img = affine['img1'] |
|
aff_target_seg = affine['seg1'] |
|
dice_metric = monai.metrics.DiceMetric(include_background=False, reduction='none') |
|
prediction = torch.argmax(torch.softmax( |
|
example_warped_seg, dim=1), dim=1, keepdim=True)[0, 0] |
|
prediction1 = torch.argmax(torch.softmax( |
|
example_warped_seg, dim=1), dim=1, keepdim=True) |
|
onehot_pred = monai.networks.one_hot(prediction1, num_label) |
|
dsc = dice_metric(onehot_pred, gt_seg).detach().cpu().numpy() |
|
eval_los.append(dsc) |
|
eval_loss_seg = f"Scan {id_moving_img} to {id_target_img}, dice score: {dsc}" |
|
eval_losses_seg.append(eval_loss_seg) |
|
warped_img_np = example_warped_image[0, 0].detach().cpu().numpy() |
|
|
|
warped_seg_np = prediction.detach().cpu().numpy() |
|
|
|
nii_seg = nib.Nifti1Image( |
|
warped_seg_np, affine=aff_target_seg, header=head_target_seg) |
|
nii = nib.Nifti1Image( |
|
warped_img_np, affine=aff_target_img, header=head_target_img) |
|
nii.to_filename(os.path.join( |
|
output_path, id_moving_img + '_to_' + id_target_img + '.nii.gz')) |
|
nii_seg.to_filename(os.path.join( |
|
output_path, id_moving_img + '_to_' + id_target_img + '_seg.nii.gz')) |
|
grid_spacing = 5 |
|
det = jacobian_determinant(reg_net_example_output.cpu().detach()[0]) |
|
visualize(target_img.cpu(), |
|
id_target_img, |
|
moving_img.cpu(), |
|
id_moving_img, |
|
example_warped_image[0, 0].cpu(), |
|
reg_net_example_output.cpu().detach()[0], |
|
det, |
|
grid_spacing, |
|
normalize_by='slice', |
|
cmap='gray', |
|
threshold=None, |
|
linewidth=1, |
|
color='darkblue', |
|
downsampling=None, |
|
threshold_det=0, |
|
output=output_path |
|
) |
|
loss = lncc_loss(example_warped_image, img12[:, [0], :, :, :]).item() |
|
eval_loss_img = f"Warped {id_moving_img} to {id_target_img}, similarity loss: {loss}, number of folds: {(det<=0).sum()}" |
|
eval_losses_img.append(eval_loss_img) |
|
k += 1 |
|
del reg_net_example_output, img12, example_warped_image, example_warped_seg |
|
|
|
with open(os.path.join(output_path, "reg_img_losses.txt"), 'w') as f: |
|
for s in eval_losses_img: |
|
f.write(s + '\n') |
|
|
|
average = np.mean(eval_los, 0) |
|
with open(os.path.join(output_path, "reg_seg_dsc.txt"), 'w') as f: |
|
for s in eval_losses_seg: |
|
f.write(s + '\n') |
|
f.write('\n\nAverage Dice Score: ' + str(average)) |
|
torch.cuda.empty_cache() |
|
|
|
def visualize(target, |
|
target_id, |
|
moving, |
|
moving_id, |
|
warped, |
|
vector_field, |
|
det, |
|
grid_spacing, |
|
normalize_by='volume', |
|
cmap=None, |
|
threshold=None, |
|
linewidth=1, |
|
color='red', |
|
downsampling=None, |
|
threshold_det=None, |
|
output=None |
|
): |
|
if normalize_by == "slice": |
|
vmin = None |
|
vmax_moving = None |
|
vmax_target = None |
|
vmax_warped = None |
|
vmax_det = None |
|
elif normalize_by == "volume": |
|
vmin = 0 |
|
vmax_moving = moving.max().item() |
|
vmax_target = target.max().item() |
|
vmax_warped = warped.max().item() |
|
vmax_det = det.max().item() |
|
else: |
|
raise(ValueError( |
|
f"Invalid value '{normalize_by}' given for normalize_by")) |
|
|
|
|
|
plt.figure(figsize=(24, 24)) |
|
x, y, z = np.array(moving.shape)//2 |
|
moving_imgs = (moving[x, :, :], moving[:, y, :], moving[:, :, z]) |
|
target_imgs = (target[x, :, :], target[:, y, :], target[:, :, z]) |
|
warped_imgs = (warped[x, :, :], warped[:, y, :], warped[:, :, z]) |
|
det_imgs = (det[x, :, :], det[:, y, :], det[:, :, z]) |
|
for i in range(3): |
|
im = moving_imgs[i] |
|
plt.subplot(6, 3, i+1) |
|
plt.axis('off') |
|
plt.title(f'moving image: {moving_id}') |
|
plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_moving, cmap=cmap) |
|
|
|
|
|
if threshold is not None: |
|
red = np.zeros(im.shape+(4,)) |
|
red[im <= threshold] = [1, 0, 0, 1] |
|
plt.imshow(red, origin='lower') |
|
|
|
for k in range(3): |
|
j = k + 4 |
|
im = target_imgs[k] |
|
plt.subplot(6, 3, j) |
|
plt.axis('off') |
|
plt.title(f'target image: {target_id}') |
|
plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_target, cmap=cmap) |
|
|
|
|
|
if threshold is not None: |
|
red = np.zeros(im.shape+(4,)) |
|
red[im <= threshold] = [1, 0, 0, 1] |
|
plt.imshow(red, origin='lower') |
|
|
|
for m in range(3): |
|
j = 7 + m |
|
im = warped_imgs[m] |
|
plt.subplot(6, 3, j) |
|
plt.axis('off') |
|
plt.title(f'warped image: {moving_id} to {target_id}') |
|
plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_warped, cmap=cmap) |
|
|
|
|
|
if threshold is not None: |
|
red = np.zeros(im.shape+(4,)) |
|
red[im <= threshold] = [1, 0, 0, 1] |
|
plt.imshow(red, origin='lower') |
|
|
|
if downsampling is None: |
|
|
|
downsampling = max(1, int(max(vector_field.shape[1:])) >> 5) |
|
|
|
x, y, z = np.array(vector_field.shape[1:])//2 |
|
plt.subplot(6, 3, 10) |
|
plt.axis('off') |
|
plt.title(f'deformation vector field: {moving_id} to {target_id}') |
|
plot_2D_vector_field(vector_field[[1, 2], x, :, :], downsampling) |
|
plt.subplot(6, 3, 11) |
|
plt.axis('off') |
|
plt.title(f'deformation vector field: {moving_id} to {target_id}') |
|
plot_2D_vector_field(vector_field[[0, 2], :, y, :], downsampling) |
|
plt.subplot(6, 3, 12) |
|
plt.axis('off') |
|
plt.title(f'deformation vector field: {moving_id} to {target_id}') |
|
plot_2D_vector_field(vector_field[[0, 1], :, :, z], downsampling) |
|
|
|
x, y, z = np.array(vector_field.shape[1:])//2 |
|
plt.subplot(6, 3, 13) |
|
plt.axis('off') |
|
plt.title(f'deformation vector field on grid: {moving_id} to {target_id}') |
|
plot_2D_deformation( |
|
vector_field[[1, 2], x, :, :], grid_spacing, linewidth=linewidth, color=color) |
|
plt.subplot(6, 3, 14) |
|
plt.axis('off') |
|
plt.title(f'deformation vector field on grid: {moving_id} to {target_id}') |
|
plot_2D_deformation( |
|
vector_field[[0, 2], :, y, :], grid_spacing, linewidth=linewidth, color=color) |
|
plt.subplot(6, 3, 15) |
|
plt.axis('off') |
|
plt.title(f'deformation vector field on grid: {moving_id} to {target_id}') |
|
plot_2D_deformation( |
|
vector_field[[0, 1], :, :, z], grid_spacing, linewidth=linewidth, color=color) |
|
|
|
for n in range(3): |
|
o = n + 16 |
|
im = det_imgs[n] |
|
plt.subplot(6, 3, o) |
|
plt.axis('off') |
|
plt.title(f'jacobian determinant: {moving_id} to {target_id}') |
|
plt.imshow(im, origin='lower', vmin=vmin, vmax=vmax_det, cmap=None) |
|
|
|
|
|
if threshold_det is not None: |
|
red = np.zeros(im.shape+(4,)) |
|
red[im <= threshold_det] = [1, 0, 0, 1] |
|
plt.imshow(red, origin='lower') |
|
|
|
plt.savefig(os.path.join( |
|
output, f'reg_net_infer_{moving_id}_to_{target_id}.png')) |
|
|