Anyou commited on
Commit
0ffaa52
1 Parent(s): b2b0303

updated models

Browse files
models/blip_override/blip.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+ from .vit import VisionTransformer, interpolate_pos_embed
13
+ from .med import BertModel, BertLMHeadModel
14
+ from transformers import BertTokenizer, BertConfig
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ import os
20
+ from urllib.parse import urlparse
21
+ from timm.models.hub import download_cached_file
22
+
23
+
24
+ class BLIP_Base(nn.Module):
25
+ def __init__(self,
26
+ med_config='models/blip_override/med_config.json',
27
+ image_size=224,
28
+ vit='base',
29
+ vit_grad_ckpt=False,
30
+ vit_ckpt_layer=0,
31
+ ):
32
+ """
33
+ Args:
34
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
35
+ image_size (int): input image size
36
+ vit (str): model size of vision transformer
37
+ """
38
+ super().__init__()
39
+
40
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
41
+ self.tokenizer = init_tokenizer()
42
+ med_config = BertConfig.from_json_file(med_config)
43
+ med_config.encoder_width = vision_width
44
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
45
+
46
+ def forward(self, image, text, attention_mask, mode):
47
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
48
+ if mode == 'image':
49
+ # return image features
50
+ image_embeds = self.visual_encoder(image)
51
+ return image_embeds
52
+
53
+ elif mode == 'text':
54
+ # return text features
55
+ text_output = self.text_encoder(text, attention_mask=attention_mask, return_dict=True, mode='text')
56
+ return text_output.last_hidden_state
57
+
58
+ elif mode == 'multimodal':
59
+ # return multimodel features
60
+ image_embeds = self.visual_encoder(image)
61
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
62
+
63
+ text[:, 0] = self.tokenizer.enc_token_id
64
+ output = self.text_encoder(text,
65
+ attention_mask=attention_mask,
66
+ encoder_hidden_states=image_embeds,
67
+ encoder_attention_mask=image_atts,
68
+ return_dict=True,
69
+ )
70
+ return output.last_hidden_state
71
+
72
+
73
+ class BLIP_Decoder(nn.Module):
74
+ def __init__(self,
75
+ med_config='models/blip_override/med_config.json',
76
+ image_size=384,
77
+ vit='base',
78
+ vit_grad_ckpt=False,
79
+ vit_ckpt_layer=0,
80
+ prompt='a picture of ',
81
+ ):
82
+ """
83
+ Args:
84
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
85
+ image_size (int): input image size
86
+ vit (str): model size of vision transformer
87
+ """
88
+ super().__init__()
89
+
90
+ self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
91
+ self.tokenizer = init_tokenizer()
92
+ med_config = BertConfig.from_json_file(med_config)
93
+ med_config.encoder_width = vision_width
94
+ self.text_decoder = BertLMHeadModel(config=med_config)
95
+
96
+ self.prompt = prompt
97
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
98
+
99
+ def forward(self, image, caption):
100
+
101
+ image_embeds = self.visual_encoder(image)
102
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
103
+
104
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(
105
+ image.device)
106
+
107
+ text.input_ids[:, 0] = self.tokenizer.bos_token_id
108
+
109
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
110
+ decoder_targets[:, :self.prompt_length] = -100
111
+
112
+ decoder_output = self.text_decoder(text.input_ids,
113
+ attention_mask=text.attention_mask,
114
+ encoder_hidden_states=image_embeds,
115
+ encoder_attention_mask=image_atts,
116
+ labels=decoder_targets,
117
+ return_dict=True,
118
+ )
119
+ loss_lm = decoder_output.loss
120
+
121
+ return loss_lm
122
+
123
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9,
124
+ repetition_penalty=1.0):
125
+ image_embeds = self.visual_encoder(image)
126
+
127
+ if not sample:
128
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
129
+
130
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
131
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask": image_atts}
132
+
133
+ prompt = [self.prompt] * image.size(0)
134
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
135
+ input_ids[:, 0] = self.tokenizer.bos_token_id
136
+ input_ids = input_ids[:, :-1]
137
+
138
+ if sample:
139
+ # nucleus sampling
140
+ outputs = self.text_decoder.generate(input_ids=input_ids,
141
+ max_length=max_length,
142
+ min_length=min_length,
143
+ do_sample=True,
144
+ top_p=top_p,
145
+ num_return_sequences=1,
146
+ eos_token_id=self.tokenizer.sep_token_id,
147
+ pad_token_id=self.tokenizer.pad_token_id,
148
+ repetition_penalty=1.1,
149
+ **model_kwargs)
150
+ else:
151
+ # beam search
152
+ outputs = self.text_decoder.generate(input_ids=input_ids,
153
+ max_length=max_length,
154
+ min_length=min_length,
155
+ num_beams=num_beams,
156
+ eos_token_id=self.tokenizer.sep_token_id,
157
+ pad_token_id=self.tokenizer.pad_token_id,
158
+ repetition_penalty=repetition_penalty,
159
+ **model_kwargs)
160
+
161
+ captions = []
162
+ for output in outputs:
163
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
164
+ captions.append(caption[len(self.prompt):])
165
+ return captions
166
+
167
+
168
+ def blip_decoder(pretrained='', **kwargs):
169
+ model = BLIP_Decoder(**kwargs)
170
+ if pretrained:
171
+ model, msg = load_checkpoint(model, pretrained)
172
+ assert (len(msg.missing_keys) == 0)
173
+ return model
174
+
175
+
176
+ def blip_feature_extractor(pretrained='', **kwargs):
177
+ model = BLIP_Base(**kwargs)
178
+ if pretrained:
179
+ model, msg = load_checkpoint(model, pretrained)
180
+ assert (len(msg.missing_keys) == 0)
181
+ return model
182
+
183
+
184
+ def init_tokenizer():
185
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
186
+ tokenizer.add_special_tokens({'bos_token': '[DEC]'})
187
+ tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
188
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
189
+ return tokenizer
190
+
191
+
192
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
193
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
194
+ assert use_grad_checkpointing is False, 'grad checkpointing is not supported yet'
195
+ if vit == 'base':
196
+ vision_width = 768
197
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
198
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing,
199
+ ckpt_layer=ckpt_layer,
200
+ drop_path_rate=0 or drop_path_rate
201
+ )
202
+ elif vit == 'large':
203
+ vision_width = 1024
204
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
205
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing,
206
+ ckpt_layer=ckpt_layer,
207
+ drop_path_rate=0.1 or drop_path_rate
208
+ )
209
+ return visual_encoder, vision_width
210
+
211
+
212
+ def is_url(url_or_filename):
213
+ parsed = urlparse(url_or_filename)
214
+ return parsed.scheme in ("http", "https")
215
+
216
+
217
+ def load_checkpoint(model, url_or_filename):
218
+ if is_url(url_or_filename):
219
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
220
+ checkpoint = torch.load(cached_file, map_location='cpu')
221
+ elif os.path.isfile(url_or_filename):
222
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
223
+ else:
224
+ raise RuntimeError('checkpoint url or path is invalid')
225
+
226
+ state_dict = checkpoint['model']
227
+
228
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],
229
+ model.visual_encoder)
230
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
231
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
232
+ model.visual_encoder_m)
233
+ for key in model.state_dict().keys():
234
+ if key in state_dict.keys():
235
+ if state_dict[key].shape != model.state_dict()[key].shape:
236
+ del state_dict[key]
237
+
238
+ msg = model.load_state_dict(state_dict, strict=False)
239
+ print('load checkpoint from %s' % url_or_filename)
240
+ return model, msg
models/blip_override/med.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+
613
+ Arguments:
614
+ attention_mask (:obj:`torch.Tensor`):
615
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
616
+ input_shape (:obj:`Tuple[int]`):
617
+ The shape of the input to the model.
618
+ device: (:obj:`torch.device`):
619
+ The device of the input to the model.
620
+
621
+ Returns:
622
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
623
+ """
624
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
625
+ # ourselves in which case we just need to make it broadcastable to all heads.
626
+ if attention_mask.dim() == 3:
627
+ extended_attention_mask = attention_mask[:, None, :, :]
628
+ elif attention_mask.dim() == 2:
629
+ # Provided a padding mask of dimensions [batch_size, seq_length]
630
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
631
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
632
+ if is_decoder:
633
+ batch_size, seq_length = input_shape
634
+
635
+ seq_ids = torch.arange(seq_length, device=device)
636
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
637
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
638
+ # causal and attention masks must have same type with pytorch version < 1.3
639
+ causal_mask = causal_mask.to(attention_mask.dtype)
640
+
641
+ if causal_mask.shape[1] < attention_mask.shape[1]:
642
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
643
+ causal_mask = torch.cat(
644
+ [
645
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
646
+ causal_mask,
647
+ ],
648
+ axis=-1,
649
+ )
650
+
651
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
+ else:
653
+ extended_attention_mask = attention_mask[:, None, None, :]
654
+ else:
655
+ raise ValueError(
656
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
657
+ input_shape, attention_mask.shape
658
+ )
659
+ )
660
+
661
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
662
+ # masked positions, this operation will create a tensor which is 0.0 for
663
+ # positions we want to attend and -10000.0 for masked positions.
664
+ # Since we are adding it to the raw scores before the softmax, this is
665
+ # effectively the same as removing these entirely.
666
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
667
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
668
+ return extended_attention_mask
669
+
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ position_ids=None,
675
+ head_mask=None,
676
+ inputs_embeds=None,
677
+ encoder_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ past_key_values=None,
681
+ use_cache=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ is_decoder=False,
686
+ mode='multimodal',
687
+ ):
688
+ r"""
689
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
690
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
691
+ the model is configured as a decoder.
692
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
693
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
694
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
695
+ - 1 for tokens that are **not masked**,
696
+ - 0 for tokens that are **masked**.
697
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
698
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
699
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
700
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
701
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
702
+ use_cache (:obj:`bool`, `optional`):
703
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
704
+ decoding (see :obj:`past_key_values`).
705
+ """
706
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
707
+ output_hidden_states = (
708
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
709
+ )
710
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
711
+
712
+ if is_decoder:
713
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
714
+ else:
715
+ use_cache = False
716
+
717
+ if input_ids is not None and inputs_embeds is not None:
718
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
719
+ elif input_ids is not None:
720
+ input_shape = input_ids.size()
721
+ batch_size, seq_length = input_shape
722
+ device = input_ids.device
723
+ elif inputs_embeds is not None:
724
+ input_shape = inputs_embeds.size()[:-1]
725
+ batch_size, seq_length = input_shape
726
+ device = inputs_embeds.device
727
+ elif encoder_embeds is not None:
728
+ input_shape = encoder_embeds.size()[:-1]
729
+ batch_size, seq_length = input_shape
730
+ device = encoder_embeds.device
731
+ else:
732
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
733
+
734
+ # past_key_values_length
735
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
736
+
737
+ if attention_mask is None:
738
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
739
+
740
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
741
+ # ourselves in which case we just need to make it broadcastable to all heads.
742
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
743
+ device, is_decoder)
744
+
745
+ # If a 2D or 3D attention mask is provided for the cross-attention
746
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
747
+ if encoder_hidden_states is not None:
748
+ if type(encoder_hidden_states) == list:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
750
+ else:
751
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
752
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
753
+
754
+ if type(encoder_attention_mask) == list:
755
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
756
+ elif encoder_attention_mask is None:
757
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
761
+ else:
762
+ encoder_extended_attention_mask = None
763
+
764
+ # Prepare head mask if needed
765
+ # 1.0 in head_mask indicate we keep the head
766
+ # attention_probs has shape bsz x n_heads x N x N
767
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
768
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
769
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
770
+
771
+ if encoder_embeds is None:
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ position_ids=position_ids,
775
+ inputs_embeds=inputs_embeds,
776
+ past_key_values_length=past_key_values_length,
777
+ )
778
+ else:
779
+ embedding_output = encoder_embeds
780
+
781
+ encoder_outputs = self.encoder(
782
+ embedding_output,
783
+ attention_mask=extended_attention_mask,
784
+ head_mask=head_mask,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ encoder_attention_mask=encoder_extended_attention_mask,
787
+ past_key_values=past_key_values,
788
+ use_cache=use_cache,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ mode=mode,
793
+ )
794
+ sequence_output = encoder_outputs[0]
795
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
796
+
797
+ if not return_dict:
798
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
799
+
800
+ return BaseModelOutputWithPoolingAndCrossAttentions(
801
+ last_hidden_state=sequence_output,
802
+ pooler_output=pooled_output,
803
+ past_key_values=encoder_outputs.past_key_values,
804
+ hidden_states=encoder_outputs.hidden_states,
805
+ attentions=encoder_outputs.attentions,
806
+ cross_attentions=encoder_outputs.cross_attentions,
807
+ )
808
+
809
+
810
+
811
+ class BertLMHeadModel(BertPreTrainedModel):
812
+
813
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
814
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
815
+
816
+ def __init__(self, config):
817
+ super().__init__(config)
818
+
819
+ self.bert = BertModel(config, add_pooling_layer=False)
820
+ self.cls = BertOnlyMLMHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def get_output_embeddings(self):
825
+ return self.cls.predictions.decoder
826
+
827
+ def set_output_embeddings(self, new_embeddings):
828
+ self.cls.predictions.decoder = new_embeddings
829
+
830
+ def forward(
831
+ self,
832
+ input_ids=None,
833
+ attention_mask=None,
834
+ position_ids=None,
835
+ head_mask=None,
836
+ inputs_embeds=None,
837
+ encoder_hidden_states=None,
838
+ encoder_attention_mask=None,
839
+ labels=None,
840
+ past_key_values=None,
841
+ use_cache=None,
842
+ output_attentions=None,
843
+ output_hidden_states=None,
844
+ return_dict=None,
845
+ return_logits=False,
846
+ is_decoder=True,
847
+ reduction='mean',
848
+ mode='multimodal',
849
+ ):
850
+ r"""
851
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
852
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
853
+ the model is configured as a decoder.
854
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
855
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
856
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
860
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
861
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
862
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
863
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
864
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
865
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
866
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
868
+ use_cache (:obj:`bool`, `optional`):
869
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
870
+ decoding (see :obj:`past_key_values`).
871
+ Returns:
872
+ Example::
873
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
874
+ >>> import torch
875
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
876
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
877
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
878
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
879
+ >>> outputs = model(**inputs)
880
+ >>> prediction_logits = outputs.logits
881
+ """
882
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
883
+ if labels is not None:
884
+ use_cache = False
885
+
886
+ outputs = self.bert(
887
+ input_ids,
888
+ attention_mask=attention_mask,
889
+ position_ids=position_ids,
890
+ head_mask=head_mask,
891
+ inputs_embeds=inputs_embeds,
892
+ encoder_hidden_states=encoder_hidden_states,
893
+ encoder_attention_mask=encoder_attention_mask,
894
+ past_key_values=past_key_values,
895
+ use_cache=use_cache,
896
+ output_attentions=output_attentions,
897
+ output_hidden_states=output_hidden_states,
898
+ return_dict=return_dict,
899
+ is_decoder=is_decoder,
900
+ mode=mode,
901
+ )
902
+
903
+ sequence_output = outputs[0]
904
+ prediction_scores = self.cls(sequence_output)
905
+
906
+ if return_logits:
907
+ return prediction_scores[:, :-1, :].contiguous()
908
+
909
+ lm_loss = None
910
+ if labels is not None:
911
+ # we are doing next-token prediction; shift prediction scores and input ids by one
912
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
913
+ labels = labels[:, 1:].contiguous()
914
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
915
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
916
+ if reduction=='none':
917
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
918
+
919
+ if not return_dict:
920
+ output = (prediction_scores,) + outputs[2:]
921
+ return ((lm_loss,) + output) if lm_loss is not None else output
922
+
923
+ return CausalLMOutputWithCrossAttentions(
924
+ loss=lm_loss,
925
+ logits=prediction_scores,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ cross_attentions=outputs.cross_attentions,
930
+ )
931
+
932
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
933
+ input_shape = input_ids.shape
934
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
935
+ if attention_mask is None:
936
+ attention_mask = input_ids.new_ones(input_shape)
937
+
938
+ # cut decoder_input_ids if past is used
939
+ if past is not None:
940
+ input_ids = input_ids[:, -1:]
941
+
942
+ return {
943
+ "input_ids": input_ids,
944
+ "attention_mask": attention_mask,
945
+ "past_key_values": past,
946
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
947
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
948
+ "is_decoder": True,
949
+ }
950
+
951
+ def _reorder_cache(self, past, beam_idx):
952
+ reordered_past = ()
953
+ for layer_past in past:
954
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
955
+ return reordered_past
models/blip_override/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
models/blip_override/vit.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+
22
+ class Mlp(nn.Module):
23
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
24
+ """
25
+
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ def forward(self, x, register_hook=False):
104
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
106
+ return x
107
+
108
+
109
+ class VisionTransformer(nn.Module):
110
+ """ Vision Transformer
111
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
+ https://arxiv.org/abs/2010.11929
113
+ """
114
+
115
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
116
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
117
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
118
+ use_grad_checkpointing=False, ckpt_layer=0):
119
+ """
120
+ Args:
121
+ img_size (int, tuple): input image size
122
+ patch_size (int, tuple): patch size
123
+ in_chans (int): number of input channels
124
+ num_classes (int): number of classes for classification head
125
+ embed_dim (int): embedding dimension
126
+ depth (int): depth of transformer
127
+ num_heads (int): number of attention heads
128
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
129
+ qkv_bias (bool): enable bias for qkv if True
130
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
131
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
132
+ drop_rate (float): dropout rate
133
+ attn_drop_rate (float): attention dropout rate
134
+ drop_path_rate (float): stochastic depth rate
135
+ norm_layer: (nn.Module): normalization layer
136
+ """
137
+ super().__init__()
138
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
139
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
140
+
141
+ self.patch_embed = PatchEmbed(
142
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
143
+
144
+ num_patches = self.patch_embed.num_patches
145
+
146
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
147
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
148
+ self.pos_drop = nn.Dropout(p=drop_rate)
149
+
150
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
151
+ self.blocks = nn.ModuleList([
152
+ Block(
153
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
155
+ use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer)
156
+ )
157
+ for i in range(depth)])
158
+ self.norm = norm_layer(embed_dim)
159
+
160
+ trunc_normal_(self.pos_embed, std=.02)
161
+ trunc_normal_(self.cls_token, std=.02)
162
+ self.apply(self._init_weights)
163
+
164
+ def _init_weights(self, m):
165
+ if isinstance(m, nn.Linear):
166
+ trunc_normal_(m.weight, std=.02)
167
+ if isinstance(m, nn.Linear) and m.bias is not None:
168
+ nn.init.constant_(m.bias, 0)
169
+ elif isinstance(m, nn.LayerNorm):
170
+ nn.init.constant_(m.bias, 0)
171
+ nn.init.constant_(m.weight, 1.0)
172
+
173
+ @torch.jit.ignore
174
+ def no_weight_decay(self):
175
+ return {'pos_embed', 'cls_token'}
176
+
177
+ def forward(self, x, register_blk=-1):
178
+ B = x.shape[0]
179
+ x = self.patch_embed(x)
180
+
181
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
182
+ x = torch.cat((cls_tokens, x), dim=1)
183
+
184
+ x = x + self.pos_embed[:, :x.size(1), :]
185
+ x = self.pos_drop(x)
186
+
187
+ for i, blk in enumerate(self.blocks):
188
+ x = blk(x, register_blk == i)
189
+ x = self.norm(x)
190
+
191
+ return x
192
+
193
+ @torch.jit.ignore()
194
+ def load_pretrained(self, checkpoint_path, prefix=''):
195
+ _load_weights(self, checkpoint_path, prefix)
196
+
197
+
198
+ @torch.no_grad()
199
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
200
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
201
+ """
202
+ import numpy as np
203
+
204
+ def _n2p(w, t=True):
205
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
206
+ w = w.flatten()
207
+ if t:
208
+ if w.ndim == 4:
209
+ w = w.transpose([3, 2, 0, 1])
210
+ elif w.ndim == 3:
211
+ w = w.transpose([2, 0, 1])
212
+ elif w.ndim == 2:
213
+ w = w.transpose([1, 0])
214
+ return torch.from_numpy(w)
215
+
216
+ w = np.load(checkpoint_path)
217
+ if not prefix and 'opt/target/embedding/kernel' in w:
218
+ prefix = 'opt/target/'
219
+
220
+ if hasattr(model.patch_embed, 'backbone'):
221
+ # hybrid
222
+ backbone = model.patch_embed.backbone
223
+ stem_only = not hasattr(backbone, 'stem')
224
+ stem = backbone if stem_only else backbone.stem
225
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
226
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
227
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
228
+ if not stem_only:
229
+ for i, stage in enumerate(backbone.stages):
230
+ for j, block in enumerate(stage.blocks):
231
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
232
+ for r in range(3):
233
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
234
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
235
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
236
+ if block.downsample is not None:
237
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
238
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
239
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
240
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
241
+ else:
242
+ embed_conv_w = adapt_input_conv(
243
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
244
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
245
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
246
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
247
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
248
+ if pos_embed_w.shape != model.pos_embed.shape:
249
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
250
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
251
+ model.pos_embed.copy_(pos_embed_w)
252
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
253
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
254
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
255
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
256
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
257
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
258
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
259
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
260
+ for i, block in enumerate(model.blocks.children()):
261
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
262
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
263
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
264
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
265
+ block.attn.qkv.weight.copy_(torch.cat([
266
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
267
+ block.attn.qkv.bias.copy_(torch.cat([
268
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
269
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
270
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
271
+ for r in range(2):
272
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
273
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
274
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
275
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
276
+
277
+
278
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
279
+ # interpolate position embedding
280
+ embedding_size = pos_embed_checkpoint.shape[-1]
281
+ num_patches = visual_encoder.patch_embed.num_patches
282
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
283
+ # height (== width) for the checkpoint position embedding
284
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
285
+ # height (== width) for the new position embedding
286
+ new_size = int(num_patches ** 0.5)
287
+
288
+ if orig_size != new_size:
289
+ # class_token and dist_token are kept unchanged
290
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
291
+ # only the position tokens are interpolated
292
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
293
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
294
+ pos_tokens = torch.nn.functional.interpolate(
295
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
296
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
297
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
298
+ print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2))
299
+
300
+ return new_pos_embed
301
+ else:
302
+ return pos_embed_checkpoint
models/diffusers_override/attention.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.modeling_utils import ModelMixin
24
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
25
+ from diffusers.utils import BaseOutput
26
+ from diffusers.utils.import_utils import is_xformers_available
27
+
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
35
+ for the unnoised latent pixels.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+ class Transformer2DModel(ModelMixin, ConfigMixin):
49
+ """
50
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
51
+ embeddings) inputs.
52
+
53
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
54
+ transformer action. Finally, reshape to image.
55
+
56
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
57
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
58
+ classes of unnoised image.
59
+
60
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
61
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
62
+
63
+ Parameters:
64
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
65
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
66
+ in_channels (`int`, *optional*):
67
+ Pass if the input is continuous. The number of channels in the input and output.
68
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
69
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
70
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
71
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
72
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
73
+ `ImagePositionalEmbeddings`.
74
+ num_vector_embeds (`int`, *optional*):
75
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
76
+ Includes the class for the masked latent pixel.
77
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
78
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
79
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
80
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
81
+ up to but not more than steps than `num_embeds_ada_norm`.
82
+ attention_bias (`bool`, *optional*):
83
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
84
+ """
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ num_attention_heads: int = 16,
90
+ attention_head_dim: int = 88,
91
+ in_channels: Optional[int] = None,
92
+ num_layers: int = 1,
93
+ dropout: float = 0.0,
94
+ norm_num_groups: int = 32,
95
+ cross_attention_dim: Optional[int] = None,
96
+ attention_bias: bool = False,
97
+ sample_size: Optional[int] = None,
98
+ num_vector_embeds: Optional[int] = None,
99
+ activation_fn: str = "geglu",
100
+ num_embeds_ada_norm: Optional[int] = None,
101
+ ):
102
+ super().__init__()
103
+ self.num_attention_heads = num_attention_heads
104
+ self.attention_head_dim = attention_head_dim
105
+ inner_dim = num_attention_heads * attention_head_dim
106
+
107
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
108
+ # Define whether input is continuous or discrete depending on configuration
109
+ self.is_input_continuous = in_channels is not None
110
+ self.is_input_vectorized = num_vector_embeds is not None
111
+
112
+ if self.is_input_continuous and self.is_input_vectorized:
113
+ raise ValueError(
114
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
115
+ " sure that either `in_channels` or `num_vector_embeds` is None."
116
+ )
117
+ elif not self.is_input_continuous and not self.is_input_vectorized:
118
+ raise ValueError(
119
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
120
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
121
+ )
122
+
123
+ # 2. Define input layers
124
+ if self.is_input_continuous:
125
+ self.in_channels = in_channels
126
+
127
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
128
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
129
+ elif self.is_input_vectorized:
130
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
131
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
132
+
133
+ self.height = sample_size
134
+ self.width = sample_size
135
+ self.num_vector_embeds = num_vector_embeds
136
+ self.num_latent_pixels = self.height * self.width
137
+
138
+ self.latent_image_embedding = ImagePositionalEmbeddings(
139
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
140
+ )
141
+
142
+ # 3. Define transformers blocks
143
+ self.transformer_blocks = nn.ModuleList(
144
+ [
145
+ BasicTransformerBlock(
146
+ inner_dim,
147
+ num_attention_heads,
148
+ attention_head_dim,
149
+ dropout=dropout,
150
+ cross_attention_dim=cross_attention_dim,
151
+ activation_fn=activation_fn,
152
+ num_embeds_ada_norm=num_embeds_ada_norm,
153
+ attention_bias=attention_bias,
154
+ )
155
+ for d in range(num_layers)
156
+ ]
157
+ )
158
+
159
+ # 4. Define output layers
160
+ if self.is_input_continuous:
161
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
162
+ elif self.is_input_vectorized:
163
+ self.norm_out = nn.LayerNorm(inner_dim)
164
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
165
+
166
+ def _set_attention_slice(self, slice_size):
167
+ for block in self.transformer_blocks:
168
+ block._set_attention_slice(slice_size)
169
+
170
+ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None,
171
+ return_dict: bool = True):
172
+ """
173
+ Args:
174
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
175
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
176
+ hidden_states
177
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
178
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
179
+ self-attention.
180
+ encoder_attention_mask ( `torch.LongTensor` of shape `(batch size, context)`, *optional*):
181
+ Attention mask for cross attention layer.
182
+ timestep ( `torch.long`, *optional*):
183
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
184
+ return_dict (`bool`, *optional*, defaults to `True`):
185
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
186
+
187
+ Returns:
188
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
189
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
190
+ tensor.
191
+ """
192
+ # 1. Input
193
+ if self.is_input_continuous:
194
+ batch, channel, height, weight = hidden_states.shape
195
+ residual = hidden_states
196
+ hidden_states = self.norm(hidden_states)
197
+ hidden_states = self.proj_in(hidden_states)
198
+ inner_dim = hidden_states.shape[1]
199
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
200
+ elif self.is_input_vectorized:
201
+ hidden_states = self.latent_image_embedding(hidden_states)
202
+
203
+ # 2. Blocks
204
+ for block in self.transformer_blocks:
205
+ hidden_states = block(hidden_states, context=encoder_hidden_states, mask=encoder_attention_mask,
206
+ timestep=timestep)
207
+
208
+ # 3. Output
209
+ if self.is_input_continuous:
210
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
211
+ hidden_states = self.proj_out(hidden_states)
212
+ output = hidden_states + residual
213
+ elif self.is_input_vectorized:
214
+ hidden_states = self.norm_out(hidden_states)
215
+ logits = self.out(hidden_states)
216
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
217
+ logits = logits.permute(0, 2, 1)
218
+
219
+ # log(p(x_0))
220
+ output = F.log_softmax(logits.double(), dim=1).float()
221
+
222
+ if not return_dict:
223
+ return (output,)
224
+
225
+ return Transformer2DModelOutput(sample=output)
226
+
227
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
228
+ for block in self.transformer_blocks:
229
+ block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
230
+
231
+
232
+ class AttentionBlock(nn.Module):
233
+ """
234
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
235
+ to the N-d case.
236
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
237
+ Uses three q, k, v linear layers to compute attention.
238
+
239
+ Parameters:
240
+ channels (`int`): The number of channels in the input and output.
241
+ num_head_channels (`int`, *optional*):
242
+ The number of channels in each head. If None, then `num_heads` = 1.
243
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
244
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
245
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ channels: int,
251
+ num_head_channels: Optional[int] = None,
252
+ norm_num_groups: int = 32,
253
+ rescale_output_factor: float = 1.0,
254
+ eps: float = 1e-5,
255
+ ):
256
+ super().__init__()
257
+ self.channels = channels
258
+
259
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
260
+ self.num_head_size = num_head_channels
261
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
262
+
263
+ # define q,k,v as linear layers
264
+ self.query = nn.Linear(channels, channels)
265
+ self.key = nn.Linear(channels, channels)
266
+ self.value = nn.Linear(channels, channels)
267
+
268
+ self.rescale_output_factor = rescale_output_factor
269
+ self.proj_attn = nn.Linear(channels, channels, 1)
270
+
271
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
272
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
273
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
274
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
275
+ return new_projection
276
+
277
+ def forward(self, hidden_states):
278
+ residual = hidden_states
279
+ batch, channel, height, width = hidden_states.shape
280
+
281
+ # norm
282
+ hidden_states = self.group_norm(hidden_states)
283
+
284
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
285
+
286
+ # proj to q, k, v
287
+ query_proj = self.query(hidden_states)
288
+ key_proj = self.key(hidden_states)
289
+ value_proj = self.value(hidden_states)
290
+
291
+ # transpose
292
+ query_states = self.transpose_for_scores(query_proj)
293
+ key_states = self.transpose_for_scores(key_proj)
294
+ value_states = self.transpose_for_scores(value_proj)
295
+
296
+ # get scores
297
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
298
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
299
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
300
+
301
+ # compute attention output
302
+ hidden_states = torch.matmul(attention_probs, value_states)
303
+
304
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
305
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
306
+ hidden_states = hidden_states.view(new_hidden_states_shape)
307
+
308
+ # compute next hidden_states
309
+ hidden_states = self.proj_attn(hidden_states)
310
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
311
+
312
+ # res connect and rescale
313
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
314
+ return hidden_states
315
+
316
+
317
+ class BasicTransformerBlock(nn.Module):
318
+ r"""
319
+ A basic Transformer block.
320
+
321
+ Parameters:
322
+ dim (`int`): The number of channels in the input and output.
323
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
324
+ attention_head_dim (`int`): The number of channels in each head.
325
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
326
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
327
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
328
+ num_embeds_ada_norm (:
329
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
330
+ attention_bias (:
331
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ dim: int,
337
+ num_attention_heads: int,
338
+ attention_head_dim: int,
339
+ dropout=0.0,
340
+ cross_attention_dim: Optional[int] = None,
341
+ activation_fn: str = "geglu",
342
+ num_embeds_ada_norm: Optional[int] = None,
343
+ attention_bias: bool = False,
344
+ ):
345
+ super().__init__()
346
+ self.attn1 = CrossAttention(
347
+ query_dim=dim,
348
+ heads=num_attention_heads,
349
+ dim_head=attention_head_dim,
350
+ dropout=dropout,
351
+ bias=attention_bias,
352
+ ) # is a self-attention
353
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
354
+ self.attn2 = CrossAttention(
355
+ query_dim=dim,
356
+ cross_attention_dim=cross_attention_dim,
357
+ heads=num_attention_heads,
358
+ dim_head=attention_head_dim,
359
+ dropout=dropout,
360
+ bias=attention_bias,
361
+ ) # is self-attn if context is none
362
+
363
+ # layer norms
364
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
365
+ if self.use_ada_layer_norm:
366
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
367
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
368
+ else:
369
+ self.norm1 = nn.LayerNorm(dim)
370
+ self.norm2 = nn.LayerNorm(dim)
371
+ self.norm3 = nn.LayerNorm(dim)
372
+
373
+ def _set_attention_slice(self, slice_size):
374
+ self.attn1._slice_size = slice_size
375
+ self.attn2._slice_size = slice_size
376
+
377
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
378
+ if not is_xformers_available():
379
+ print("Here is how to install it")
380
+ raise ModuleNotFoundError(
381
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
382
+ " xformers",
383
+ name="xformers",
384
+ )
385
+ elif not torch.cuda.is_available():
386
+ raise ValueError(
387
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
388
+ " available for GPU "
389
+ )
390
+ else:
391
+ try:
392
+ # Make sure we can run the memory efficient attention
393
+ _ = xformers.ops.memory_efficient_attention(
394
+ torch.randn((1, 2, 40), device="cuda"),
395
+ torch.randn((1, 2, 40), device="cuda"),
396
+ torch.randn((1, 2, 40), device="cuda"),
397
+ )
398
+ except Exception as e:
399
+ raise e
400
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
401
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
402
+
403
+ def forward(self, hidden_states, context=None, mask=None, timestep=None):
404
+ # 1. Self-Attention
405
+ norm_hidden_states = (
406
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
407
+ )
408
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
409
+
410
+ # 2. Cross-Attention
411
+ norm_hidden_states = (
412
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
413
+ )
414
+ hidden_states = self.attn2(norm_hidden_states, context=context, mask=mask) + hidden_states
415
+
416
+ # 3. Feed-forward
417
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
418
+
419
+ return hidden_states
420
+
421
+
422
+ class CrossAttention(nn.Module):
423
+ r"""
424
+ A cross attention layer.
425
+
426
+ Parameters:
427
+ query_dim (`int`): The number of channels in the query.
428
+ cross_attention_dim (`int`, *optional*):
429
+ The number of channels in the context. If not given, defaults to `query_dim`.
430
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
431
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
432
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
433
+ bias (`bool`, *optional*, defaults to False):
434
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
435
+ """
436
+
437
+ def __init__(
438
+ self,
439
+ query_dim: int,
440
+ cross_attention_dim: Optional[int] = None,
441
+ heads: int = 8,
442
+ dim_head: int = 64,
443
+ dropout: float = 0.0,
444
+ bias=False,
445
+ ):
446
+ super().__init__()
447
+ inner_dim = dim_head * heads
448
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
449
+
450
+ self.scale = dim_head ** -0.5
451
+ self.heads = heads
452
+ # for slice_size > 0 the attention score computation
453
+ # is split across the batch axis to save memory
454
+ # You can set slice_size with `set_attention_slice`
455
+ self._slice_size = None
456
+ self._use_memory_efficient_attention_xformers = False
457
+
458
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
459
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
460
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
461
+
462
+ self.to_out = nn.ModuleList([])
463
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
464
+ self.to_out.append(nn.Dropout(dropout))
465
+
466
+ def reshape_heads_to_batch_dim(self, tensor):
467
+ batch_size, seq_len, dim = tensor.shape
468
+ head_size = self.heads
469
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
470
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
471
+ return tensor
472
+
473
+ def reshape_batch_dim_to_heads(self, tensor):
474
+ batch_size, seq_len, dim = tensor.shape
475
+ head_size = self.heads
476
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
477
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
478
+ return tensor
479
+
480
+ def forward(self, hidden_states, context=None, mask=None):
481
+ batch_size, sequence_length, _ = hidden_states.shape
482
+
483
+ query = self.to_q(hidden_states)
484
+ context = context if context is not None else hidden_states
485
+ key = self.to_k(context)
486
+ value = self.to_v(context)
487
+
488
+ dim = query.shape[-1]
489
+
490
+ query = self.reshape_heads_to_batch_dim(query)
491
+ key = self.reshape_heads_to_batch_dim(key)
492
+ value = self.reshape_heads_to_batch_dim(value)
493
+ mask = mask.repeat_interleave(self.heads, dim=0).unsqueeze(1) if mask is not None else None
494
+
495
+ # attention, what we cannot get enough of
496
+ if self._use_memory_efficient_attention_xformers:
497
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value)
498
+ else:
499
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
500
+ hidden_states = self._attention(query, key, value, mask)
501
+ else:
502
+ assert mask is None, "masking is not supported for sliced attention"
503
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
504
+
505
+ # linear proj
506
+ hidden_states = self.to_out[0](hidden_states)
507
+ # dropout
508
+ hidden_states = self.to_out[1](hidden_states)
509
+ return hidden_states
510
+
511
+ def _attention(self, query, key, value, mask):
512
+ # TODO: use baddbmm for better performance
513
+ if query.device.type == "mps":
514
+ # Better performance on mps (~20-25%)
515
+ attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
516
+ else:
517
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
518
+ attention_scores = attention_scores.masked_fill(mask.expand(attention_scores.shape), value=float("-inf")) \
519
+ if mask is not None else attention_scores
520
+ attention_probs = attention_scores.softmax(dim=-1)
521
+ # compute attention output
522
+
523
+ if query.device.type == "mps":
524
+ hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
525
+ else:
526
+ hidden_states = torch.matmul(attention_probs, value)
527
+
528
+ # reshape hidden_states
529
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
530
+ return hidden_states
531
+
532
+ def _sliced_attention(self, query, key, value, sequence_length, dim):
533
+ batch_size_attention = query.shape[0]
534
+ hidden_states = torch.zeros(
535
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
536
+ )
537
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
538
+ for i in range(hidden_states.shape[0] // slice_size):
539
+ start_idx = i * slice_size
540
+ end_idx = (i + 1) * slice_size
541
+ if query.device.type == "mps":
542
+ # Better performance on mps (~20-25%)
543
+ attn_slice = (
544
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
545
+ * self.scale
546
+ )
547
+ else:
548
+ attn_slice = (
549
+ torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
550
+ ) # TODO: use baddbmm for better performance
551
+ attn_slice = attn_slice.softmax(dim=-1)
552
+ if query.device.type == "mps":
553
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
554
+ else:
555
+ attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
556
+
557
+ hidden_states[start_idx:end_idx] = attn_slice
558
+
559
+ # reshape hidden_states
560
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
561
+ return hidden_states
562
+
563
+ def _memory_efficient_attention_xformers(self, query, key, value):
564
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
565
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
566
+ return hidden_states
567
+
568
+
569
+ class FeedForward(nn.Module):
570
+ r"""
571
+ A feed-forward layer.
572
+
573
+ Parameters:
574
+ dim (`int`): The number of channels in the input.
575
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
576
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
577
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
578
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ dim: int,
584
+ dim_out: Optional[int] = None,
585
+ mult: int = 4,
586
+ dropout: float = 0.0,
587
+ activation_fn: str = "geglu",
588
+ ):
589
+ super().__init__()
590
+ inner_dim = int(dim * mult)
591
+ dim_out = dim_out if dim_out is not None else dim
592
+
593
+ if activation_fn == "geglu":
594
+ geglu = GEGLU(dim, inner_dim)
595
+ elif activation_fn == "geglu-approximate":
596
+ geglu = ApproximateGELU(dim, inner_dim)
597
+
598
+ self.net = nn.ModuleList([])
599
+ # project in
600
+ self.net.append(geglu)
601
+ # project dropout
602
+ self.net.append(nn.Dropout(dropout))
603
+ # project out
604
+ self.net.append(nn.Linear(inner_dim, dim_out))
605
+
606
+ def forward(self, hidden_states):
607
+ for module in self.net:
608
+ hidden_states = module(hidden_states)
609
+ return hidden_states
610
+
611
+
612
+ # feedforward
613
+ class GEGLU(nn.Module):
614
+ r"""
615
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
616
+
617
+ Parameters:
618
+ dim_in (`int`): The number of channels in the input.
619
+ dim_out (`int`): The number of channels in the output.
620
+ """
621
+
622
+ def __init__(self, dim_in: int, dim_out: int):
623
+ super().__init__()
624
+ self.proj = nn.Linear(dim_in, dim_out * 2)
625
+
626
+ def gelu(self, gate):
627
+ if gate.device.type != "mps":
628
+ return F.gelu(gate)
629
+ # mps: gelu is not implemented for float16
630
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
631
+
632
+ def forward(self, hidden_states):
633
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
634
+ return hidden_states * self.gelu(gate)
635
+
636
+
637
+ class ApproximateGELU(nn.Module):
638
+ """
639
+ The approximate form of Gaussian Error Linear Unit (GELU)
640
+
641
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
642
+ """
643
+
644
+ def __init__(self, dim_in: int, dim_out: int):
645
+ super().__init__()
646
+ self.proj = nn.Linear(dim_in, dim_out)
647
+
648
+ def forward(self, x):
649
+ x = self.proj(x)
650
+ return x * torch.sigmoid(1.702 * x)
651
+
652
+
653
+ class AdaLayerNorm(nn.Module):
654
+ """
655
+ Norm layer modified to incorporate timestep embeddings.
656
+ """
657
+
658
+ def __init__(self, embedding_dim, num_embeddings):
659
+ super().__init__()
660
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
661
+ self.silu = nn.SiLU()
662
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
663
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
664
+
665
+ def forward(self, x, timestep):
666
+ emb = self.linear(self.silu(self.emb(timestep)))
667
+ scale, shift = torch.chunk(emb, 2)
668
+ x = self.norm(x) * (1 + scale) + shift
669
+ return x
models/diffusers_override/unet_2d_blocks.py ADDED
@@ -0,0 +1,1602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .attention import AttentionBlock, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
20
+
21
+
22
+ def get_down_block(
23
+ down_block_type,
24
+ num_layers,
25
+ in_channels,
26
+ out_channels,
27
+ temb_channels,
28
+ add_downsample,
29
+ resnet_eps,
30
+ resnet_act_fn,
31
+ attn_num_head_channels,
32
+ resnet_groups=None,
33
+ cross_attention_dim=None,
34
+ downsample_padding=None,
35
+ ):
36
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
37
+ if down_block_type == "DownBlock2D":
38
+ return DownBlock2D(
39
+ num_layers=num_layers,
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ temb_channels=temb_channels,
43
+ add_downsample=add_downsample,
44
+ resnet_eps=resnet_eps,
45
+ resnet_act_fn=resnet_act_fn,
46
+ resnet_groups=resnet_groups,
47
+ downsample_padding=downsample_padding,
48
+ )
49
+ elif down_block_type == "AttnDownBlock2D":
50
+ return AttnDownBlock2D(
51
+ num_layers=num_layers,
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ temb_channels=temb_channels,
55
+ add_downsample=add_downsample,
56
+ resnet_eps=resnet_eps,
57
+ resnet_act_fn=resnet_act_fn,
58
+ resnet_groups=resnet_groups,
59
+ downsample_padding=downsample_padding,
60
+ attn_num_head_channels=attn_num_head_channels,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock2D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
65
+ return CrossAttnDownBlock2D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ resnet_groups=resnet_groups,
74
+ downsample_padding=downsample_padding,
75
+ cross_attention_dim=cross_attention_dim,
76
+ attn_num_head_channels=attn_num_head_channels,
77
+ )
78
+ elif down_block_type == "SkipDownBlock2D":
79
+ return SkipDownBlock2D(
80
+ num_layers=num_layers,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ downsample_padding=downsample_padding,
88
+ )
89
+ elif down_block_type == "AttnSkipDownBlock2D":
90
+ return AttnSkipDownBlock2D(
91
+ num_layers=num_layers,
92
+ in_channels=in_channels,
93
+ out_channels=out_channels,
94
+ temb_channels=temb_channels,
95
+ add_downsample=add_downsample,
96
+ resnet_eps=resnet_eps,
97
+ resnet_act_fn=resnet_act_fn,
98
+ downsample_padding=downsample_padding,
99
+ attn_num_head_channels=attn_num_head_channels,
100
+ )
101
+ elif down_block_type == "DownEncoderBlock2D":
102
+ return DownEncoderBlock2D(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ add_downsample=add_downsample,
107
+ resnet_eps=resnet_eps,
108
+ resnet_act_fn=resnet_act_fn,
109
+ resnet_groups=resnet_groups,
110
+ downsample_padding=downsample_padding,
111
+ )
112
+ elif down_block_type == "AttnDownEncoderBlock2D":
113
+ return AttnDownEncoderBlock2D(
114
+ num_layers=num_layers,
115
+ in_channels=in_channels,
116
+ out_channels=out_channels,
117
+ add_downsample=add_downsample,
118
+ resnet_eps=resnet_eps,
119
+ resnet_act_fn=resnet_act_fn,
120
+ resnet_groups=resnet_groups,
121
+ downsample_padding=downsample_padding,
122
+ attn_num_head_channels=attn_num_head_channels,
123
+ )
124
+ raise ValueError(f"{down_block_type} does not exist.")
125
+
126
+
127
+ def get_up_block(
128
+ up_block_type,
129
+ num_layers,
130
+ in_channels,
131
+ out_channels,
132
+ prev_output_channel,
133
+ temb_channels,
134
+ add_upsample,
135
+ resnet_eps,
136
+ resnet_act_fn,
137
+ attn_num_head_channels,
138
+ resnet_groups=None,
139
+ cross_attention_dim=None,
140
+ ):
141
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
142
+ if up_block_type == "UpBlock2D":
143
+ return UpBlock2D(
144
+ num_layers=num_layers,
145
+ in_channels=in_channels,
146
+ out_channels=out_channels,
147
+ prev_output_channel=prev_output_channel,
148
+ temb_channels=temb_channels,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ )
154
+ elif up_block_type == "CrossAttnUpBlock2D":
155
+ if cross_attention_dim is None:
156
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
157
+ return CrossAttnUpBlock2D(
158
+ num_layers=num_layers,
159
+ in_channels=in_channels,
160
+ out_channels=out_channels,
161
+ prev_output_channel=prev_output_channel,
162
+ temb_channels=temb_channels,
163
+ add_upsample=add_upsample,
164
+ resnet_eps=resnet_eps,
165
+ resnet_act_fn=resnet_act_fn,
166
+ resnet_groups=resnet_groups,
167
+ cross_attention_dim=cross_attention_dim,
168
+ attn_num_head_channels=attn_num_head_channels,
169
+ )
170
+ elif up_block_type == "AttnUpBlock2D":
171
+ return AttnUpBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ prev_output_channel=prev_output_channel,
176
+ temb_channels=temb_channels,
177
+ add_upsample=add_upsample,
178
+ resnet_eps=resnet_eps,
179
+ resnet_act_fn=resnet_act_fn,
180
+ resnet_groups=resnet_groups,
181
+ attn_num_head_channels=attn_num_head_channels,
182
+ )
183
+ elif up_block_type == "SkipUpBlock2D":
184
+ return SkipUpBlock2D(
185
+ num_layers=num_layers,
186
+ in_channels=in_channels,
187
+ out_channels=out_channels,
188
+ prev_output_channel=prev_output_channel,
189
+ temb_channels=temb_channels,
190
+ add_upsample=add_upsample,
191
+ resnet_eps=resnet_eps,
192
+ resnet_act_fn=resnet_act_fn,
193
+ )
194
+ elif up_block_type == "AttnSkipUpBlock2D":
195
+ return AttnSkipUpBlock2D(
196
+ num_layers=num_layers,
197
+ in_channels=in_channels,
198
+ out_channels=out_channels,
199
+ prev_output_channel=prev_output_channel,
200
+ temb_channels=temb_channels,
201
+ add_upsample=add_upsample,
202
+ resnet_eps=resnet_eps,
203
+ resnet_act_fn=resnet_act_fn,
204
+ attn_num_head_channels=attn_num_head_channels,
205
+ )
206
+ elif up_block_type == "UpDecoderBlock2D":
207
+ return UpDecoderBlock2D(
208
+ num_layers=num_layers,
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ add_upsample=add_upsample,
212
+ resnet_eps=resnet_eps,
213
+ resnet_act_fn=resnet_act_fn,
214
+ resnet_groups=resnet_groups,
215
+ )
216
+ elif up_block_type == "AttnUpDecoderBlock2D":
217
+ return AttnUpDecoderBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ add_upsample=add_upsample,
222
+ resnet_eps=resnet_eps,
223
+ resnet_act_fn=resnet_act_fn,
224
+ resnet_groups=resnet_groups,
225
+ attn_num_head_channels=attn_num_head_channels,
226
+ )
227
+ raise ValueError(f"{up_block_type} does not exist.")
228
+
229
+
230
+ class UNetMidBlock2D(nn.Module):
231
+ def __init__(
232
+ self,
233
+ in_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default",
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ resnet_pre_norm: bool = True,
242
+ attn_num_head_channels=1,
243
+ attention_type="default",
244
+ output_scale_factor=1.0,
245
+ **kwargs,
246
+ ):
247
+ super().__init__()
248
+
249
+ self.attention_type = attention_type
250
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
251
+
252
+ # there is always at least one resnet
253
+ resnets = [
254
+ ResnetBlock2D(
255
+ in_channels=in_channels,
256
+ out_channels=in_channels,
257
+ temb_channels=temb_channels,
258
+ eps=resnet_eps,
259
+ groups=resnet_groups,
260
+ dropout=dropout,
261
+ time_embedding_norm=resnet_time_scale_shift,
262
+ non_linearity=resnet_act_fn,
263
+ output_scale_factor=output_scale_factor,
264
+ pre_norm=resnet_pre_norm,
265
+ )
266
+ ]
267
+ attentions = []
268
+
269
+ for _ in range(num_layers):
270
+ attentions.append(
271
+ AttentionBlock(
272
+ in_channels,
273
+ num_head_channels=attn_num_head_channels,
274
+ rescale_output_factor=output_scale_factor,
275
+ eps=resnet_eps,
276
+ norm_num_groups=resnet_groups,
277
+ )
278
+ )
279
+ resnets.append(
280
+ ResnetBlock2D(
281
+ in_channels=in_channels,
282
+ out_channels=in_channels,
283
+ temb_channels=temb_channels,
284
+ eps=resnet_eps,
285
+ groups=resnet_groups,
286
+ dropout=dropout,
287
+ time_embedding_norm=resnet_time_scale_shift,
288
+ non_linearity=resnet_act_fn,
289
+ output_scale_factor=output_scale_factor,
290
+ pre_norm=resnet_pre_norm,
291
+ )
292
+ )
293
+
294
+ self.attentions = nn.ModuleList(attentions)
295
+ self.resnets = nn.ModuleList(resnets)
296
+
297
+ def forward(self, hidden_states, temb=None, encoder_states=None):
298
+ hidden_states = self.resnets[0](hidden_states, temb)
299
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
300
+ if self.attention_type == "default":
301
+ hidden_states = attn(hidden_states)
302
+ else:
303
+ hidden_states = attn(hidden_states, encoder_states)
304
+ hidden_states = resnet(hidden_states, temb)
305
+
306
+ return hidden_states
307
+
308
+
309
+ class UNetMidBlock2DCrossAttn(nn.Module):
310
+ def __init__(
311
+ self,
312
+ in_channels: int,
313
+ temb_channels: int,
314
+ dropout: float = 0.0,
315
+ num_layers: int = 1,
316
+ resnet_eps: float = 1e-6,
317
+ resnet_time_scale_shift: str = "default",
318
+ resnet_act_fn: str = "swish",
319
+ resnet_groups: int = 32,
320
+ resnet_pre_norm: bool = True,
321
+ attn_num_head_channels=1,
322
+ attention_type="default",
323
+ output_scale_factor=1.0,
324
+ cross_attention_dim=1280,
325
+ **kwargs,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.attention_type = attention_type
330
+ self.attn_num_head_channels = attn_num_head_channels
331
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
332
+
333
+ # there is always at least one resnet
334
+ resnets = [
335
+ ResnetBlock2D(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ temb_channels=temb_channels,
339
+ eps=resnet_eps,
340
+ groups=resnet_groups,
341
+ dropout=dropout,
342
+ time_embedding_norm=resnet_time_scale_shift,
343
+ non_linearity=resnet_act_fn,
344
+ output_scale_factor=output_scale_factor,
345
+ pre_norm=resnet_pre_norm,
346
+ )
347
+ ]
348
+ attentions = []
349
+
350
+ for _ in range(num_layers):
351
+ attentions.append(
352
+ Transformer2DModel(
353
+ attn_num_head_channels,
354
+ in_channels // attn_num_head_channels,
355
+ in_channels=in_channels,
356
+ num_layers=1,
357
+ cross_attention_dim=cross_attention_dim,
358
+ norm_num_groups=resnet_groups,
359
+ )
360
+ )
361
+ resnets.append(
362
+ ResnetBlock2D(
363
+ in_channels=in_channels,
364
+ out_channels=in_channels,
365
+ temb_channels=temb_channels,
366
+ eps=resnet_eps,
367
+ groups=resnet_groups,
368
+ dropout=dropout,
369
+ time_embedding_norm=resnet_time_scale_shift,
370
+ non_linearity=resnet_act_fn,
371
+ output_scale_factor=output_scale_factor,
372
+ pre_norm=resnet_pre_norm,
373
+ )
374
+ )
375
+
376
+ self.attentions = nn.ModuleList(attentions)
377
+ self.resnets = nn.ModuleList(resnets)
378
+
379
+ def set_attention_slice(self, slice_size):
380
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
381
+ raise ValueError(
382
+ f"Make sure slice_size {slice_size} is a divisor of "
383
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
384
+ )
385
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
386
+ raise ValueError(
387
+ f"Chunk_size {slice_size} has to be smaller or equal to "
388
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
389
+ )
390
+
391
+ for attn in self.attentions:
392
+ attn._set_attention_slice(slice_size)
393
+
394
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
395
+ for attn in self.attentions:
396
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
397
+
398
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, encoder_attention_mask=None):
399
+ hidden_states = self.resnets[0](hidden_states, temb)
400
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
401
+ hidden_states = attn(hidden_states, encoder_hidden_states, encoder_attention_mask).sample
402
+ hidden_states = resnet(hidden_states, temb)
403
+
404
+ return hidden_states
405
+
406
+
407
+ class AttnDownBlock2D(nn.Module):
408
+ def __init__(
409
+ self,
410
+ in_channels: int,
411
+ out_channels: int,
412
+ temb_channels: int,
413
+ dropout: float = 0.0,
414
+ num_layers: int = 1,
415
+ resnet_eps: float = 1e-6,
416
+ resnet_time_scale_shift: str = "default",
417
+ resnet_act_fn: str = "swish",
418
+ resnet_groups: int = 32,
419
+ resnet_pre_norm: bool = True,
420
+ attn_num_head_channels=1,
421
+ attention_type="default",
422
+ output_scale_factor=1.0,
423
+ downsample_padding=1,
424
+ add_downsample=True,
425
+ ):
426
+ super().__init__()
427
+ resnets = []
428
+ attentions = []
429
+
430
+ self.attention_type = attention_type
431
+
432
+ for i in range(num_layers):
433
+ in_channels = in_channels if i == 0 else out_channels
434
+ resnets.append(
435
+ ResnetBlock2D(
436
+ in_channels=in_channels,
437
+ out_channels=out_channels,
438
+ temb_channels=temb_channels,
439
+ eps=resnet_eps,
440
+ groups=resnet_groups,
441
+ dropout=dropout,
442
+ time_embedding_norm=resnet_time_scale_shift,
443
+ non_linearity=resnet_act_fn,
444
+ output_scale_factor=output_scale_factor,
445
+ pre_norm=resnet_pre_norm,
446
+ )
447
+ )
448
+ attentions.append(
449
+ AttentionBlock(
450
+ out_channels,
451
+ num_head_channels=attn_num_head_channels,
452
+ rescale_output_factor=output_scale_factor,
453
+ eps=resnet_eps,
454
+ norm_num_groups=resnet_groups,
455
+ )
456
+ )
457
+
458
+ self.attentions = nn.ModuleList(attentions)
459
+ self.resnets = nn.ModuleList(resnets)
460
+
461
+ if add_downsample:
462
+ self.downsamplers = nn.ModuleList(
463
+ [
464
+ Downsample2D(
465
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
466
+ )
467
+ ]
468
+ )
469
+ else:
470
+ self.downsamplers = None
471
+
472
+ def forward(self, hidden_states, temb=None):
473
+ output_states = ()
474
+
475
+ for resnet, attn in zip(self.resnets, self.attentions):
476
+ hidden_states = resnet(hidden_states, temb)
477
+ hidden_states = attn(hidden_states)
478
+ output_states += (hidden_states,)
479
+
480
+ if self.downsamplers is not None:
481
+ for downsampler in self.downsamplers:
482
+ hidden_states = downsampler(hidden_states)
483
+
484
+ output_states += (hidden_states,)
485
+
486
+ return hidden_states, output_states
487
+
488
+
489
+ class CrossAttnDownBlock2D(nn.Module):
490
+ def __init__(
491
+ self,
492
+ in_channels: int,
493
+ out_channels: int,
494
+ temb_channels: int,
495
+ dropout: float = 0.0,
496
+ num_layers: int = 1,
497
+ resnet_eps: float = 1e-6,
498
+ resnet_time_scale_shift: str = "default",
499
+ resnet_act_fn: str = "swish",
500
+ resnet_groups: int = 32,
501
+ resnet_pre_norm: bool = True,
502
+ attn_num_head_channels=1,
503
+ cross_attention_dim=1280,
504
+ attention_type="default",
505
+ output_scale_factor=1.0,
506
+ downsample_padding=1,
507
+ add_downsample=True,
508
+ ):
509
+ super().__init__()
510
+ resnets = []
511
+ attentions = []
512
+
513
+ self.attention_type = attention_type
514
+ self.attn_num_head_channels = attn_num_head_channels
515
+
516
+ for i in range(num_layers):
517
+ in_channels = in_channels if i == 0 else out_channels
518
+ resnets.append(
519
+ ResnetBlock2D(
520
+ in_channels=in_channels,
521
+ out_channels=out_channels,
522
+ temb_channels=temb_channels,
523
+ eps=resnet_eps,
524
+ groups=resnet_groups,
525
+ dropout=dropout,
526
+ time_embedding_norm=resnet_time_scale_shift,
527
+ non_linearity=resnet_act_fn,
528
+ output_scale_factor=output_scale_factor,
529
+ pre_norm=resnet_pre_norm,
530
+ )
531
+ )
532
+ attentions.append(
533
+ Transformer2DModel(
534
+ attn_num_head_channels,
535
+ out_channels // attn_num_head_channels,
536
+ in_channels=out_channels,
537
+ num_layers=1,
538
+ cross_attention_dim=cross_attention_dim,
539
+ norm_num_groups=resnet_groups,
540
+ )
541
+ )
542
+ self.attentions = nn.ModuleList(attentions)
543
+ self.resnets = nn.ModuleList(resnets)
544
+
545
+ if add_downsample:
546
+ self.downsamplers = nn.ModuleList(
547
+ [
548
+ Downsample2D(
549
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
550
+ )
551
+ ]
552
+ )
553
+ else:
554
+ self.downsamplers = None
555
+
556
+ self.gradient_checkpointing = False
557
+
558
+ def set_attention_slice(self, slice_size):
559
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
560
+ raise ValueError(
561
+ f"Make sure slice_size {slice_size} is a divisor of "
562
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
563
+ )
564
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
565
+ raise ValueError(
566
+ f"Chunk_size {slice_size} has to be smaller or equal to "
567
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
568
+ )
569
+
570
+ for attn in self.attentions:
571
+ attn._set_attention_slice(slice_size)
572
+
573
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
574
+ for attn in self.attentions:
575
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
576
+
577
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, encoder_attention_mask=None):
578
+ output_states = ()
579
+
580
+ for resnet, attn in zip(self.resnets, self.attentions):
581
+ if self.training and self.gradient_checkpointing:
582
+
583
+ def create_custom_forward(module, return_dict=None):
584
+ def custom_forward(*inputs):
585
+ if return_dict is not None:
586
+ return module(*inputs, return_dict=return_dict)
587
+ else:
588
+ return module(*inputs)
589
+
590
+ return custom_forward
591
+
592
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
593
+ hidden_states = torch.utils.checkpoint.checkpoint(
594
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
595
+ encoder_attention_mask
596
+ )[0]
597
+ else:
598
+ hidden_states = resnet(hidden_states, temb)
599
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
600
+ encoder_attention_mask=encoder_attention_mask).sample
601
+
602
+ output_states += (hidden_states,)
603
+
604
+ if self.downsamplers is not None:
605
+ for downsampler in self.downsamplers:
606
+ hidden_states = downsampler(hidden_states)
607
+
608
+ output_states += (hidden_states,)
609
+
610
+ return hidden_states, output_states
611
+
612
+
613
+ class DownBlock2D(nn.Module):
614
+ def __init__(
615
+ self,
616
+ in_channels: int,
617
+ out_channels: int,
618
+ temb_channels: int,
619
+ dropout: float = 0.0,
620
+ num_layers: int = 1,
621
+ resnet_eps: float = 1e-6,
622
+ resnet_time_scale_shift: str = "default",
623
+ resnet_act_fn: str = "swish",
624
+ resnet_groups: int = 32,
625
+ resnet_pre_norm: bool = True,
626
+ output_scale_factor=1.0,
627
+ add_downsample=True,
628
+ downsample_padding=1,
629
+ ):
630
+ super().__init__()
631
+ resnets = []
632
+
633
+ for i in range(num_layers):
634
+ in_channels = in_channels if i == 0 else out_channels
635
+ resnets.append(
636
+ ResnetBlock2D(
637
+ in_channels=in_channels,
638
+ out_channels=out_channels,
639
+ temb_channels=temb_channels,
640
+ eps=resnet_eps,
641
+ groups=resnet_groups,
642
+ dropout=dropout,
643
+ time_embedding_norm=resnet_time_scale_shift,
644
+ non_linearity=resnet_act_fn,
645
+ output_scale_factor=output_scale_factor,
646
+ pre_norm=resnet_pre_norm,
647
+ )
648
+ )
649
+
650
+ self.resnets = nn.ModuleList(resnets)
651
+
652
+ if add_downsample:
653
+ self.downsamplers = nn.ModuleList(
654
+ [
655
+ Downsample2D(
656
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
657
+ )
658
+ ]
659
+ )
660
+ else:
661
+ self.downsamplers = None
662
+
663
+ self.gradient_checkpointing = False
664
+
665
+ def forward(self, hidden_states, temb=None):
666
+ output_states = ()
667
+
668
+ for resnet in self.resnets:
669
+ if self.training and self.gradient_checkpointing:
670
+
671
+ def create_custom_forward(module):
672
+ def custom_forward(*inputs):
673
+ return module(*inputs)
674
+
675
+ return custom_forward
676
+
677
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
678
+ else:
679
+ hidden_states = resnet(hidden_states, temb)
680
+
681
+ output_states += (hidden_states,)
682
+
683
+ if self.downsamplers is not None:
684
+ for downsampler in self.downsamplers:
685
+ hidden_states = downsampler(hidden_states)
686
+
687
+ output_states += (hidden_states,)
688
+
689
+ return hidden_states, output_states
690
+
691
+
692
+ class DownEncoderBlock2D(nn.Module):
693
+ def __init__(
694
+ self,
695
+ in_channels: int,
696
+ out_channels: int,
697
+ dropout: float = 0.0,
698
+ num_layers: int = 1,
699
+ resnet_eps: float = 1e-6,
700
+ resnet_time_scale_shift: str = "default",
701
+ resnet_act_fn: str = "swish",
702
+ resnet_groups: int = 32,
703
+ resnet_pre_norm: bool = True,
704
+ output_scale_factor=1.0,
705
+ add_downsample=True,
706
+ downsample_padding=1,
707
+ ):
708
+ super().__init__()
709
+ resnets = []
710
+
711
+ for i in range(num_layers):
712
+ in_channels = in_channels if i == 0 else out_channels
713
+ resnets.append(
714
+ ResnetBlock2D(
715
+ in_channels=in_channels,
716
+ out_channels=out_channels,
717
+ temb_channels=None,
718
+ eps=resnet_eps,
719
+ groups=resnet_groups,
720
+ dropout=dropout,
721
+ time_embedding_norm=resnet_time_scale_shift,
722
+ non_linearity=resnet_act_fn,
723
+ output_scale_factor=output_scale_factor,
724
+ pre_norm=resnet_pre_norm,
725
+ )
726
+ )
727
+
728
+ self.resnets = nn.ModuleList(resnets)
729
+
730
+ if add_downsample:
731
+ self.downsamplers = nn.ModuleList(
732
+ [
733
+ Downsample2D(
734
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
735
+ )
736
+ ]
737
+ )
738
+ else:
739
+ self.downsamplers = None
740
+
741
+ def forward(self, hidden_states):
742
+ for resnet in self.resnets:
743
+ hidden_states = resnet(hidden_states, temb=None)
744
+
745
+ if self.downsamplers is not None:
746
+ for downsampler in self.downsamplers:
747
+ hidden_states = downsampler(hidden_states)
748
+
749
+ return hidden_states
750
+
751
+
752
+ class AttnDownEncoderBlock2D(nn.Module):
753
+ def __init__(
754
+ self,
755
+ in_channels: int,
756
+ out_channels: int,
757
+ dropout: float = 0.0,
758
+ num_layers: int = 1,
759
+ resnet_eps: float = 1e-6,
760
+ resnet_time_scale_shift: str = "default",
761
+ resnet_act_fn: str = "swish",
762
+ resnet_groups: int = 32,
763
+ resnet_pre_norm: bool = True,
764
+ attn_num_head_channels=1,
765
+ output_scale_factor=1.0,
766
+ add_downsample=True,
767
+ downsample_padding=1,
768
+ ):
769
+ super().__init__()
770
+ resnets = []
771
+ attentions = []
772
+
773
+ for i in range(num_layers):
774
+ in_channels = in_channels if i == 0 else out_channels
775
+ resnets.append(
776
+ ResnetBlock2D(
777
+ in_channels=in_channels,
778
+ out_channels=out_channels,
779
+ temb_channels=None,
780
+ eps=resnet_eps,
781
+ groups=resnet_groups,
782
+ dropout=dropout,
783
+ time_embedding_norm=resnet_time_scale_shift,
784
+ non_linearity=resnet_act_fn,
785
+ output_scale_factor=output_scale_factor,
786
+ pre_norm=resnet_pre_norm,
787
+ )
788
+ )
789
+ attentions.append(
790
+ AttentionBlock(
791
+ out_channels,
792
+ num_head_channels=attn_num_head_channels,
793
+ rescale_output_factor=output_scale_factor,
794
+ eps=resnet_eps,
795
+ norm_num_groups=resnet_groups,
796
+ )
797
+ )
798
+
799
+ self.attentions = nn.ModuleList(attentions)
800
+ self.resnets = nn.ModuleList(resnets)
801
+
802
+ if add_downsample:
803
+ self.downsamplers = nn.ModuleList(
804
+ [
805
+ Downsample2D(
806
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
807
+ )
808
+ ]
809
+ )
810
+ else:
811
+ self.downsamplers = None
812
+
813
+ def forward(self, hidden_states):
814
+ for resnet, attn in zip(self.resnets, self.attentions):
815
+ hidden_states = resnet(hidden_states, temb=None)
816
+ hidden_states = attn(hidden_states)
817
+
818
+ if self.downsamplers is not None:
819
+ for downsampler in self.downsamplers:
820
+ hidden_states = downsampler(hidden_states)
821
+
822
+ return hidden_states
823
+
824
+
825
+ class AttnSkipDownBlock2D(nn.Module):
826
+ def __init__(
827
+ self,
828
+ in_channels: int,
829
+ out_channels: int,
830
+ temb_channels: int,
831
+ dropout: float = 0.0,
832
+ num_layers: int = 1,
833
+ resnet_eps: float = 1e-6,
834
+ resnet_time_scale_shift: str = "default",
835
+ resnet_act_fn: str = "swish",
836
+ resnet_pre_norm: bool = True,
837
+ attn_num_head_channels=1,
838
+ attention_type="default",
839
+ output_scale_factor=np.sqrt(2.0),
840
+ downsample_padding=1,
841
+ add_downsample=True,
842
+ ):
843
+ super().__init__()
844
+ self.attentions = nn.ModuleList([])
845
+ self.resnets = nn.ModuleList([])
846
+
847
+ self.attention_type = attention_type
848
+
849
+ for i in range(num_layers):
850
+ in_channels = in_channels if i == 0 else out_channels
851
+ self.resnets.append(
852
+ ResnetBlock2D(
853
+ in_channels=in_channels,
854
+ out_channels=out_channels,
855
+ temb_channels=temb_channels,
856
+ eps=resnet_eps,
857
+ groups=min(in_channels // 4, 32),
858
+ groups_out=min(out_channels // 4, 32),
859
+ dropout=dropout,
860
+ time_embedding_norm=resnet_time_scale_shift,
861
+ non_linearity=resnet_act_fn,
862
+ output_scale_factor=output_scale_factor,
863
+ pre_norm=resnet_pre_norm,
864
+ )
865
+ )
866
+ self.attentions.append(
867
+ AttentionBlock(
868
+ out_channels,
869
+ num_head_channels=attn_num_head_channels,
870
+ rescale_output_factor=output_scale_factor,
871
+ eps=resnet_eps,
872
+ )
873
+ )
874
+
875
+ if add_downsample:
876
+ self.resnet_down = ResnetBlock2D(
877
+ in_channels=out_channels,
878
+ out_channels=out_channels,
879
+ temb_channels=temb_channels,
880
+ eps=resnet_eps,
881
+ groups=min(out_channels // 4, 32),
882
+ dropout=dropout,
883
+ time_embedding_norm=resnet_time_scale_shift,
884
+ non_linearity=resnet_act_fn,
885
+ output_scale_factor=output_scale_factor,
886
+ pre_norm=resnet_pre_norm,
887
+ use_in_shortcut=True,
888
+ down=True,
889
+ kernel="fir",
890
+ )
891
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
892
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
893
+ else:
894
+ self.resnet_down = None
895
+ self.downsamplers = None
896
+ self.skip_conv = None
897
+
898
+ def forward(self, hidden_states, temb=None, skip_sample=None):
899
+ output_states = ()
900
+
901
+ for resnet, attn in zip(self.resnets, self.attentions):
902
+ hidden_states = resnet(hidden_states, temb)
903
+ hidden_states = attn(hidden_states)
904
+ output_states += (hidden_states,)
905
+
906
+ if self.downsamplers is not None:
907
+ hidden_states = self.resnet_down(hidden_states, temb)
908
+ for downsampler in self.downsamplers:
909
+ skip_sample = downsampler(skip_sample)
910
+
911
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
912
+
913
+ output_states += (hidden_states,)
914
+
915
+ return hidden_states, output_states, skip_sample
916
+
917
+
918
+ class SkipDownBlock2D(nn.Module):
919
+ def __init__(
920
+ self,
921
+ in_channels: int,
922
+ out_channels: int,
923
+ temb_channels: int,
924
+ dropout: float = 0.0,
925
+ num_layers: int = 1,
926
+ resnet_eps: float = 1e-6,
927
+ resnet_time_scale_shift: str = "default",
928
+ resnet_act_fn: str = "swish",
929
+ resnet_pre_norm: bool = True,
930
+ output_scale_factor=np.sqrt(2.0),
931
+ add_downsample=True,
932
+ downsample_padding=1,
933
+ ):
934
+ super().__init__()
935
+ self.resnets = nn.ModuleList([])
936
+
937
+ for i in range(num_layers):
938
+ in_channels = in_channels if i == 0 else out_channels
939
+ self.resnets.append(
940
+ ResnetBlock2D(
941
+ in_channels=in_channels,
942
+ out_channels=out_channels,
943
+ temb_channels=temb_channels,
944
+ eps=resnet_eps,
945
+ groups=min(in_channels // 4, 32),
946
+ groups_out=min(out_channels // 4, 32),
947
+ dropout=dropout,
948
+ time_embedding_norm=resnet_time_scale_shift,
949
+ non_linearity=resnet_act_fn,
950
+ output_scale_factor=output_scale_factor,
951
+ pre_norm=resnet_pre_norm,
952
+ )
953
+ )
954
+
955
+ if add_downsample:
956
+ self.resnet_down = ResnetBlock2D(
957
+ in_channels=out_channels,
958
+ out_channels=out_channels,
959
+ temb_channels=temb_channels,
960
+ eps=resnet_eps,
961
+ groups=min(out_channels // 4, 32),
962
+ dropout=dropout,
963
+ time_embedding_norm=resnet_time_scale_shift,
964
+ non_linearity=resnet_act_fn,
965
+ output_scale_factor=output_scale_factor,
966
+ pre_norm=resnet_pre_norm,
967
+ use_in_shortcut=True,
968
+ down=True,
969
+ kernel="fir",
970
+ )
971
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
972
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
973
+ else:
974
+ self.resnet_down = None
975
+ self.downsamplers = None
976
+ self.skip_conv = None
977
+
978
+ def forward(self, hidden_states, temb=None, skip_sample=None):
979
+ output_states = ()
980
+
981
+ for resnet in self.resnets:
982
+ hidden_states = resnet(hidden_states, temb)
983
+ output_states += (hidden_states,)
984
+
985
+ if self.downsamplers is not None:
986
+ hidden_states = self.resnet_down(hidden_states, temb)
987
+ for downsampler in self.downsamplers:
988
+ skip_sample = downsampler(skip_sample)
989
+
990
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
991
+
992
+ output_states += (hidden_states,)
993
+
994
+ return hidden_states, output_states, skip_sample
995
+
996
+
997
+ class AttnUpBlock2D(nn.Module):
998
+ def __init__(
999
+ self,
1000
+ in_channels: int,
1001
+ prev_output_channel: int,
1002
+ out_channels: int,
1003
+ temb_channels: int,
1004
+ dropout: float = 0.0,
1005
+ num_layers: int = 1,
1006
+ resnet_eps: float = 1e-6,
1007
+ resnet_time_scale_shift: str = "default",
1008
+ resnet_act_fn: str = "swish",
1009
+ resnet_groups: int = 32,
1010
+ resnet_pre_norm: bool = True,
1011
+ attention_type="default",
1012
+ attn_num_head_channels=1,
1013
+ output_scale_factor=1.0,
1014
+ add_upsample=True,
1015
+ ):
1016
+ super().__init__()
1017
+ resnets = []
1018
+ attentions = []
1019
+
1020
+ self.attention_type = attention_type
1021
+
1022
+ for i in range(num_layers):
1023
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1024
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1025
+
1026
+ resnets.append(
1027
+ ResnetBlock2D(
1028
+ in_channels=resnet_in_channels + res_skip_channels,
1029
+ out_channels=out_channels,
1030
+ temb_channels=temb_channels,
1031
+ eps=resnet_eps,
1032
+ groups=resnet_groups,
1033
+ dropout=dropout,
1034
+ time_embedding_norm=resnet_time_scale_shift,
1035
+ non_linearity=resnet_act_fn,
1036
+ output_scale_factor=output_scale_factor,
1037
+ pre_norm=resnet_pre_norm,
1038
+ )
1039
+ )
1040
+ attentions.append(
1041
+ AttentionBlock(
1042
+ out_channels,
1043
+ num_head_channels=attn_num_head_channels,
1044
+ rescale_output_factor=output_scale_factor,
1045
+ eps=resnet_eps,
1046
+ norm_num_groups=resnet_groups,
1047
+ )
1048
+ )
1049
+
1050
+ self.attentions = nn.ModuleList(attentions)
1051
+ self.resnets = nn.ModuleList(resnets)
1052
+
1053
+ if add_upsample:
1054
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1055
+ else:
1056
+ self.upsamplers = None
1057
+
1058
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1059
+ for resnet, attn in zip(self.resnets, self.attentions):
1060
+ # pop res hidden states
1061
+ res_hidden_states = res_hidden_states_tuple[-1]
1062
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1063
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1064
+
1065
+ hidden_states = resnet(hidden_states, temb)
1066
+ hidden_states = attn(hidden_states)
1067
+
1068
+ if self.upsamplers is not None:
1069
+ for upsampler in self.upsamplers:
1070
+ hidden_states = upsampler(hidden_states)
1071
+
1072
+ return hidden_states
1073
+
1074
+
1075
+ class CrossAttnUpBlock2D(nn.Module):
1076
+ def __init__(
1077
+ self,
1078
+ in_channels: int,
1079
+ out_channels: int,
1080
+ prev_output_channel: int,
1081
+ temb_channels: int,
1082
+ dropout: float = 0.0,
1083
+ num_layers: int = 1,
1084
+ resnet_eps: float = 1e-6,
1085
+ resnet_time_scale_shift: str = "default",
1086
+ resnet_act_fn: str = "swish",
1087
+ resnet_groups: int = 32,
1088
+ resnet_pre_norm: bool = True,
1089
+ attn_num_head_channels=1,
1090
+ cross_attention_dim=1280,
1091
+ attention_type="default",
1092
+ output_scale_factor=1.0,
1093
+ add_upsample=True,
1094
+ ):
1095
+ super().__init__()
1096
+ resnets = []
1097
+ attentions = []
1098
+
1099
+ self.attention_type = attention_type
1100
+ self.attn_num_head_channels = attn_num_head_channels
1101
+
1102
+ for i in range(num_layers):
1103
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1104
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1105
+
1106
+ resnets.append(
1107
+ ResnetBlock2D(
1108
+ in_channels=resnet_in_channels + res_skip_channels,
1109
+ out_channels=out_channels,
1110
+ temb_channels=temb_channels,
1111
+ eps=resnet_eps,
1112
+ groups=resnet_groups,
1113
+ dropout=dropout,
1114
+ time_embedding_norm=resnet_time_scale_shift,
1115
+ non_linearity=resnet_act_fn,
1116
+ output_scale_factor=output_scale_factor,
1117
+ pre_norm=resnet_pre_norm,
1118
+ )
1119
+ )
1120
+ attentions.append(
1121
+ Transformer2DModel(
1122
+ attn_num_head_channels,
1123
+ out_channels // attn_num_head_channels,
1124
+ in_channels=out_channels,
1125
+ num_layers=1,
1126
+ cross_attention_dim=cross_attention_dim,
1127
+ norm_num_groups=resnet_groups,
1128
+ )
1129
+ )
1130
+ self.attentions = nn.ModuleList(attentions)
1131
+ self.resnets = nn.ModuleList(resnets)
1132
+
1133
+ if add_upsample:
1134
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1135
+ else:
1136
+ self.upsamplers = None
1137
+
1138
+ self.gradient_checkpointing = False
1139
+
1140
+ def set_attention_slice(self, slice_size):
1141
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1142
+ raise ValueError(
1143
+ f"Make sure slice_size {slice_size} is a divisor of "
1144
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1145
+ )
1146
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1147
+ raise ValueError(
1148
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1149
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1150
+ )
1151
+
1152
+ for attn in self.attentions:
1153
+ attn._set_attention_slice(slice_size)
1154
+
1155
+ self.gradient_checkpointing = False
1156
+
1157
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1158
+ for attn in self.attentions:
1159
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
1160
+
1161
+ def forward(
1162
+ self,
1163
+ hidden_states,
1164
+ res_hidden_states_tuple,
1165
+ temb=None,
1166
+ encoder_hidden_states=None,
1167
+ encoder_attention_mask=None,
1168
+ upsample_size=None,
1169
+ ):
1170
+ for resnet, attn in zip(self.resnets, self.attentions):
1171
+ # pop res hidden states
1172
+ res_hidden_states = res_hidden_states_tuple[-1]
1173
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1174
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1175
+
1176
+ if self.training and self.gradient_checkpointing:
1177
+
1178
+ def create_custom_forward(module, return_dict=None):
1179
+ def custom_forward(*inputs):
1180
+ if return_dict is not None:
1181
+ return module(*inputs, return_dict=return_dict)
1182
+ else:
1183
+ return module(*inputs)
1184
+
1185
+ return custom_forward
1186
+
1187
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1188
+ hidden_states = torch.utils.checkpoint.checkpoint(
1189
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
1190
+ encoder_attention_mask
1191
+ )[0]
1192
+ else:
1193
+ hidden_states = resnet(hidden_states, temb)
1194
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
1195
+ encoder_attention_mask=encoder_attention_mask).sample
1196
+
1197
+ if self.upsamplers is not None:
1198
+ for upsampler in self.upsamplers:
1199
+ hidden_states = upsampler(hidden_states, upsample_size)
1200
+
1201
+ return hidden_states
1202
+
1203
+
1204
+ class UpBlock2D(nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ in_channels: int,
1208
+ prev_output_channel: int,
1209
+ out_channels: int,
1210
+ temb_channels: int,
1211
+ dropout: float = 0.0,
1212
+ num_layers: int = 1,
1213
+ resnet_eps: float = 1e-6,
1214
+ resnet_time_scale_shift: str = "default",
1215
+ resnet_act_fn: str = "swish",
1216
+ resnet_groups: int = 32,
1217
+ resnet_pre_norm: bool = True,
1218
+ output_scale_factor=1.0,
1219
+ add_upsample=True,
1220
+ ):
1221
+ super().__init__()
1222
+ resnets = []
1223
+
1224
+ for i in range(num_layers):
1225
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1226
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1227
+
1228
+ resnets.append(
1229
+ ResnetBlock2D(
1230
+ in_channels=resnet_in_channels + res_skip_channels,
1231
+ out_channels=out_channels,
1232
+ temb_channels=temb_channels,
1233
+ eps=resnet_eps,
1234
+ groups=resnet_groups,
1235
+ dropout=dropout,
1236
+ time_embedding_norm=resnet_time_scale_shift,
1237
+ non_linearity=resnet_act_fn,
1238
+ output_scale_factor=output_scale_factor,
1239
+ pre_norm=resnet_pre_norm,
1240
+ )
1241
+ )
1242
+
1243
+ self.resnets = nn.ModuleList(resnets)
1244
+
1245
+ if add_upsample:
1246
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1247
+ else:
1248
+ self.upsamplers = None
1249
+
1250
+ self.gradient_checkpointing = False
1251
+
1252
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1253
+ for resnet in self.resnets:
1254
+ # pop res hidden states
1255
+ res_hidden_states = res_hidden_states_tuple[-1]
1256
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1257
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1258
+
1259
+ if self.training and self.gradient_checkpointing:
1260
+
1261
+ def create_custom_forward(module):
1262
+ def custom_forward(*inputs):
1263
+ return module(*inputs)
1264
+
1265
+ return custom_forward
1266
+
1267
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1268
+ else:
1269
+ hidden_states = resnet(hidden_states, temb)
1270
+
1271
+ if self.upsamplers is not None:
1272
+ for upsampler in self.upsamplers:
1273
+ hidden_states = upsampler(hidden_states, upsample_size)
1274
+
1275
+ return hidden_states
1276
+
1277
+
1278
+ class UpDecoderBlock2D(nn.Module):
1279
+ def __init__(
1280
+ self,
1281
+ in_channels: int,
1282
+ out_channels: int,
1283
+ dropout: float = 0.0,
1284
+ num_layers: int = 1,
1285
+ resnet_eps: float = 1e-6,
1286
+ resnet_time_scale_shift: str = "default",
1287
+ resnet_act_fn: str = "swish",
1288
+ resnet_groups: int = 32,
1289
+ resnet_pre_norm: bool = True,
1290
+ output_scale_factor=1.0,
1291
+ add_upsample=True,
1292
+ ):
1293
+ super().__init__()
1294
+ resnets = []
1295
+
1296
+ for i in range(num_layers):
1297
+ input_channels = in_channels if i == 0 else out_channels
1298
+
1299
+ resnets.append(
1300
+ ResnetBlock2D(
1301
+ in_channels=input_channels,
1302
+ out_channels=out_channels,
1303
+ temb_channels=None,
1304
+ eps=resnet_eps,
1305
+ groups=resnet_groups,
1306
+ dropout=dropout,
1307
+ time_embedding_norm=resnet_time_scale_shift,
1308
+ non_linearity=resnet_act_fn,
1309
+ output_scale_factor=output_scale_factor,
1310
+ pre_norm=resnet_pre_norm,
1311
+ )
1312
+ )
1313
+
1314
+ self.resnets = nn.ModuleList(resnets)
1315
+
1316
+ if add_upsample:
1317
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1318
+ else:
1319
+ self.upsamplers = None
1320
+
1321
+ def forward(self, hidden_states):
1322
+ for resnet in self.resnets:
1323
+ hidden_states = resnet(hidden_states, temb=None)
1324
+
1325
+ if self.upsamplers is not None:
1326
+ for upsampler in self.upsamplers:
1327
+ hidden_states = upsampler(hidden_states)
1328
+
1329
+ return hidden_states
1330
+
1331
+
1332
+ class AttnUpDecoderBlock2D(nn.Module):
1333
+ def __init__(
1334
+ self,
1335
+ in_channels: int,
1336
+ out_channels: int,
1337
+ dropout: float = 0.0,
1338
+ num_layers: int = 1,
1339
+ resnet_eps: float = 1e-6,
1340
+ resnet_time_scale_shift: str = "default",
1341
+ resnet_act_fn: str = "swish",
1342
+ resnet_groups: int = 32,
1343
+ resnet_pre_norm: bool = True,
1344
+ attn_num_head_channels=1,
1345
+ output_scale_factor=1.0,
1346
+ add_upsample=True,
1347
+ ):
1348
+ super().__init__()
1349
+ resnets = []
1350
+ attentions = []
1351
+
1352
+ for i in range(num_layers):
1353
+ input_channels = in_channels if i == 0 else out_channels
1354
+
1355
+ resnets.append(
1356
+ ResnetBlock2D(
1357
+ in_channels=input_channels,
1358
+ out_channels=out_channels,
1359
+ temb_channels=None,
1360
+ eps=resnet_eps,
1361
+ groups=resnet_groups,
1362
+ dropout=dropout,
1363
+ time_embedding_norm=resnet_time_scale_shift,
1364
+ non_linearity=resnet_act_fn,
1365
+ output_scale_factor=output_scale_factor,
1366
+ pre_norm=resnet_pre_norm,
1367
+ )
1368
+ )
1369
+ attentions.append(
1370
+ AttentionBlock(
1371
+ out_channels,
1372
+ num_head_channels=attn_num_head_channels,
1373
+ rescale_output_factor=output_scale_factor,
1374
+ eps=resnet_eps,
1375
+ norm_num_groups=resnet_groups,
1376
+ )
1377
+ )
1378
+
1379
+ self.attentions = nn.ModuleList(attentions)
1380
+ self.resnets = nn.ModuleList(resnets)
1381
+
1382
+ if add_upsample:
1383
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1384
+ else:
1385
+ self.upsamplers = None
1386
+
1387
+ def forward(self, hidden_states):
1388
+ for resnet, attn in zip(self.resnets, self.attentions):
1389
+ hidden_states = resnet(hidden_states, temb=None)
1390
+ hidden_states = attn(hidden_states)
1391
+
1392
+ if self.upsamplers is not None:
1393
+ for upsampler in self.upsamplers:
1394
+ hidden_states = upsampler(hidden_states)
1395
+
1396
+ return hidden_states
1397
+
1398
+
1399
+ class AttnSkipUpBlock2D(nn.Module):
1400
+ def __init__(
1401
+ self,
1402
+ in_channels: int,
1403
+ prev_output_channel: int,
1404
+ out_channels: int,
1405
+ temb_channels: int,
1406
+ dropout: float = 0.0,
1407
+ num_layers: int = 1,
1408
+ resnet_eps: float = 1e-6,
1409
+ resnet_time_scale_shift: str = "default",
1410
+ resnet_act_fn: str = "swish",
1411
+ resnet_pre_norm: bool = True,
1412
+ attn_num_head_channels=1,
1413
+ attention_type="default",
1414
+ output_scale_factor=np.sqrt(2.0),
1415
+ upsample_padding=1,
1416
+ add_upsample=True,
1417
+ ):
1418
+ super().__init__()
1419
+ self.attentions = nn.ModuleList([])
1420
+ self.resnets = nn.ModuleList([])
1421
+
1422
+ self.attention_type = attention_type
1423
+
1424
+ for i in range(num_layers):
1425
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1426
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1427
+
1428
+ self.resnets.append(
1429
+ ResnetBlock2D(
1430
+ in_channels=resnet_in_channels + res_skip_channels,
1431
+ out_channels=out_channels,
1432
+ temb_channels=temb_channels,
1433
+ eps=resnet_eps,
1434
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1435
+ groups_out=min(out_channels // 4, 32),
1436
+ dropout=dropout,
1437
+ time_embedding_norm=resnet_time_scale_shift,
1438
+ non_linearity=resnet_act_fn,
1439
+ output_scale_factor=output_scale_factor,
1440
+ pre_norm=resnet_pre_norm,
1441
+ )
1442
+ )
1443
+
1444
+ self.attentions.append(
1445
+ AttentionBlock(
1446
+ out_channels,
1447
+ num_head_channels=attn_num_head_channels,
1448
+ rescale_output_factor=output_scale_factor,
1449
+ eps=resnet_eps,
1450
+ )
1451
+ )
1452
+
1453
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1454
+ if add_upsample:
1455
+ self.resnet_up = ResnetBlock2D(
1456
+ in_channels=out_channels,
1457
+ out_channels=out_channels,
1458
+ temb_channels=temb_channels,
1459
+ eps=resnet_eps,
1460
+ groups=min(out_channels // 4, 32),
1461
+ groups_out=min(out_channels // 4, 32),
1462
+ dropout=dropout,
1463
+ time_embedding_norm=resnet_time_scale_shift,
1464
+ non_linearity=resnet_act_fn,
1465
+ output_scale_factor=output_scale_factor,
1466
+ pre_norm=resnet_pre_norm,
1467
+ use_in_shortcut=True,
1468
+ up=True,
1469
+ kernel="fir",
1470
+ )
1471
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1472
+ self.skip_norm = torch.nn.GroupNorm(
1473
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1474
+ )
1475
+ self.act = nn.SiLU()
1476
+ else:
1477
+ self.resnet_up = None
1478
+ self.skip_conv = None
1479
+ self.skip_norm = None
1480
+ self.act = None
1481
+
1482
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1483
+ for resnet in self.resnets:
1484
+ # pop res hidden states
1485
+ res_hidden_states = res_hidden_states_tuple[-1]
1486
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1487
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1488
+
1489
+ hidden_states = resnet(hidden_states, temb)
1490
+
1491
+ hidden_states = self.attentions[0](hidden_states)
1492
+
1493
+ if skip_sample is not None:
1494
+ skip_sample = self.upsampler(skip_sample)
1495
+ else:
1496
+ skip_sample = 0
1497
+
1498
+ if self.resnet_up is not None:
1499
+ skip_sample_states = self.skip_norm(hidden_states)
1500
+ skip_sample_states = self.act(skip_sample_states)
1501
+ skip_sample_states = self.skip_conv(skip_sample_states)
1502
+
1503
+ skip_sample = skip_sample + skip_sample_states
1504
+
1505
+ hidden_states = self.resnet_up(hidden_states, temb)
1506
+
1507
+ return hidden_states, skip_sample
1508
+
1509
+
1510
+ class SkipUpBlock2D(nn.Module):
1511
+ def __init__(
1512
+ self,
1513
+ in_channels: int,
1514
+ prev_output_channel: int,
1515
+ out_channels: int,
1516
+ temb_channels: int,
1517
+ dropout: float = 0.0,
1518
+ num_layers: int = 1,
1519
+ resnet_eps: float = 1e-6,
1520
+ resnet_time_scale_shift: str = "default",
1521
+ resnet_act_fn: str = "swish",
1522
+ resnet_pre_norm: bool = True,
1523
+ output_scale_factor=np.sqrt(2.0),
1524
+ add_upsample=True,
1525
+ upsample_padding=1,
1526
+ ):
1527
+ super().__init__()
1528
+ self.resnets = nn.ModuleList([])
1529
+
1530
+ for i in range(num_layers):
1531
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1532
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1533
+
1534
+ self.resnets.append(
1535
+ ResnetBlock2D(
1536
+ in_channels=resnet_in_channels + res_skip_channels,
1537
+ out_channels=out_channels,
1538
+ temb_channels=temb_channels,
1539
+ eps=resnet_eps,
1540
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1541
+ groups_out=min(out_channels // 4, 32),
1542
+ dropout=dropout,
1543
+ time_embedding_norm=resnet_time_scale_shift,
1544
+ non_linearity=resnet_act_fn,
1545
+ output_scale_factor=output_scale_factor,
1546
+ pre_norm=resnet_pre_norm,
1547
+ )
1548
+ )
1549
+
1550
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1551
+ if add_upsample:
1552
+ self.resnet_up = ResnetBlock2D(
1553
+ in_channels=out_channels,
1554
+ out_channels=out_channels,
1555
+ temb_channels=temb_channels,
1556
+ eps=resnet_eps,
1557
+ groups=min(out_channels // 4, 32),
1558
+ groups_out=min(out_channels // 4, 32),
1559
+ dropout=dropout,
1560
+ time_embedding_norm=resnet_time_scale_shift,
1561
+ non_linearity=resnet_act_fn,
1562
+ output_scale_factor=output_scale_factor,
1563
+ pre_norm=resnet_pre_norm,
1564
+ use_in_shortcut=True,
1565
+ up=True,
1566
+ kernel="fir",
1567
+ )
1568
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1569
+ self.skip_norm = torch.nn.GroupNorm(
1570
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1571
+ )
1572
+ self.act = nn.SiLU()
1573
+ else:
1574
+ self.resnet_up = None
1575
+ self.skip_conv = None
1576
+ self.skip_norm = None
1577
+ self.act = None
1578
+
1579
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1580
+ for resnet in self.resnets:
1581
+ # pop res hidden states
1582
+ res_hidden_states = res_hidden_states_tuple[-1]
1583
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1584
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1585
+
1586
+ hidden_states = resnet(hidden_states, temb)
1587
+
1588
+ if skip_sample is not None:
1589
+ skip_sample = self.upsampler(skip_sample)
1590
+ else:
1591
+ skip_sample = 0
1592
+
1593
+ if self.resnet_up is not None:
1594
+ skip_sample_states = self.skip_norm(hidden_states)
1595
+ skip_sample_states = self.act(skip_sample_states)
1596
+ skip_sample_states = self.skip_conv(skip_sample_states)
1597
+
1598
+ skip_sample = skip_sample + skip_sample_states
1599
+
1600
+ hidden_states = self.resnet_up(hidden_states, temb)
1601
+
1602
+ return hidden_states, skip_sample
models/diffusers_override/unet_2d_condition.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.modeling_utils import ModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from .unet_2d_blocks import (
26
+ CrossAttnDownBlock2D,
27
+ CrossAttnUpBlock2D,
28
+ DownBlock2D,
29
+ UNetMidBlock2DCrossAttn,
30
+ UpBlock2D,
31
+ get_down_block,
32
+ get_up_block,
33
+ )
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class UNet2DConditionOutput(BaseOutput):
40
+ """
41
+ Args:
42
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
43
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
44
+ """
45
+
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
50
+ r"""
51
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
52
+ and returns sample shaped output.
53
+
54
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
55
+ implements for all the models (such as downloading or saving, etc.)
56
+
57
+ Parameters:
58
+ sample_size (`int`, *optional*): The size of the input sample.
59
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
60
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
61
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
62
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
63
+ Whether to flip the sin to cos in the time embedding.
64
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
65
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
66
+ The tuple of downsample blocks to use.
67
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
68
+ The tuple of upsample blocks to use.
69
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
70
+ The tuple of output channels for each block.
71
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
72
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
73
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
74
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
76
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
77
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
78
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
79
+ """
80
+
81
+ _supports_gradient_checkpointing = True
82
+
83
+ @register_to_config
84
+ def __init__(
85
+ self,
86
+ sample_size: Optional[int] = None,
87
+ in_channels: int = 4,
88
+ out_channels: int = 4,
89
+ center_input_sample: bool = False,
90
+ flip_sin_to_cos: bool = True,
91
+ freq_shift: int = 0,
92
+ down_block_types: Tuple[str] = (
93
+ "CrossAttnDownBlock2D",
94
+ "CrossAttnDownBlock2D",
95
+ "CrossAttnDownBlock2D",
96
+ "DownBlock2D",
97
+ ),
98
+ up_block_types: Tuple[str] = (
99
+ "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
100
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
101
+ layers_per_block: int = 2,
102
+ downsample_padding: int = 1,
103
+ mid_block_scale_factor: float = 1,
104
+ act_fn: str = "silu",
105
+ norm_num_groups: int = 32,
106
+ norm_eps: float = 1e-5,
107
+ cross_attention_dim: int = 1280,
108
+ attention_head_dim: int = 8,
109
+ ):
110
+ super().__init__()
111
+
112
+ self.sample_size = sample_size
113
+ time_embed_dim = block_out_channels[0] * 4
114
+
115
+ # input
116
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
117
+
118
+ # time
119
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
120
+ timestep_input_dim = block_out_channels[0]
121
+
122
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
123
+
124
+ self.down_blocks = nn.ModuleList([])
125
+ self.mid_block = None
126
+ self.up_blocks = nn.ModuleList([])
127
+
128
+ # down
129
+ output_channel = block_out_channels[0]
130
+ for i, down_block_type in enumerate(down_block_types):
131
+ input_channel = output_channel
132
+ output_channel = block_out_channels[i]
133
+ is_final_block = i == len(block_out_channels) - 1
134
+
135
+ down_block = get_down_block(
136
+ down_block_type,
137
+ num_layers=layers_per_block,
138
+ in_channels=input_channel,
139
+ out_channels=output_channel,
140
+ temb_channels=time_embed_dim,
141
+ add_downsample=not is_final_block,
142
+ resnet_eps=norm_eps,
143
+ resnet_act_fn=act_fn,
144
+ resnet_groups=norm_num_groups,
145
+ cross_attention_dim=cross_attention_dim,
146
+ attn_num_head_channels=attention_head_dim,
147
+ downsample_padding=downsample_padding,
148
+ )
149
+ self.down_blocks.append(down_block)
150
+
151
+ # mid
152
+ self.mid_block = UNetMidBlock2DCrossAttn(
153
+ in_channels=block_out_channels[-1],
154
+ temb_channels=time_embed_dim,
155
+ resnet_eps=norm_eps,
156
+ resnet_act_fn=act_fn,
157
+ output_scale_factor=mid_block_scale_factor,
158
+ resnet_time_scale_shift="default",
159
+ cross_attention_dim=cross_attention_dim,
160
+ attn_num_head_channels=attention_head_dim,
161
+ resnet_groups=norm_num_groups,
162
+ )
163
+
164
+ # count how many layers upsample the images
165
+ self.num_upsamplers = 0
166
+
167
+ # up
168
+ reversed_block_out_channels = list(reversed(block_out_channels))
169
+ output_channel = reversed_block_out_channels[0]
170
+ for i, up_block_type in enumerate(up_block_types):
171
+ is_final_block = i == len(block_out_channels) - 1
172
+
173
+ prev_output_channel = output_channel
174
+ output_channel = reversed_block_out_channels[i]
175
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
176
+
177
+ # add upsample block for all BUT final layer
178
+ if not is_final_block:
179
+ add_upsample = True
180
+ self.num_upsamplers += 1
181
+ else:
182
+ add_upsample = False
183
+
184
+ up_block = get_up_block(
185
+ up_block_type,
186
+ num_layers=layers_per_block + 1,
187
+ in_channels=input_channel,
188
+ out_channels=output_channel,
189
+ prev_output_channel=prev_output_channel,
190
+ temb_channels=time_embed_dim,
191
+ add_upsample=add_upsample,
192
+ resnet_eps=norm_eps,
193
+ resnet_act_fn=act_fn,
194
+ resnet_groups=norm_num_groups,
195
+ cross_attention_dim=cross_attention_dim,
196
+ attn_num_head_channels=attention_head_dim,
197
+ )
198
+ self.up_blocks.append(up_block)
199
+ prev_output_channel = output_channel
200
+
201
+ # out
202
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
203
+ self.conv_act = nn.SiLU()
204
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
205
+
206
+ def set_attention_slice(self, slice_size):
207
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
208
+ raise ValueError(
209
+ f"Make sure slice_size {slice_size} is a divisor of "
210
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
211
+ )
212
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
213
+ raise ValueError(
214
+ f"Chunk_size {slice_size} has to be smaller or equal to "
215
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
216
+ )
217
+
218
+ for block in self.down_blocks:
219
+ if hasattr(block, "attentions") and block.attentions is not None:
220
+ block.set_attention_slice(slice_size)
221
+
222
+ self.mid_block.set_attention_slice(slice_size)
223
+
224
+ for block in self.up_blocks:
225
+ if hasattr(block, "attentions") and block.attentions is not None:
226
+ block.set_attention_slice(slice_size)
227
+
228
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
229
+ for block in self.down_blocks:
230
+ if hasattr(block, "attentions") and block.attentions is not None:
231
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
232
+
233
+ self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
234
+
235
+ for block in self.up_blocks:
236
+ if hasattr(block, "attentions") and block.attentions is not None:
237
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
238
+
239
+ def _set_gradient_checkpointing(self, module, value=False):
240
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
241
+ module.gradient_checkpointing = value
242
+
243
+ def forward(
244
+ self,
245
+ sample: torch.FloatTensor,
246
+ timestep: Union[torch.Tensor, float, int],
247
+ encoder_hidden_states: torch.Tensor,
248
+ encoder_attention_mask: torch.Tensor,
249
+ return_dict: bool = True,
250
+ ) -> Union[UNet2DConditionOutput, Tuple]:
251
+ r"""
252
+ Args:
253
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
254
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
255
+ encoder_hidden_states (`torch.FloatTensor`):
256
+ (batch_size, sequence_length, hidden_size) encoder hidden states
257
+ encoder_attention_mask (`torch.FloatTensor`):
258
+ (batch_size, sequence_length) encoder attention mask
259
+ return_dict (`bool`, *optional*, defaults to `True`):
260
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
261
+
262
+ Returns:
263
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
264
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
265
+ returning a tuple, the first element is the sample tensor.
266
+ """
267
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
268
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
269
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
270
+ # on the fly if necessary.
271
+ default_overall_up_factor = 2 ** self.num_upsamplers
272
+
273
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
274
+ forward_upsample_size = False
275
+ upsample_size = None
276
+
277
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
278
+ logger.info("Forward upsample size to force interpolation output size.")
279
+ forward_upsample_size = True
280
+
281
+ # 0. center input if necessary
282
+ if self.config.center_input_sample:
283
+ sample = 2 * sample - 1.0
284
+
285
+ # 1. time
286
+ timesteps = timestep
287
+ if not torch.is_tensor(timesteps):
288
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
289
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
290
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
291
+ timesteps = timesteps[None].to(sample.device)
292
+
293
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
294
+ timesteps = timesteps.expand(sample.shape[0])
295
+
296
+ t_emb = self.time_proj(timesteps)
297
+
298
+ # timesteps does not contain any weights and will always return f32 tensors
299
+ # but time_embedding might actually be running in fp16. so we need to cast here.
300
+ # there might be better ways to encapsulate this.
301
+ t_emb = t_emb.to(dtype=self.dtype)
302
+ emb = self.time_embedding(t_emb)
303
+
304
+ # 2. pre-process
305
+ sample = self.conv_in(sample)
306
+
307
+ # 3. down
308
+ down_block_res_samples = (sample,)
309
+ for downsample_block in self.down_blocks:
310
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
311
+ sample, res_samples = downsample_block(
312
+ hidden_states=sample,
313
+ temb=emb,
314
+ encoder_hidden_states=encoder_hidden_states,
315
+ encoder_attention_mask=encoder_attention_mask,
316
+ )
317
+ else:
318
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
319
+
320
+ down_block_res_samples += res_samples
321
+
322
+ # 4. mid
323
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
324
+ encoder_attention_mask=encoder_attention_mask)
325
+
326
+ # 5. up
327
+ for i, upsample_block in enumerate(self.up_blocks):
328
+ is_final_block = i == len(self.up_blocks) - 1
329
+
330
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
331
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
332
+
333
+ # if we have not reached the final block and need to forward the
334
+ # upsample size, we do it here
335
+ if not is_final_block and forward_upsample_size:
336
+ upsample_size = down_block_res_samples[-1].shape[2:]
337
+
338
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
339
+ sample = upsample_block(
340
+ hidden_states=sample,
341
+ temb=emb,
342
+ res_hidden_states_tuple=res_samples,
343
+ encoder_hidden_states=encoder_hidden_states,
344
+ encoder_attention_mask=encoder_attention_mask,
345
+ upsample_size=upsample_size,
346
+ )
347
+ else:
348
+ sample = upsample_block(
349
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
350
+ )
351
+ # 6. post-process
352
+ sample = self.conv_norm_out(sample)
353
+ sample = self.conv_act(sample)
354
+ sample = self.conv_out(sample)
355
+
356
+ if not return_dict:
357
+ return (sample,)
358
+
359
+ return UNet2DConditionOutput(sample=sample)
models/inception.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+
6
+ try:
7
+ from torchvision.models.utils import load_state_dict_from_url
8
+ except ImportError:
9
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
14
+
15
+
16
+ class InceptionV3(nn.Module):
17
+ """Pretrained InceptionV3 network returning feature maps"""
18
+
19
+ # Index of default block of inception to return,
20
+ # corresponds to output of final average pooling
21
+ DEFAULT_BLOCK_INDEX = 3
22
+
23
+ # Maps feature dimensionality to their output blocks indices
24
+ BLOCK_INDEX_BY_DIM = {
25
+ 64: 0, # First max pooling features
26
+ 192: 1, # Second max pooling featurs
27
+ 768: 2, # Pre-aux classifier features
28
+ 2048: 3 # Final average pooling features
29
+ }
30
+
31
+ def __init__(self,
32
+ output_blocks=[DEFAULT_BLOCK_INDEX],
33
+ resize_input=True,
34
+ normalize_input=True,
35
+ requires_grad=False,
36
+ use_fid_inception=True):
37
+ """Build pretrained InceptionV3
38
+
39
+ Parameters
40
+ ----------
41
+ output_blocks : list of int
42
+ Indices of blocks to return features of. Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input : bool
48
+ If true, bilinearly resizes input to width and height 299 before
49
+ feeding input to model. As the network without fully connected
50
+ layers is fully convolutional, it should be able to handle inputs
51
+ of arbitrary size, so resizing might not be strictly needed
52
+ normalize_input : bool
53
+ If true, scales the input from range (0, 1) to the range the
54
+ pretrained Inception network expects, namely (-1, 1)
55
+ requires_grad : bool
56
+ If true, parameters of the model require gradients. Possibly useful
57
+ for finetuning the network
58
+ use_fid_inception : bool
59
+ If true, uses the pretrained Inception model used in Tensorflow's
60
+ FID implementation. If false, uses the pretrained Inception model
61
+ available in torchvision. The FID Inception model has different
62
+ weights and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get comparable
65
+ results.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, \
75
+ 'Last possible output block index is 3'
76
+
77
+ self.blocks = nn.ModuleList()
78
+
79
+ if use_fid_inception:
80
+ inception = fid_inception_v3()
81
+ else:
82
+ inception = models.inception_v3(pretrained=True)
83
+
84
+ # Block 0: input to maxpool1
85
+ block0 = [
86
+ inception.Conv2d_1a_3x3,
87
+ inception.Conv2d_2a_3x3,
88
+ inception.Conv2d_2b_3x3,
89
+ nn.MaxPool2d(kernel_size=3, stride=2)
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block0))
92
+
93
+ # Block 1: maxpool1 to maxpool2
94
+ if self.last_needed_block >= 1:
95
+ block1 = [
96
+ inception.Conv2d_3b_1x1,
97
+ inception.Conv2d_4a_3x3,
98
+ nn.MaxPool2d(kernel_size=3, stride=2)
99
+ ]
100
+ self.blocks.append(nn.Sequential(*block1))
101
+
102
+ # Block 2: maxpool2 to aux classifier
103
+ if self.last_needed_block >= 2:
104
+ block2 = [
105
+ inception.Mixed_5b,
106
+ inception.Mixed_5c,
107
+ inception.Mixed_5d,
108
+ inception.Mixed_6a,
109
+ inception.Mixed_6b,
110
+ inception.Mixed_6c,
111
+ inception.Mixed_6d,
112
+ inception.Mixed_6e,
113
+ ]
114
+ self.blocks.append(nn.Sequential(*block2))
115
+
116
+ # Block 3: aux classifier to final avgpool
117
+ if self.last_needed_block >= 3:
118
+ block3 = [
119
+ inception.Mixed_7a,
120
+ inception.Mixed_7b,
121
+ inception.Mixed_7c,
122
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
123
+ ]
124
+ self.blocks.append(nn.Sequential(*block3))
125
+
126
+ for param in self.parameters():
127
+ param.requires_grad = requires_grad
128
+
129
+ def forward(self, inp):
130
+ """Get Inception feature maps
131
+
132
+ Parameters
133
+ ----------
134
+ inp : torch.autograd.Variable
135
+ Input tensor of shape Bx3xHxW. Values are expected to be in
136
+ range (0, 1)
137
+
138
+ Returns
139
+ -------
140
+ List of torch.autograd.Variable, corresponding to the selected output
141
+ block, sorted ascending by index
142
+ """
143
+ outp = []
144
+ x = inp
145
+
146
+ if self.resize_input:
147
+ x = F.interpolate(x,
148
+ size=(299, 299),
149
+ mode='bilinear',
150
+ align_corners=False)
151
+
152
+ if self.normalize_input:
153
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
154
+
155
+ for idx, block in enumerate(self.blocks):
156
+ x = block(x)
157
+ if idx in self.output_blocks:
158
+ outp.append(x)
159
+
160
+ if idx == self.last_needed_block:
161
+ break
162
+
163
+ return outp
164
+
165
+
166
+ def fid_inception_v3():
167
+ """Build pretrained Inception model for FID computation
168
+
169
+ The Inception model for FID computation uses a different set of weights
170
+ and has a slightly different structure than torchvision's Inception.
171
+
172
+ This method first constructs torchvision's Inception and then patches the
173
+ necessary parts that are different in the FID Inception model.
174
+ """
175
+ inception = models.inception_v3(num_classes=1008,
176
+ aux_logits=False,
177
+ pretrained=False)
178
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
179
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
180
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
181
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
182
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
183
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
184
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
185
+ inception.Mixed_7b = FIDInceptionE_1(1280)
186
+ inception.Mixed_7c = FIDInceptionE_2(2048)
187
+
188
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
189
+ inception.load_state_dict(state_dict)
190
+ return inception
191
+
192
+
193
+ class FIDInceptionA(models.inception.InceptionA):
194
+ """InceptionA block patched for FID computation"""
195
+
196
+ def __init__(self, in_channels, pool_features):
197
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
198
+
199
+ def forward(self, x):
200
+ branch1x1 = self.branch1x1(x)
201
+
202
+ branch5x5 = self.branch5x5_1(x)
203
+ branch5x5 = self.branch5x5_2(branch5x5)
204
+
205
+ branch3x3dbl = self.branch3x3dbl_1(x)
206
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
207
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
208
+
209
+ # Patch: Tensorflow's average pool does not use the padded zero's in
210
+ # its average calculation
211
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
212
+ count_include_pad=False)
213
+ branch_pool = self.branch_pool(branch_pool)
214
+
215
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
216
+ return torch.cat(outputs, 1)
217
+
218
+
219
+ class FIDInceptionC(models.inception.InceptionC):
220
+ """InceptionC block patched for FID computation"""
221
+
222
+ def __init__(self, in_channels, channels_7x7):
223
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
224
+
225
+ def forward(self, x):
226
+ branch1x1 = self.branch1x1(x)
227
+
228
+ branch7x7 = self.branch7x7_1(x)
229
+ branch7x7 = self.branch7x7_2(branch7x7)
230
+ branch7x7 = self.branch7x7_3(branch7x7)
231
+
232
+ branch7x7dbl = self.branch7x7dbl_1(x)
233
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
234
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
235
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
236
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
237
+
238
+ # Patch: Tensorflow's average pool does not use the padded zero's in
239
+ # its average calculation
240
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
241
+ count_include_pad=False)
242
+ branch_pool = self.branch_pool(branch_pool)
243
+
244
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
245
+ return torch.cat(outputs, 1)
246
+
247
+
248
+ class FIDInceptionE_1(models.inception.InceptionE):
249
+ """First InceptionE block patched for FID computation"""
250
+
251
+ def __init__(self, in_channels):
252
+ super(FIDInceptionE_1, self).__init__(in_channels)
253
+
254
+ def forward(self, x):
255
+ branch1x1 = self.branch1x1(x)
256
+
257
+ branch3x3 = self.branch3x3_1(x)
258
+ branch3x3 = [
259
+ self.branch3x3_2a(branch3x3),
260
+ self.branch3x3_2b(branch3x3),
261
+ ]
262
+ branch3x3 = torch.cat(branch3x3, 1)
263
+
264
+ branch3x3dbl = self.branch3x3dbl_1(x)
265
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
266
+ branch3x3dbl = [
267
+ self.branch3x3dbl_3a(branch3x3dbl),
268
+ self.branch3x3dbl_3b(branch3x3dbl),
269
+ ]
270
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
271
+
272
+ # Patch: Tensorflow's average pool does not use the padded zero's in
273
+ # its average calculation
274
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
275
+ count_include_pad=False)
276
+ branch_pool = self.branch_pool(branch_pool)
277
+
278
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
279
+ return torch.cat(outputs, 1)
280
+
281
+
282
+ class FIDInceptionE_2(models.inception.InceptionE):
283
+ """Second InceptionE block patched for FID computation"""
284
+
285
+ def __init__(self, in_channels):
286
+ super(FIDInceptionE_2, self).__init__(in_channels)
287
+
288
+ def forward(self, x):
289
+ branch1x1 = self.branch1x1(x)
290
+
291
+ branch3x3 = self.branch3x3_1(x)
292
+ branch3x3 = [
293
+ self.branch3x3_2a(branch3x3),
294
+ self.branch3x3_2b(branch3x3),
295
+ ]
296
+ branch3x3 = torch.cat(branch3x3, 1)
297
+
298
+ branch3x3dbl = self.branch3x3dbl_1(x)
299
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
300
+ branch3x3dbl = [
301
+ self.branch3x3dbl_3a(branch3x3dbl),
302
+ self.branch3x3dbl_3b(branch3x3dbl),
303
+ ]
304
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
305
+
306
+ # Patch: The FID Inception model uses max pooling instead of average
307
+ # pooling. This is likely an error in this specific Inception
308
+ # implementation, as other Inception models use average pooling here
309
+ # (which matches the description in the paper).
310
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
311
+ branch_pool = self.branch_pool(branch_pool)
312
+
313
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
314
+ return torch.cat(outputs, 1)