MykolaL commited on
Commit
fb91e14
1 Parent(s): 9661d01

Upload EVPRefer

Browse files
Files changed (5) hide show
  1. config.json +12 -0
  2. evpconfig.py +10 -0
  3. model.py +320 -0
  4. model.safetensors +3 -0
  5. models.py +349 -0
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "EVPRefer"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "evpconfig.EVPConfig",
7
+ "AutoModel": "model.EVPRefer"
8
+ },
9
+ "model_type": "EVP",
10
+ "torch_dtype": "float32",
11
+ "transformers_version": "4.35.2"
12
+ }
evpconfig.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class EVPConfig(PretrainedConfig):
4
+ model_type = "EVP"
5
+ def __init__(
6
+ self,
7
+ **kwargs,
8
+ ):
9
+ super().__init__(**kwargs)
10
+
model.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import sys
6
+ from ldm.util import instantiate_from_config
7
+ from transformers.models.clip.modeling_clip import CLIPTextModel
8
+ from omegaconf import OmegaConf
9
+ from lib.mask_predictor import SimpleDecoding
10
+ from transformers import PreTrainedModel
11
+ from .models import UNetWrapper, TextAdapterRefer
12
+ from evpconfig import EVPConfig
13
+ from transformers import CLIPTokenizer
14
+ import torchvision.transforms as transforms
15
+
16
+
17
+ def icnr(x, scale=2, init=nn.init.kaiming_normal_):
18
+ """
19
+ Checkerboard artifact free sub-pixel convolution
20
+ https://arxiv.org/abs/1707.02937
21
+ """
22
+ ni,nf,h,w = x.shape
23
+ ni2 = int(ni/(scale**2))
24
+ k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
25
+ k = k.contiguous().view(ni2, nf, -1)
26
+ k = k.repeat(1, 1, scale**2)
27
+ k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
28
+ x.data.copy_(k)
29
+
30
+
31
+ class PixelShuffle(nn.Module):
32
+ """
33
+ Real-Time Single Image and Video Super-Resolution
34
+ https://arxiv.org/abs/1609.05158
35
+ """
36
+ def __init__(self, n_channels, scale):
37
+ super(PixelShuffle, self).__init__()
38
+ self.conv = nn.Conv2d(n_channels, n_channels*(scale**2), kernel_size=1)
39
+ icnr(self.conv.weight)
40
+ self.shuf = nn.PixelShuffle(scale)
41
+ self.relu = nn.ReLU()
42
+
43
+ def forward(self,x):
44
+ x = self.shuf(self.relu(self.conv(x)))
45
+ return x
46
+
47
+
48
+ class AttentionModule(nn.Module):
49
+ def __init__(self, in_channels, out_channels):
50
+ super(AttentionModule, self).__init__()
51
+
52
+ # Convolutional Layers
53
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
54
+
55
+ # Group Normalization
56
+ self.group_norm = nn.GroupNorm(20, out_channels)
57
+
58
+ # ReLU Activation
59
+ self.relu = nn.ReLU()
60
+
61
+ # Spatial Attention
62
+ self.spatial_attention = nn.Sequential(
63
+ nn.Conv2d(in_channels, 1, kernel_size=1),
64
+ nn.Sigmoid()
65
+ )
66
+
67
+ def forward(self, x):
68
+ # Apply spatial attention
69
+ spatial_attention = self.spatial_attention(x)
70
+ x = x * spatial_attention
71
+
72
+ # Apply convolutional layer
73
+ x = self.conv1(x)
74
+ x = self.group_norm(x)
75
+ x = self.relu(x)
76
+
77
+ return x
78
+
79
+
80
+ class AttentionDownsamplingModule(nn.Module):
81
+ def __init__(self, in_channels, out_channels, scale_factor=2):
82
+ super(AttentionDownsamplingModule, self).__init__()
83
+
84
+ # Spatial Attention
85
+ self.spatial_attention = nn.Sequential(
86
+ nn.Conv2d(in_channels, 1, kernel_size=1),
87
+ nn.Sigmoid()
88
+ )
89
+
90
+ # Channel Attention
91
+ self.channel_attention = nn.Sequential(
92
+ nn.AdaptiveAvgPool2d(1),
93
+ nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
94
+ nn.ReLU(inplace=True),
95
+ nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
96
+ nn.Sigmoid()
97
+ )
98
+
99
+ # Convolutional Layers
100
+ if scale_factor == 2:
101
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
102
+ elif scale_factor == 4:
103
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
104
+
105
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
106
+
107
+ # Group Normalization
108
+ self.group_norm = nn.GroupNorm(20, out_channels)
109
+
110
+ # ReLU Activation
111
+ self.relu = nn.ReLU(inplace=True)
112
+
113
+ def forward(self, x):
114
+ # Apply spatial attention
115
+ spatial_attention = self.spatial_attention(x)
116
+ x = x * spatial_attention
117
+
118
+ # Apply channel attention
119
+ channel_attention = self.channel_attention(x)
120
+ x = x * channel_attention
121
+
122
+ # Apply convolutional layers
123
+ x = self.conv1(x)
124
+ x = self.group_norm(x)
125
+ x = self.relu(x)
126
+ x = self.conv2(x)
127
+ x = self.group_norm(x)
128
+ x = self.relu(x)
129
+
130
+ return x
131
+
132
+
133
+ class AttentionUpsamplingModule(nn.Module):
134
+ def __init__(self, in_channels, out_channels):
135
+ super(AttentionUpsamplingModule, self).__init__()
136
+
137
+ # Spatial Attention for outs[2]
138
+ self.spatial_attention = nn.Sequential(
139
+ nn.Conv2d(in_channels, 1, kernel_size=1),
140
+ nn.Sigmoid()
141
+ )
142
+
143
+ # Channel Attention for outs[2]
144
+ self.channel_attention = nn.Sequential(
145
+ nn.AdaptiveAvgPool2d(1),
146
+ nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),
147
+ nn.ReLU(),
148
+ nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),
149
+ nn.Sigmoid()
150
+ )
151
+
152
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
153
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
154
+
155
+ # Group Normalization
156
+ self.group_norm = nn.GroupNorm(20, out_channels)
157
+
158
+ # ReLU Activation
159
+ self.relu = nn.ReLU()
160
+ self.upscale = PixelShuffle(in_channels, 2)
161
+
162
+ def forward(self, x):
163
+ # Apply spatial attention
164
+ spatial_attention = self.spatial_attention(x)
165
+ x = x * spatial_attention
166
+
167
+ # Apply channel attention
168
+ channel_attention = self.channel_attention(x)
169
+ x = x * channel_attention
170
+
171
+ # Apply convolutional layers
172
+ x = self.conv1(x)
173
+ x = self.group_norm(x)
174
+ x = self.relu(x)
175
+ x = self.conv2(x)
176
+ x = self.group_norm(x)
177
+ x = self.relu(x)
178
+
179
+ # Upsample
180
+ x = self.upscale(x)
181
+
182
+ return x
183
+
184
+
185
+ class ConvLayer(nn.Module):
186
+ def __init__(self, in_channels, out_channels):
187
+ super(ConvLayer, self).__init__()
188
+
189
+ self.conv1 = nn.Sequential(
190
+ nn.Conv2d(in_channels, out_channels, 1),
191
+ nn.GroupNorm(20, out_channels),
192
+ nn.ReLU(),
193
+ )
194
+
195
+ def forward(self, x):
196
+ x = self.conv1(x)
197
+
198
+ return x
199
+
200
+
201
+ class InverseMultiAttentiveFeatureRefinement(nn.Module):
202
+ def __init__(self, in_channels_list):
203
+ super(InverseMultiAttentiveFeatureRefinement, self).__init__()
204
+
205
+ self.layer1 = AttentionModule(in_channels_list[0], in_channels_list[0])
206
+ self.layer2 = AttentionDownsamplingModule(in_channels_list[0], in_channels_list[0]//2, scale_factor = 2)
207
+ self.layer3 = ConvLayer(in_channels_list[0]//2 + in_channels_list[1], in_channels_list[1])
208
+ self.layer4 = AttentionDownsamplingModule(in_channels_list[1], in_channels_list[1]//2, scale_factor = 2)
209
+ self.layer5 = ConvLayer(in_channels_list[1]//2 + in_channels_list[2], in_channels_list[2])
210
+ self.layer6 = AttentionDownsamplingModule(in_channels_list[2], in_channels_list[2]//2, scale_factor = 2)
211
+ self.layer7 = ConvLayer(in_channels_list[2]//2 + in_channels_list[3], in_channels_list[3])
212
+
213
+ '''
214
+ self.layer8 = AttentionUpsamplingModule(in_channels_list[3], in_channels_list[3])
215
+ self.layer9 = ConvLayer(in_channels_list[2] + in_channels_list[3], in_channels_list[2])
216
+ self.layer10 = AttentionUpsamplingModule(in_channels_list[2], in_channels_list[2])
217
+ self.layer11 = ConvLayer(in_channels_list[1] + in_channels_list[2], in_channels_list[1])
218
+ self.layer12 = AttentionUpsamplingModule(in_channels_list[1], in_channels_list[1])
219
+ self.layer13 = ConvLayer(in_channels_list[0] + in_channels_list[1], in_channels_list[0])
220
+ '''
221
+ def forward(self, inputs):
222
+ x_c4, x_c3, x_c2, x_c1 = inputs
223
+ x_c4 = self.layer1(x_c4)
224
+ x_c4_3 = self.layer2(x_c4)
225
+ x_c3 = torch.cat([x_c4_3, x_c3], dim=1)
226
+ x_c3 = self.layer3(x_c3)
227
+ x_c3_2 = self.layer4(x_c3)
228
+ x_c2 = torch.cat([x_c3_2, x_c2], dim=1)
229
+ x_c2 = self.layer5(x_c2)
230
+ x_c2_1 = self.layer6(x_c2)
231
+ x_c1 = torch.cat([x_c2_1, x_c1], dim=1)
232
+ x_c1 = self.layer7(x_c1)
233
+ '''
234
+ x_c1_2 = self.layer8(x_c1)
235
+ x_c2 = torch.cat([x_c1_2, x_c2], dim=1)
236
+ x_c2 = self.layer9(x_c2)
237
+ x_c2_3 = self.layer10(x_c2)
238
+ x_c3 = torch.cat([x_c2_3, x_c3], dim=1)
239
+ x_c3 = self.layer11(x_c3)
240
+ x_c3_4 = self.layer12(x_c3)
241
+ x_c4 = torch.cat([x_c3_4, x_c4], dim=1)
242
+ x_c4 = self.layer13(x_c4)
243
+ '''
244
+ return [x_c4, x_c3, x_c2, x_c1]
245
+
246
+
247
+
248
+ class EVPRefer(PreTrainedModel):
249
+ """Encoder Decoder segmentors.
250
+
251
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
252
+ Note that auxiliary_head is only used for deep supervision during training,
253
+ which could be dumped during inference.
254
+ """
255
+ config_class = EVPConfig
256
+ def __init__(self, config,
257
+ sd_path=None,
258
+ base_size=512,
259
+ token_embed_dim=768,
260
+ neck_dim=[320,680,1320,1280],
261
+ **args):
262
+ super().__init__(config)
263
+ config = OmegaConf.load('./v1-inference.yaml')
264
+ if os.path.exists(f'{sd_path}'):
265
+ config.model.params.ckpt_path = f'{sd_path}'
266
+ else:
267
+ config.model.params.ckpt_path = None
268
+
269
+ sd_model = instantiate_from_config(config.model)
270
+ self.encoder_vq = sd_model.first_stage_model
271
+ self.unet = UNetWrapper(sd_model.model, base_size=base_size)
272
+ del sd_model.cond_stage_model
273
+ del self.encoder_vq.decoder
274
+ for param in self.encoder_vq.parameters():
275
+ param.requires_grad = True
276
+
277
+ self.text_adapter = TextAdapterRefer(text_dim=token_embed_dim)
278
+
279
+ self.classifier = SimpleDecoding(dims=neck_dim)
280
+
281
+ self.gamma = nn.Parameter(torch.ones(token_embed_dim) * 1e-4)
282
+ self.aggregation = InverseMultiAttentiveFeatureRefinement([320,680,1320,1280])
283
+ self.clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
284
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
285
+
286
+ for param in self.clip_model.parameters():
287
+ param.requires_grad = True
288
+
289
+
290
+ def forward(self, img, sentences):
291
+ image_t = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(img)
292
+ shape = image_t.shape
293
+ img = torch.nn.functional.interpolate(image_t, (512,512), mode='bilinear', align_corners=True)
294
+
295
+ input_ids = self.tokenizer(text=sentences, truncation=True, max_length=40, return_length=True,
296
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")['input_ids'].to(image_t.device)
297
+
298
+
299
+ input_shape = img.shape[-2:]
300
+
301
+ latents = self.encoder_vq.encode(img).mode()
302
+ latents = latents / 4.7164
303
+
304
+ l_feats = self.clip_model(input_ids=input_ids).last_hidden_state
305
+ c_crossattn = self.text_adapter(latents, l_feats, self.gamma) # NOTE: here the c_crossattn should be expand_dim as latents
306
+ t = torch.ones((img.shape[0],), device=img.device).long()
307
+ outs = self.unet(latents, t, c_crossattn=[c_crossattn])
308
+
309
+ outs = self.aggregation(outs)
310
+
311
+ x_c1, x_c2, x_c3, x_c4 = outs
312
+ x = self.classifier(x_c4, x_c3, x_c2, x_c1)
313
+ x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True)
314
+ pred = torch.nn.functional.interpolate(x, shape[2:], mode='bilinear', align_corners=True)
315
+ output_mask = pred.detach().cpu().argmax(1).data.numpy().squeeze()
316
+ return output_mask
317
+
318
+
319
+ def get_latent(self, x):
320
+ return self.encoder_vq.encode(x).mode()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07d8a7895e79a4defbf5445cada1a9973bcda51f7c73e4a91c878074a5f758f5
3
+ size 4317946624
models.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+
3
+ import torch as th
4
+ import torch
5
+ import math
6
+ import abc
7
+
8
+ from torch import nn, einsum
9
+
10
+ from einops import rearrange, repeat
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+ from transformers import CLIPTokenizer
13
+ from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPTextTransformer#, _expand_mask
14
+ from inspect import isfunction
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def uniq(arr):
22
+ return{el: True for el in arr}.keys()
23
+
24
+
25
+ def default(val, d):
26
+ if exists(val):
27
+ return val
28
+ return d() if isfunction(d) else d
29
+
30
+
31
+
32
+ def register_attention_control(model, controller):
33
+ def ca_forward(self, place_in_unet):
34
+ def forward(x, context=None, mask=None):
35
+ h = self.heads
36
+
37
+ q = self.to_q(x)
38
+ is_cross = context is not None
39
+ context = default(context, x)
40
+ k = self.to_k(context)
41
+ v = self.to_v(context)
42
+
43
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
44
+
45
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
46
+
47
+ if exists(mask):
48
+ mask = rearrange(mask, 'b ... -> b (...)')
49
+ max_neg_value = -torch.finfo(sim.dtype).max
50
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
51
+ sim.masked_fill_(~mask, max_neg_value)
52
+
53
+ # attention, what we cannot get enough of
54
+ attn = sim.softmax(dim=-1)
55
+
56
+ attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0)
57
+ controller(attn2, is_cross, place_in_unet)
58
+
59
+ out = einsum('b i j, b j d -> b i d', attn, v)
60
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
61
+ return self.to_out(out)
62
+
63
+ return forward
64
+
65
+ class DummyController:
66
+ def __call__(self, *args):
67
+ return args[0]
68
+
69
+ def __init__(self):
70
+ self.num_att_layers = 0
71
+
72
+ if controller is None:
73
+ controller = DummyController()
74
+
75
+ def register_recr(net_, count, place_in_unet):
76
+ if net_.__class__.__name__ == 'CrossAttention':
77
+ net_.forward = ca_forward(net_, place_in_unet)
78
+ return count + 1
79
+ elif hasattr(net_, 'children'):
80
+ for net__ in net_.children():
81
+ count = register_recr(net__, count, place_in_unet)
82
+ return count
83
+
84
+ cross_att_count = 0
85
+ sub_nets = model.diffusion_model.named_children()
86
+
87
+ for net in sub_nets:
88
+ if "input_blocks" in net[0]:
89
+ cross_att_count += register_recr(net[1], 0, "down")
90
+ elif "output_blocks" in net[0]:
91
+ cross_att_count += register_recr(net[1], 0, "up")
92
+ elif "middle_block" in net[0]:
93
+ cross_att_count += register_recr(net[1], 0, "mid")
94
+
95
+ controller.num_att_layers = cross_att_count
96
+
97
+
98
+ class AttentionControl(abc.ABC):
99
+
100
+ def step_callback(self, x_t):
101
+ return x_t
102
+
103
+ def between_steps(self):
104
+ return
105
+
106
+ @property
107
+ def num_uncond_att_layers(self):
108
+ return 0
109
+
110
+ @abc.abstractmethod
111
+ def forward (self, attn, is_cross: bool, place_in_unet: str):
112
+ raise NotImplementedError
113
+
114
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
115
+ attn = self.forward(attn, is_cross, place_in_unet)
116
+ return attn
117
+
118
+ def reset(self):
119
+ self.cur_step = 0
120
+ self.cur_att_layer = 0
121
+
122
+ def __init__(self):
123
+ self.cur_step = 0
124
+ self.num_att_layers = -1
125
+ self.cur_att_layer = 0
126
+
127
+
128
+ class AttentionStore(AttentionControl):
129
+ @staticmethod
130
+ def get_empty_store():
131
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
132
+ "down_self": [], "mid_self": [], "up_self": []}
133
+
134
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
135
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
136
+ if attn.shape[1] <= (self.max_size) ** 2: # avoid memory overhead
137
+ self.step_store[key].append(attn)
138
+ return attn
139
+
140
+ def between_steps(self):
141
+ if len(self.attention_store) == 0:
142
+ self.attention_store = self.step_store
143
+ else:
144
+ for key in self.attention_store:
145
+ for i in range(len(self.attention_store[key])):
146
+ self.attention_store[key][i] += self.step_store[key][i]
147
+ self.step_store = self.get_empty_store()
148
+
149
+ def get_average_attention(self):
150
+ average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store}
151
+ return average_attention
152
+
153
+ def reset(self):
154
+ super(AttentionStore, self).reset()
155
+ self.step_store = self.get_empty_store()
156
+ self.attention_store = {}
157
+
158
+ def __init__(self, base_size=64, max_size=None):
159
+ super(AttentionStore, self).__init__()
160
+ self.step_store = self.get_empty_store()
161
+ self.attention_store = {}
162
+ self.base_size = base_size
163
+ if max_size is None:
164
+ self.max_size = self.base_size // 2
165
+ else:
166
+ self.max_size = max_size
167
+
168
+ def register_hier_output(model):
169
+ self = model.diffusion_model
170
+ from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
171
+ def forward(x, timesteps=None, context=None, y=None,**kwargs):
172
+ """
173
+ Apply the model to an input batch.
174
+ :param x: an [N x C x ...] Tensor of inputs.
175
+ :param timesteps: a 1-D batch of timesteps.
176
+ :param context: conditioning plugged in via crossattn
177
+ :param y: an [N] Tensor of labels, if class-conditional.
178
+ :return: an [N x C x ...] Tensor of outputs.
179
+ """
180
+ assert (y is not None) == (
181
+ self.num_classes is not None
182
+ ), "must specify y if and only if the model is class-conditional"
183
+ hs = []
184
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
185
+ emb = self.time_embed(t_emb)
186
+
187
+ if self.num_classes is not None:
188
+ assert y.shape == (x.shape[0],)
189
+ emb = emb + self.label_emb(y)
190
+
191
+ h = x.type(self.dtype)
192
+ for module in self.input_blocks:
193
+ # import pdb; pdb.set_trace()
194
+ if context.shape[1]==2:
195
+ h = module(h, emb, context[:,0,:].unsqueeze(1))
196
+ else:
197
+ h = module(h, emb, context)
198
+ hs.append(h)
199
+ if context.shape[1]==2:
200
+ h = self.middle_block(h, emb, context[:,0,:].unsqueeze(1))
201
+ else:
202
+ h = self.middle_block(h, emb, context)
203
+ out_list = []
204
+
205
+ for i_out, module in enumerate(self.output_blocks):
206
+ h = th.cat([h, hs.pop()], dim=1)
207
+ if context.shape[1]==2:
208
+ h = module(h, emb, context[:,1,:].unsqueeze(1))
209
+ else:
210
+ h = module(h, emb, context)
211
+ if i_out in [1, 4, 7]:
212
+ out_list.append(h)
213
+ h = h.type(x.dtype)
214
+
215
+ out_list.append(h)
216
+ return out_list
217
+
218
+ self.forward = forward
219
+
220
+ class UNetWrapper(nn.Module):
221
+ def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None:
222
+ super().__init__()
223
+ self.unet = unet
224
+ self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size)
225
+ self.size16 = base_size // 32
226
+ self.size32 = base_size // 16
227
+ self.size64 = base_size // 8
228
+ self.use_attn = use_attn
229
+ if self.use_attn:
230
+ register_attention_control(unet, self.attention_store)
231
+ register_hier_output(unet)
232
+ self.attn_selector = attn_selector.split('+')
233
+
234
+ def forward(self, *args, **kwargs):
235
+ if self.use_attn:
236
+ self.attention_store.reset()
237
+ out_list = self.unet(*args, **kwargs)
238
+ if self.use_attn:
239
+ avg_attn = self.attention_store.get_average_attention()
240
+ attn16, attn32, attn64 = self.process_attn(avg_attn)
241
+ out_list[1] = torch.cat([out_list[1], attn16], dim=1)
242
+ out_list[2] = torch.cat([out_list[2], attn32], dim=1)
243
+ if attn64 is not None:
244
+ out_list[3] = torch.cat([out_list[3], attn64], dim=1)
245
+ return out_list[::-1]
246
+
247
+ def process_attn(self, avg_attn):
248
+ attns = {self.size16: [], self.size32: [], self.size64: []}
249
+ for k in self.attn_selector:
250
+ for up_attn in avg_attn[k]:
251
+ size = int(math.sqrt(up_attn.shape[1]))
252
+ attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size))
253
+ attn16 = torch.stack(attns[self.size16]).mean(0)
254
+ attn32 = torch.stack(attns[self.size32]).mean(0)
255
+ if len(attns[self.size64]) > 0:
256
+ attn64 = torch.stack(attns[self.size64]).mean(0)
257
+ else:
258
+ attn64 = None
259
+ return attn16, attn32, attn64
260
+
261
+ class TextAdapter(nn.Module):
262
+ def __init__(self, text_dim=768, hidden_dim=None):
263
+ super().__init__()
264
+ if hidden_dim is None:
265
+ hidden_dim = text_dim
266
+ self.fc = nn.Sequential(
267
+ nn.Linear(text_dim, hidden_dim),
268
+ nn.GELU(),
269
+ nn.Linear(hidden_dim, text_dim)
270
+ )
271
+
272
+ def forward(self, latents, texts, gamma):
273
+ n_class, channel = texts.shape
274
+ bs = latents.shape[0]
275
+
276
+ texts_after = self.fc(texts)
277
+ texts = texts + gamma * texts_after
278
+ texts = repeat(texts, 'n c -> b n c', b=bs)
279
+ return texts
280
+
281
+ class TextAdapterRefer(nn.Module):
282
+ def __init__(self, text_dim=768):
283
+ super().__init__()
284
+
285
+ self.fc = nn.Sequential(
286
+ nn.Linear(text_dim, text_dim),
287
+ nn.GELU(),
288
+ nn.Linear(text_dim, text_dim)
289
+ )
290
+
291
+ def forward(self, latents, texts, gamma):
292
+ texts_after = self.fc(texts)
293
+ texts = texts + gamma * texts_after
294
+ return texts
295
+
296
+
297
+ class TextAdapterDepth(nn.Module):
298
+ def __init__(self, text_dim=768):
299
+ super().__init__()
300
+
301
+ self.fc = nn.Sequential(
302
+ nn.Linear(text_dim, text_dim),
303
+ nn.GELU(),
304
+ nn.Linear(text_dim, text_dim)
305
+ )
306
+
307
+ def forward(self, latents, texts, gamma):
308
+ # use the gamma to blend
309
+ n_sen, channel = texts.shape
310
+ bs = latents.shape[0]
311
+
312
+ texts_after = self.fc(texts)
313
+ texts = texts + gamma * texts_after
314
+ texts = repeat(texts, 'n c -> n b c', b=1)
315
+ return texts
316
+
317
+
318
+ class FrozenCLIPEmbedder(nn.Module):
319
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
320
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, pool=True):
321
+ super().__init__()
322
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
323
+ self.transformer = CLIPTextModel.from_pretrained(version)
324
+ self.device = device
325
+ self.max_length = max_length
326
+ self.freeze()
327
+
328
+ self.pool = pool
329
+
330
+ def freeze(self):
331
+ self.transformer = self.transformer.eval()
332
+ for param in self.parameters():
333
+ param.requires_grad = False
334
+
335
+ def forward(self, text):
336
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
337
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
338
+ tokens = batch_encoding["input_ids"].to(self.device)
339
+ outputs = self.transformer(input_ids=tokens)
340
+
341
+ if self.pool:
342
+ z = outputs.pooler_output
343
+ else:
344
+ z = outputs.last_hidden_state
345
+ return z
346
+
347
+ def encode(self, text):
348
+ return self(text)
349
+