|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
|
|
|
|
def preprocess(image, output, binarize, threshold): |
|
|
|
image = image.cpu().detach().numpy().squeeze() |
|
image = np.transpose(image,(1,2,0)) |
|
image = (image + 1) * 0.5 |
|
output = output.cpu().detach().numpy().squeeze() |
|
|
|
if binarize: |
|
output = np.where(output > threshold, 1., 0.) |
|
|
|
return image, output |
|
|
|
|
|
def enlarge_array(output): |
|
df = pd.DataFrame(np.reshape(output, (14,14))) |
|
df = pd.DataFrame(np.repeat(df.values, 16, axis=0)) |
|
df = pd.DataFrame(np.repeat(df.values, 16, axis=1)) |
|
output = df.to_numpy() |
|
|
|
return output |
|
|
|
|
|
def visualize_output(image, output, binarize, threshold): |
|
|
|
image, output = preprocess(image, output, binarize, threshold) |
|
output = enlarge_array(output) |
|
output_mask = Image.fromarray(output * 255) |
|
|
|
fig = plt.figure(figsize = (6,6)) |
|
plt.axis('off') |
|
plt.imshow(image) |
|
if binarize: |
|
plt.imshow(output_mask, alpha=.45) |
|
else: |
|
plt.imshow(output_mask, alpha=.45) |
|
fig.tight_layout(pad=0) |
|
fig.canvas.draw() |
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
|
|
return data |