File size: 4,441 Bytes
6ffe23f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from matplotlib import pyplot as plt
import math
import numpy as np
def visualize_results(img, mask, pred, n_slices: int=3, slices: list=None, title: str=""):
"""
img: tensor [C, H, W, Z]
mask: tensor [C, H, W, Z]
pred: tensor [C, H, W, Z]
n_slices: number of slices to visualize
slices: list of slices to visualize
title; title of the plot
"""
if slices is not None:
n_slices = len(slices)
fig, ax = plt.subplots(n_slices, 3, figsize=(14, 5*n_slices))
inc = img.shape[-1] // n_slices
mask_masked = np.ma.masked_where(mask == 0, mask)
pred_masked = np.ma.masked_where(pred == 0, pred)
for i in range(n_slices):
slice_num = i*inc if slices is None else slices[i]
# image
for c in range(3):
ax[i,c].imshow(img[0,:,:,slice_num], cmap="gray")
ax[i,c].axis("off")
ax[i,c].set_title(f'image')
# ground truth
ax[i,1].imshow(mask_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
ax[i,1].imshow(mask_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
ax[i,1].set_title(f'ground truth')
# predicted
ax[i,2].imshow(pred_masked[1,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.5)
ax[i,2].imshow(pred_masked[2,:,:,slice_num], cmap='Reds', vmin=0, vmax=1.3, interpolation='none', alpha=0.8)
ax[i,2].set_title(f'predicted')
plt.suptitle(title, size=14)
plt.tight_layout()
plt.show()
def visualize_patient(img, mask=None, n_slices: int=3, slices: list=None, z_dim_last=True, mask_channel=0, title: str=""):
"""
img: tensor [C, H, W, Z]
mask: tensor [C, H, W, Z]
n: number of slices to visualize
"""
if slices is not None:
n_slices = len(slices)
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
if z_dim_last: inc = img.shape[-1] // n_slices
else: inc = img.shape[0] // n_slices
masked = np.ma.masked_where(mask == 0, mask)
for i in range(n_slices):
r, c = divmod(i, 3)
slice_num = i*inc if slices is None else slices[i]
if n_slices <= 3:
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[c].axis("off")
ax[c].set_title(f'slice {slice_num}')
if mask is not None:
if z_dim_last: mask_overlay = ax[c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
else: mask_overlay = ax[c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
else:
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[r][c].axis("off")
ax[r][c].set_title(f'slice {slice_num}')
if mask is not None:
if z_dim_last: mask_overlay = ax[r][c].imshow(masked[mask_channel,:,:,slice_num], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
else: mask_overlay = ax[r][c].imshow(masked[slice_num,mask_channel,:,:], cmap='jet', vmin=1, vmax=4, interpolation='none', alpha=0.4)
plt.suptitle(title, size=14)
#if mask is not None:
# cbar = fig.colorbar(mask_overlay, extend='both')
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(math.ceil(n_slices/3), 3, figsize=(14, 5*math.ceil(n_slices/3)))
if z_dim_last: inc = img.shape[-1] // n_slices
else: inc = img.shape[0] // n_slices
for i in range(n_slices):
r, c = divmod(i, 3)
slice_num = i*inc if slices is None else slices[i]
if n_slices <= 3:
if z_dim_last: ax[c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[c].axis("off")
ax[c].set_title(f'slice {slice_num}')
else:
if z_dim_last: ax[r][c].imshow(img[0,:,:,slice_num], cmap="gray")
else: ax[r][c].imshow(img[slice_num,0,:,:], cmap="gray")
ax[r][c].axis("off")
ax[r][c].set_title(f'slice {slice_num}')
plt.suptitle(title, size=14)
plt.tight_layout()
plt.show() |