File size: 5,884 Bytes
89f6983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import matplotlib.pyplot as plt
import math
import utils
from . import parse

save_ind = 0

def visualize(image, title, colorbar=False, show_plot=True, **kwargs):
    plt.title(title)
    plt.imshow(image, **kwargs)
    if colorbar:
        plt.colorbar()
    if show_plot:
        plt.show()

def visualize_arrays(image_title_pairs, colorbar_index=-1, show_plot=True, figsize=None, **kwargs):
    if figsize is not None:
        plt.figure(figsize=figsize)
    num_subplots = len(image_title_pairs)
    for idx, image_title_pair in enumerate(image_title_pairs):
        plt.subplot(1, num_subplots, idx+1)
        if isinstance(image_title_pair, (list, tuple)):
            image, title = image_title_pair
        else:
            image, title = image_title_pair, None
        
        if title is not None:
            plt.title(title)
        
        plt.imshow(image, **kwargs)
        if idx == colorbar_index:
            plt.colorbar()
            
    if show_plot:
        plt.show()

def visualize_masked_latents(latents_all, masked_latents, timestep_T=False, timestep_0=True):
    if timestep_T:
        # from T to 0
        latent_idx = 0

        plt.subplot(1, 2, 1)
        plt.title("latents_all (t=T)")
        plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")

        plt.subplot(1, 2, 2)
        plt.title("mask latents (t=T)")
        plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")

        plt.show()

    if timestep_0:
        latent_idx = -1
        plt.subplot(1, 2, 1)
        plt.title("latents_all (t=0)")
        plt.imshow((latents_all[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")

        plt.subplot(1, 2, 2)
        plt.title("mask latents (t=0)")
        plt.imshow((masked_latents[latent_idx, 0, :3].cpu().permute(1,2,0).numpy().astype(float) / 1.5).clip(0., 1.), cmap="gray")

        plt.show()

# This function has not been adapted to new `saved_attn`.
def visualize_attn(token_map, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False):
    """
    Visualize cross attention: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token
    cross_attention_probs_tensors:
    One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors`
    stage_id: index of downsampling/mid/upsaming block
    block_id: index of the transformer block
    """
    
    plt.figure(figsize=(20, 8))

    for token_id in range(len(token_map)):
        token = token_map[token_id]
        plt.subplot(1, len(token_map), token_id + 1)
        plt.title(token)
        attn = cross_attention_probs_tensors[stage_id][visualize_step_start:].mean(dim=0)[block_id]
        
        if not input_ca_has_condition_only:
            assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
            attn = attn[1]
        else:
            assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
            attn = attn[0]
        
        attn = attn.mean(dim=0)[:, token_id]
        H = W = int(math.sqrt(attn.shape[0]))
        attn = attn.reshape((H, W))
        plt.imshow(attn.cpu().numpy())
        
    plt.show()

# This function has not been adapted to new `saved_attn`.
def visualize_across_timesteps(token_id, cross_attention_probs_tensors, stage_id, block_id, visualize_step_start=10, input_ca_has_condition_only=False):
    """
    Visualize cross attention for one token, across timesteps: `stage_id`th downsampling block, mean over all timesteps starting from step start, `block_id`th Transformer block, second item (conditioned), mean over heads, show each token
    cross_attention_probs_tensors:
    One of `cross_attention_probs_down_tensors`, `cross_attention_probs_mid_tensors`, and `cross_attention_probs_up_tensors`
    stage_id: index of downsampling/mid/upsaming block
    block_id: index of the transformer block
    
    `visualize_step_start` is not used. We visualize all timesteps.
    """
    plt.figure(figsize=(50, 8))
    
    attn_stage = cross_attention_probs_tensors[stage_id]
    num_inference_steps = attn_stage.shape[0]

    for t in range(num_inference_steps):
        plt.subplot(1, num_inference_steps, t + 1)
        plt.title(f"t: {t}")

        attn = attn_stage[t][block_id]
        
        if not input_ca_has_condition_only:
            assert attn.shape[0] == 2, f"Expect to have 2 items (uncond and cond), but found {attn.shape[0]} items"
            attn = attn[1]
        else:
            assert attn.shape[0] == 1, f"Expect to have 1 item (cond only), but found {attn.shape[0]} items"
            attn = attn[0]
        
        attn = attn.mean(dim=0)[:, token_id]
        H = W = int(math.sqrt(attn.shape[0]))
        attn = attn.reshape((H, W))
        plt.imshow(attn.cpu().numpy())
        plt.axis("off")
        plt.tight_layout()

    plt.show()

def visualize_bboxes(bboxes, H, W):
    num_boxes = len(bboxes)
    for ind, bbox in enumerate(bboxes):
        plt.subplot(1, num_boxes, ind + 1)
        fg_mask = utils.proportion_to_mask(bbox, H, W)
        plt.title(f"transformed bbox ({ind})")
        plt.imshow(fg_mask.cpu().numpy())
    plt.show()

def display(image, save_prefix="", ind=None):
    global save_ind
    if save_prefix != "":
        save_prefix = save_prefix + "_"
    ind = f"{ind}_" if ind is not None else ""
    path = f"{parse.img_dir}/{save_prefix}{ind}{save_ind}.png"
    
    print(f"Saved to {path}")
    
    image.save(path)
    save_ind = save_ind + 1