import numpy as np import pandas as pd import matplotlib.pyplot as plt from PIL import Image def preprocess(image, output, binarize, threshold): image = image.cpu().detach().numpy().squeeze() image = np.transpose(image,(1,2,0)) image = (image + 1) * 0.5 output = output.cpu().detach().numpy().squeeze() if binarize: output = np.where(output > threshold, 1., 0.) return image, output def enlarge_array(output): df = pd.DataFrame(np.reshape(output, (14,14))) df = pd.DataFrame(np.repeat(df.values, 16, axis=0)) df = pd.DataFrame(np.repeat(df.values, 16, axis=1)) output = df.to_numpy() return output def visualize_output(image, output, binarize, threshold): image, output = preprocess(image, output, binarize, threshold) output = enlarge_array(output) output_mask = Image.fromarray(output * 255) fig = plt.figure(figsize = (6,6)) plt.axis('off') plt.imshow(image) if binarize: plt.imshow(output_mask, alpha=.45) else: plt.imshow(output_mask, alpha=.45) fig.tight_layout(pad=0) fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data