File size: 3,754 Bytes
ceb80dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from matplotlib import pyplot as plt
from matplotlib import gridspec
import matplotlib.patches as mpatches
import torch
import numpy as np
from PIL import Image


def get_cols():
    # list of perceptually distinct colours (for spatial factor plots)
    return np.array([[255,0,0], [255,255,0], [0,234,255], [170,0,255], [255,127,0], [191,255,0], [0,149,255], [255,0,170], [255,212,0], [106,255,0], [0,64,255], [237,185,185], [185,215,237], [231,233,185], [220,185,237], [185,237,224], [143,35,35], [35,98,143], [143,106,35], [107,35,143], [79,143,35], [0,0,0], [115,115,115], [204,204,204]])


def mapRange(value, inMin, inMax, outMin, outMax):
    return outMin + (((value - inMin) / (inMax - inMin)) * (outMax - outMin))


def plot_masks(Us, r, s, rs=256, save_path=None, title_factors=True):
    """
    Plots the parts factors with matplotlib for visualization

    Parameters
    ----------
    Us : np.array
        Learnt parts factor matrix.
    r : int
        Number of factors to show.
    s : int
        Dimensions of each part (h*w).
    rs : int
        Target size to downsize images to.
    save_path : bool
        Save figure?
    title_factors : bool
        Print matplotlib title on each part?
    """

    fig = plt.figure(constrained_layout=True, figsize=(20, 3))
    spec = gridspec.GridSpec(ncols=r + 1, nrows=1, figure=fig)

    for i in range(0, r):
        fig.add_subplot(spec[i])

        if title_factors:
            plt.title(f'Part {i}')

        part = Us[i].reshape([s, s])
        part = mapRange(part, torch.min(part), torch.max(part), 0.0, 1.0) * 255
        part = part.detach().cpu().numpy()
        part = np.array(Image.fromarray(np.uint8(part)).convert('RGBA').resize((rs, rs), Image.NEAREST)) / 255

        plt.axis('off')
        plt.imshow(part, vmin=1, vmax=1, cmap='gray', alpha=1.00)

    if save_path is not None:
        plt.savefig(save_path)


def plot_colours(image, Us, r, s, rs=128, save_path=None, alpha=1.0, seed=-1, legend=True):
    """
    Plots the parts factors over an image with matplotlib for visualization

    Parameters
    ----------
    image : np.array
        Image to visualize.
    Us : np.array
        Learnt parts factor matrix.
    r : int
        Number of factors to show.
    s : int
        Dimensions of each part (h*w).
    rs : int
        Target size to downsize images to.
    alpha : float
        Alpha value for the masks.
    seed : int
        Random seed when generating the colour palette (use -1 to use the provided "perceptually distinct" colour palette, but note this has a maximum of 30 colours or so).
    legend : bool
        Plot the legend, detailing the colour-coded parts key?
    """

    img = Image.fromarray(image).resize((rs, rs)).convert('RGBA')

    # Use perceptually distinct colour list, or random seed (for e.g. if you have too many factors)
    cols = get_cols()
    if seed >= 0:
        np.random.seed(seed)
        cols = np.random.randint(0, 255, [r, 3])

    plt.imshow(img, alpha=1.0)
    plt.axis('off')

    patches = []
    for i in range(0, r):
        mask = Us[i].detach().cpu().numpy().reshape([s, s])
        mask = mapRange(mask, np.min(mask), np.max(mask), 0, 255)
        mask = np.uint8(mask)
        mask = np.array(Image.fromarray(mask).convert('L').resize((rs, rs)))
        mask = (mask[:, :, None] / 255.) * np.array(np.concatenate([cols[i] / 255, [1]]))

        patches += [mpatches.Patch(color=cols[i] / 255, label=f'Part {i}')]

        plt.imshow(mask, vmin=0, vmax=1, alpha=alpha)

    if legend:
        plt.legend(title='Spatial factors', handles=patches, bbox_to_anchor=(1.01, 1.01), loc="upper left")

    if save_path is not None:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0)