File size: 7,725 Bytes
4c022fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5268ad
 
4c022fe
 
 
 
 
ab7db7f
 
 
 
4c022fe
 
ab7db7f
4c022fe
 
ab7db7f
4c022fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab7db7f
 
 
f6e53b7
ab7db7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import numpy as np
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision

from pathlib import Path


def split_attention_maps_over_steps(attention_maps):
    r"""Function for splitting attention maps over steps.
    Args:
        attention_maps (dict): Dictionary of attention maps.
        sampler_order (int): Order of the sampler.
    """
    # This function splits attention maps into unconditional and conditional score and over steps

    attention_maps_cond = dict()    # Maps corresponding to conditional score
    attention_maps_uncond = dict()  # Maps corresponding to unconditional score

    for layer in attention_maps.keys():

        for step_num in range(len(attention_maps[layer])):
            if step_num not in attention_maps_cond:
                attention_maps_cond[step_num] = dict()
                attention_maps_uncond[step_num] = dict()

            attention_maps_uncond[step_num].update(
                {layer: attention_maps[layer][step_num][:1]})
            attention_maps_cond[step_num].update(
                {layer: attention_maps[layer][step_num][1:2]})

    return attention_maps_cond, attention_maps_uncond


def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
    atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
    for i, (attn_map, obj_token) in enumerate(zip(atten_map_list, obj_tokens)):
        n_obj = len(attn_map)
        plt.figure()
        plt.clf()

        fig, axs = plt.subplots(
            ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))

        fig.set_figheight(3)
        fig.set_figwidth(3*n_obj+0.1)

        cmap = plt.get_cmap('OrRd')

        vmax = 0
        vmin = 1
        for tid in range(n_obj):
            attention_map_cur = attn_map[tid]
            vmax = max(vmax, float(attention_map_cur.max()))
            vmin = min(vmin, float(attention_map_cur.min()))

        for tid in range(n_obj):
            sns.heatmap(
                attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
                cmap=cmap, vmin=vmin, vmax=vmax
            )
            axs[tid].set_axis_off()
            if tokens_vis is not None:
                if tid == n_obj-1:
                    axs_xlabel = 'other tokens'
                else:
                    axs_xlabel = ''
                    for token_id in obj_tokens[tid]:
                        axs_xlabel += ' ' + tokens_vis[token_id.item() -
                                                       1][:-len('</w>')]
                axs[tid].set_title(axs_xlabel)

        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        fig.colorbar(sm, cax=axs[-1])
        canvas = fig.canvas
        canvas.draw()
        width, height = canvas.get_width_height()
        img = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape((height, width, 3))

        fig.tight_layout()
    return img


def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None):
    r"""Function to visualize attention maps.
    Args:
        save_dir (str): Path to save attention maps
        batch_size (int): Batch size
        sampler_order (int): Sampler order
    """

    # Split attention maps over steps
    attention_maps_cond, _ = split_attention_maps_over_steps(
        attention_maps
    )

    selected_layers = [
        # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
        # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
        'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
        # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
        'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
        'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
        'mid_block.attentions.0.transformer_blocks.0.attn2',
        'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
        'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
        'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
        # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
        'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
        # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
        # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
        # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
        # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
    ]

    nsteps = len(attention_maps_cond)
    hw_ori = width * height

    attention_maps = []
    for obj_token in obj_tokens:
        attention_maps.append([])

    for step_num in range(nsteps):
        attention_maps_cur = attention_maps_cond[step_num]

        for layer in attention_maps_cur.keys():
            if step_num < 10 or layer not in selected_layers:
                continue

            attention_ind = attention_maps_cur[layer].cpu()

            # Attention maps are of shape [batch_size, nkeys, 77]
            # since they are averaged out while collecting from hooks to save memory.
            # Now split the heads from batch dimension
            bs, hw, nclip = attention_ind.shape
            down_ratio = np.sqrt(hw_ori // hw)
            width_cur = int(width // down_ratio)
            height_cur = int(height // down_ratio)
            attention_ind = attention_ind.reshape(
                bs, height_cur, width_cur, nclip)
            for obj_id, obj_token in enumerate(obj_tokens):
                if obj_token[0] == -1:
                    attention_map_prev = torch.stack(
                        [attention_maps[i][-1] for i in range(obj_id)]).sum(0)
                    attention_maps[obj_id].append(
                        attention_map_prev.max()-attention_map_prev)
                else:
                    obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
                        0].permute([3, 0, 1, 2])
                    obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
                                                                                 interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
                    attention_maps[obj_id].append(obj_attention_map)

    # average attention maps over steps
    attention_maps_averaged = []
    for obj_id, obj_token in enumerate(obj_tokens):
        if obj_id == len(obj_tokens) - 1:
            attention_maps_averaged.append(
                torch.cat(attention_maps[obj_id]).mean(0))
        else:
            attention_maps_averaged.append(
                torch.cat(attention_maps[obj_id]).mean(0))

    # normalize attention maps into [0, 1]
    attention_maps_averaged_normalized = []
    attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
    for obj_id, obj_token in enumerate(obj_tokens):
        attention_maps_averaged_normalized.append(
            attention_maps_averaged[obj_id]/attention_maps_averaged_sum)

    # softmax
    attention_maps_averaged_normalized = (
        torch.cat(attention_maps_averaged)/0.001).softmax(0)
    attention_maps_averaged_normalized = [
        attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]

    token_maps_vis = plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
                        obj_tokens, save_dir, seed, tokens_vis)
    attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
        [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
    return attention_maps_averaged_normalized, token_maps_vis