PandA / utils.py
james-oldfield's picture
Upload 18 files
ceb80dd
from matplotlib import pyplot as plt
from matplotlib import gridspec
import matplotlib.patches as mpatches
import torch
import numpy as np
from PIL import Image
def get_cols():
# list of perceptually distinct colours (for spatial factor plots)
return np.array([[255,0,0], [255,255,0], [0,234,255], [170,0,255], [255,127,0], [191,255,0], [0,149,255], [255,0,170], [255,212,0], [106,255,0], [0,64,255], [237,185,185], [185,215,237], [231,233,185], [220,185,237], [185,237,224], [143,35,35], [35,98,143], [143,106,35], [107,35,143], [79,143,35], [0,0,0], [115,115,115], [204,204,204]])
def mapRange(value, inMin, inMax, outMin, outMax):
return outMin + (((value - inMin) / (inMax - inMin)) * (outMax - outMin))
def plot_masks(Us, r, s, rs=256, save_path=None, title_factors=True):
"""
Plots the parts factors with matplotlib for visualization
Parameters
----------
Us : np.array
Learnt parts factor matrix.
r : int
Number of factors to show.
s : int
Dimensions of each part (h*w).
rs : int
Target size to downsize images to.
save_path : bool
Save figure?
title_factors : bool
Print matplotlib title on each part?
"""
fig = plt.figure(constrained_layout=True, figsize=(20, 3))
spec = gridspec.GridSpec(ncols=r + 1, nrows=1, figure=fig)
for i in range(0, r):
fig.add_subplot(spec[i])
if title_factors:
plt.title(f'Part {i}')
part = Us[i].reshape([s, s])
part = mapRange(part, torch.min(part), torch.max(part), 0.0, 1.0) * 255
part = part.detach().cpu().numpy()
part = np.array(Image.fromarray(np.uint8(part)).convert('RGBA').resize((rs, rs), Image.NEAREST)) / 255
plt.axis('off')
plt.imshow(part, vmin=1, vmax=1, cmap='gray', alpha=1.00)
if save_path is not None:
plt.savefig(save_path)
def plot_colours(image, Us, r, s, rs=128, save_path=None, alpha=1.0, seed=-1, legend=True):
"""
Plots the parts factors over an image with matplotlib for visualization
Parameters
----------
image : np.array
Image to visualize.
Us : np.array
Learnt parts factor matrix.
r : int
Number of factors to show.
s : int
Dimensions of each part (h*w).
rs : int
Target size to downsize images to.
alpha : float
Alpha value for the masks.
seed : int
Random seed when generating the colour palette (use -1 to use the provided "perceptually distinct" colour palette, but note this has a maximum of 30 colours or so).
legend : bool
Plot the legend, detailing the colour-coded parts key?
"""
img = Image.fromarray(image).resize((rs, rs)).convert('RGBA')
# Use perceptually distinct colour list, or random seed (for e.g. if you have too many factors)
cols = get_cols()
if seed >= 0:
np.random.seed(seed)
cols = np.random.randint(0, 255, [r, 3])
plt.imshow(img, alpha=1.0)
plt.axis('off')
patches = []
for i in range(0, r):
mask = Us[i].detach().cpu().numpy().reshape([s, s])
mask = mapRange(mask, np.min(mask), np.max(mask), 0, 255)
mask = np.uint8(mask)
mask = np.array(Image.fromarray(mask).convert('L').resize((rs, rs)))
mask = (mask[:, :, None] / 255.) * np.array(np.concatenate([cols[i] / 255, [1]]))
patches += [mpatches.Patch(color=cols[i] / 255, label=f'Part {i}')]
plt.imshow(mask, vmin=0, vmax=1, alpha=alpha)
if legend:
plt.legend(title='Spatial factors', handles=patches, bbox_to_anchor=(1.01, 1.01), loc="upper left")
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)