Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
1.39 kB
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
def display_images(data, n_rows=3, n_cols=3):
figure, axs = plt.subplots(n_rows, n_cols, figsize=(24, 12))
axs = axs.flatten()
plt.setp(axs, xticks=[], yticks=[])
plt.subplots_adjust(wspace=0, hspace=0)
for img, ax in zip(data, axs):
img = unnormalize_if_necessary(img)
ax.imshow(img)
return figure
def unnormalize_if_necessary(x):
if isinstance(x, np.ndarray):
if x.min() < 0:
return (x * 0.5) + 0.5
elif isinstance(x, tf.Tensor):
if x.numpy().min() < 0:
return (x * 0.5) + 0.5
return x
def display_true_pred(y_true, y_pred, n_cols=3):
fig = plt.figure(constrained_layout=True, figsize=(24, 12))
y_true = unnormalize_if_necessary(y_true)
y_pred = unnormalize_if_necessary(y_pred)
images = [y_pred, y_true]
# create 2x1 subfigs
subfigs = fig.subfigures(nrows=2, ncols=1)
for row, subfig in enumerate(subfigs):
subfig.suptitle("Prediction" if row == 0 else "Ground truth", fontsize=24)
# create 1xn_cols subplots per subfig
axs = subfig.subplots(nrows=1, ncols=n_cols)
for col, ax in enumerate(axs):
if row == 0:
ax.imshow(images[row][col])
else:
ax.imshow(images[row][col])
return fig