Spaces:
Build error
Build error
| from matplotlib import pyplot as plt | |
| from matplotlib import gridspec | |
| import matplotlib.patches as mpatches | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| def get_cols(): | |
| # list of perceptually distinct colours (for spatial factor plots) | |
| return np.array([[255,0,0], [255,255,0], [0,234,255], [170,0,255], [255,127,0], [191,255,0], [0,149,255], [255,0,170], [255,212,0], [106,255,0], [0,64,255], [237,185,185], [185,215,237], [231,233,185], [220,185,237], [185,237,224], [143,35,35], [35,98,143], [143,106,35], [107,35,143], [79,143,35], [0,0,0], [115,115,115], [204,204,204]]) | |
| def mapRange(value, inMin, inMax, outMin, outMax): | |
| return outMin + (((value - inMin) / (inMax - inMin)) * (outMax - outMin)) | |
| def plot_masks(Us, r, s, rs=256, save_path=None, title_factors=True): | |
| """ | |
| Plots the parts factors with matplotlib for visualization | |
| Parameters | |
| ---------- | |
| Us : np.array | |
| Learnt parts factor matrix. | |
| r : int | |
| Number of factors to show. | |
| s : int | |
| Dimensions of each part (h*w). | |
| rs : int | |
| Target size to downsize images to. | |
| save_path : bool | |
| Save figure? | |
| title_factors : bool | |
| Print matplotlib title on each part? | |
| """ | |
| fig = plt.figure(constrained_layout=True, figsize=(20, 3)) | |
| spec = gridspec.GridSpec(ncols=r + 1, nrows=1, figure=fig) | |
| for i in range(0, r): | |
| fig.add_subplot(spec[i]) | |
| if title_factors: | |
| plt.title(f'Part {i}') | |
| part = Us[i].reshape([s, s]) | |
| part = mapRange(part, torch.min(part), torch.max(part), 0.0, 1.0) * 255 | |
| part = part.detach().cpu().numpy() | |
| part = np.array(Image.fromarray(np.uint8(part)).convert('RGBA').resize((rs, rs), Image.NEAREST)) / 255 | |
| plt.axis('off') | |
| plt.imshow(part, vmin=1, vmax=1, cmap='gray', alpha=1.00) | |
| if save_path is not None: | |
| plt.savefig(save_path) | |
| def plot_colours(image, Us, r, s, rs=128, save_path=None, alpha=1.0, seed=-1, legend=True): | |
| """ | |
| Plots the parts factors over an image with matplotlib for visualization | |
| Parameters | |
| ---------- | |
| image : np.array | |
| Image to visualize. | |
| Us : np.array | |
| Learnt parts factor matrix. | |
| r : int | |
| Number of factors to show. | |
| s : int | |
| Dimensions of each part (h*w). | |
| rs : int | |
| Target size to downsize images to. | |
| alpha : float | |
| Alpha value for the masks. | |
| seed : int | |
| Random seed when generating the colour palette (use -1 to use the provided "perceptually distinct" colour palette, but note this has a maximum of 30 colours or so). | |
| legend : bool | |
| Plot the legend, detailing the colour-coded parts key? | |
| """ | |
| img = Image.fromarray(image).resize((rs, rs)).convert('RGBA') | |
| # Use perceptually distinct colour list, or random seed (for e.g. if you have too many factors) | |
| cols = get_cols() | |
| if seed >= 0: | |
| np.random.seed(seed) | |
| cols = np.random.randint(0, 255, [r, 3]) | |
| plt.imshow(img, alpha=1.0) | |
| plt.axis('off') | |
| patches = [] | |
| for i in range(0, r): | |
| mask = Us[i].detach().cpu().numpy().reshape([s, s]) | |
| mask = mapRange(mask, np.min(mask), np.max(mask), 0, 255) | |
| mask = np.uint8(mask) | |
| mask = np.array(Image.fromarray(mask).convert('L').resize((rs, rs))) | |
| mask = (mask[:, :, None] / 255.) * np.array(np.concatenate([cols[i] / 255, [1]])) | |
| patches += [mpatches.Patch(color=cols[i] / 255, label=f'Part {i}')] | |
| plt.imshow(mask, vmin=0, vmax=1, alpha=alpha) | |
| if legend: | |
| plt.legend(title='Spatial factors', handles=patches, bbox_to_anchor=(1.01, 1.01), loc="upper left") | |
| if save_path is not None: | |
| plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0) |