AutoLink / utils_ /visualization.py
xingzhehe's picture
try fitst commit
91fc62a
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torchvision
from matplotlib import colors
def get_part_color(n_parts):
colormap = ('red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
'honeydew', 'thistle',
'red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
'honeydew', 'thistle')[:n_parts]
part_color = []
for i in range(n_parts):
part_color.append(colors.to_rgb(colormap[i]))
part_color = np.array(part_color)
return part_color
def denormalize(img):
mean = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
std = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
img = img * std + mean
img = torch.clamp(img, min=0, max=1)
return img
def draw_matrix(mat):
fig = plt.figure()
sns.heatmap(mat, annot=True, fmt='.2f', cmap="YlGnBu")
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot
def draw_kp_grid(img, kp):
kp_color = get_part_color(kp.shape[1])
img = img[:64].permute(0, 2, 3, 1).detach().cpu()
kp = kp.detach().cpu()[:64]
fig = plt.figure(figsize=(8, 8))
gs = gridspec.GridSpec(8, 8)
gs.update(wspace=0, hspace=0)
for i, sample in enumerate(img):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.imshow(sample, vmin=0, vmax=1)
ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot
def draw_kp_grid_unnorm(img, kp):
kp_color = get_part_color(kp.shape[1])
img = img[:64].permute(0, 2, 3, 1).detach().cpu()
kp = kp.detach().cpu()[:64]
fig = plt.figure(figsize=(8, 8))
gs = gridspec.GridSpec(8, 8)
gs.update(wspace=0, hspace=0)
for i, sample in enumerate(img):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.imshow(sample)
ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
ncols, nrows = fig.canvas.get_width_height()
fig.canvas.draw()
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
plt.close(fig)
return plot
def draw_img_grid(img):
img = img[:64].detach().cpu()
nrow = min(8, img.shape[0])
img = torchvision.utils.make_grid(img[:64], nrow=nrow).permute(1, 2, 0)
return torch.clamp(img * 255, min=0, max=255).numpy().astype(np.uint8)