| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| def preprocess(image, output, binarize, threshold): | |
| image = image.cpu().detach().numpy().squeeze() | |
| image = np.transpose(image,(1,2,0)) | |
| image = (image + 1) * 0.5 | |
| output = output.cpu().detach().numpy().squeeze() | |
| if binarize: | |
| output = np.where(output > threshold, 1., 0.) | |
| return image, output | |
| def enlarge_array(output): | |
| df = pd.DataFrame(np.reshape(output, (14,14))) | |
| df = pd.DataFrame(np.repeat(df.values, 16, axis=0)) | |
| df = pd.DataFrame(np.repeat(df.values, 16, axis=1)) | |
| output = df.to_numpy() | |
| return output | |
| def visualize_output(image, output, binarize, threshold): | |
| image, output = preprocess(image, output, binarize, threshold) | |
| output = enlarge_array(output) | |
| output_mask = Image.fromarray(output * 255) | |
| fig = plt.figure(figsize = (6,6)) | |
| plt.axis('off') | |
| plt.imshow(image) | |
| if binarize: | |
| plt.imshow(output_mask, alpha=.45) | |
| else: | |
| plt.imshow(output_mask, alpha=.45) | |
| fig.tight_layout(pad=0) | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| return data |