File size: 4,004 Bytes
64cc34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

# global variable
raw_attn_maps = {}
raw_ip_attn_maps = {}
attn_maps = {}
ip_attn_maps = {}


def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            if name not in raw_attn_maps:
                raw_attn_maps[name] = []
            if name not in raw_ip_attn_maps:
                raw_ip_attn_maps[name] = []
            raw_attn_maps[name].append(module.processor.attn_map)
            raw_ip_attn_maps[name].append(module.processor.ip_attn_map)
            del module.processor.attn_map
            del module.processor.ip_attn_map

    return forward_hook


def post_process_attn_maps():
    global raw_attn_maps, raw_ip_attn_maps, attn_maps, ip_attn_maps
    attn_maps = [
        dict(zip(raw_attn_maps.keys(), values))
        for values in zip(*raw_attn_maps.values())
    ]
    ip_attn_maps = [
        dict(zip(raw_ip_attn_maps.keys(), values))
        for values in zip(*raw_ip_attn_maps.values())
    ]

    return attn_maps, ip_attn_maps


def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split(".")[-1].startswith("attn2"):
            module.register_forward_hook(hook_fn(name))

    return unet


def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1, 0)
    temp_size = None

    for i in range(0, 5):
        scale = 2**i
        if (target_size[0] // scale) * (target_size[1] // scale) == attn_map.shape[
            1
        ] * 64:
            temp_size = (target_size[0] // (scale * 8), target_size[1] // (scale * 8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32),
        size=target_size,
        mode="bilinear",
        align_corners=False,
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map


def get_net_attn_map(
    image_size, batch_size=2, instance_or_negative=False, detach=True, step=-1
):

    idx = 0 if instance_or_negative else 1
    net_attn_maps = []
    net_ip_attn_maps = []

    for _, attn_map in attn_maps[step].items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[
            idx
        ].squeeze()  # get the attention map of text
        attn_map = upscale(attn_map, image_size)
        net_attn_maps.append(attn_map)

    net_attn_maps = torch.mean(torch.stack(net_attn_maps, dim=0), dim=0)

    for _, attn_map in ip_attn_maps[step].items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[
            idx
        ].squeeze()  # get the attention map of text
        attn_map = upscale(attn_map, image_size)
        net_ip_attn_maps.append(attn_map)

    net_ip_attn_maps = torch.mean(torch.stack(net_ip_attn_maps, dim=0), dim=0)

    return net_attn_maps, net_ip_attn_maps


def attnmaps2images(net_attn_maps):
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        normalized_attn_map = (
            (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        )
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        image = Image.fromarray(normalized_attn_map)
        images.append(image)

    return images


def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")


def get_generator(seed, device):

    if seed is not None:
        if isinstance(seed, list):
            generator = [
                torch.Generator(device).manual_seed(seed_item) for seed_item in seed
            ]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator