flamehaze1115 commited on
Commit
3d356e9
1 Parent(s): f67ff5d

Upload 3 files

Browse files
mvdiffusion/models/transformer_mv2d.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph
24
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
25
+ from diffusers.models.embeddings import PatchEmbed
26
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils.import_utils import is_xformers_available
29
+
30
+ from einops import rearrange, repeat
31
+ import pdb
32
+ import random
33
+
34
+
35
+ if is_xformers_available():
36
+ import xformers
37
+ import xformers.ops
38
+ else:
39
+ xformers = None
40
+
41
+ def my_repeat(tensor, num_repeats):
42
+ """
43
+ Repeat a tensor along a given dimension
44
+ """
45
+ if len(tensor.shape) == 3:
46
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
47
+ elif len(tensor.shape) == 4:
48
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
49
+
50
+
51
+ @dataclass
52
+ class TransformerMV2DModelOutput(BaseOutput):
53
+ """
54
+ The output of [`Transformer2DModel`].
55
+
56
+ Args:
57
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
58
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
59
+ distributions for the unnoised latent pixels.
60
+ """
61
+
62
+ sample: torch.FloatTensor
63
+
64
+
65
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
66
+ """
67
+ A 2D Transformer model for image-like data.
68
+
69
+ Parameters:
70
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
71
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
72
+ in_channels (`int`, *optional*):
73
+ The number of channels in the input and output (specify if the input is **continuous**).
74
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
75
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
76
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
77
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
78
+ This is fixed during training since it is used to learn a number of position embeddings.
79
+ num_vector_embeds (`int`, *optional*):
80
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
81
+ Includes the class for the masked latent pixel.
82
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
83
+ num_embeds_ada_norm ( `int`, *optional*):
84
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
85
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
86
+ added to the hidden states.
87
+
88
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
89
+ attention_bias (`bool`, *optional*):
90
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
91
+ """
92
+
93
+ @register_to_config
94
+ def __init__(
95
+ self,
96
+ num_attention_heads: int = 16,
97
+ attention_head_dim: int = 88,
98
+ in_channels: Optional[int] = None,
99
+ out_channels: Optional[int] = None,
100
+ num_layers: int = 1,
101
+ dropout: float = 0.0,
102
+ norm_num_groups: int = 32,
103
+ cross_attention_dim: Optional[int] = None,
104
+ attention_bias: bool = False,
105
+ sample_size: Optional[int] = None,
106
+ num_vector_embeds: Optional[int] = None,
107
+ patch_size: Optional[int] = None,
108
+ activation_fn: str = "geglu",
109
+ num_embeds_ada_norm: Optional[int] = None,
110
+ use_linear_projection: bool = False,
111
+ only_cross_attention: bool = False,
112
+ upcast_attention: bool = False,
113
+ norm_type: str = "layer_norm",
114
+ norm_elementwise_affine: bool = True,
115
+ num_views: int = 1,
116
+ cd_attention_last: bool=False,
117
+ cd_attention_mid: bool=False,
118
+ multiview_attention: bool=True,
119
+ sparse_mv_attention: bool = False,
120
+ mvcd_attention: bool=False
121
+ ):
122
+ super().__init__()
123
+ self.use_linear_projection = use_linear_projection
124
+ self.num_attention_heads = num_attention_heads
125
+ self.attention_head_dim = attention_head_dim
126
+ inner_dim = num_attention_heads * attention_head_dim
127
+
128
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
129
+ # Define whether input is continuous or discrete depending on configuration
130
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
131
+ self.is_input_vectorized = num_vector_embeds is not None
132
+ self.is_input_patches = in_channels is not None and patch_size is not None
133
+
134
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
135
+ deprecation_message = (
136
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
137
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
138
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
139
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
140
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
141
+ )
142
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
143
+ norm_type = "ada_norm"
144
+
145
+ if self.is_input_continuous and self.is_input_vectorized:
146
+ raise ValueError(
147
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
148
+ " sure that either `in_channels` or `num_vector_embeds` is None."
149
+ )
150
+ elif self.is_input_vectorized and self.is_input_patches:
151
+ raise ValueError(
152
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
153
+ " sure that either `num_vector_embeds` or `num_patches` is None."
154
+ )
155
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
156
+ raise ValueError(
157
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
158
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
159
+ )
160
+
161
+ # 2. Define input layers
162
+ if self.is_input_continuous:
163
+ self.in_channels = in_channels
164
+
165
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
166
+ if use_linear_projection:
167
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
168
+ else:
169
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
170
+ elif self.is_input_vectorized:
171
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
172
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
173
+
174
+ self.height = sample_size
175
+ self.width = sample_size
176
+ self.num_vector_embeds = num_vector_embeds
177
+ self.num_latent_pixels = self.height * self.width
178
+
179
+ self.latent_image_embedding = ImagePositionalEmbeddings(
180
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
181
+ )
182
+ elif self.is_input_patches:
183
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
184
+
185
+ self.height = sample_size
186
+ self.width = sample_size
187
+
188
+ self.patch_size = patch_size
189
+ self.pos_embed = PatchEmbed(
190
+ height=sample_size,
191
+ width=sample_size,
192
+ patch_size=patch_size,
193
+ in_channels=in_channels,
194
+ embed_dim=inner_dim,
195
+ )
196
+
197
+ # 3. Define transformers blocks
198
+ self.transformer_blocks = nn.ModuleList(
199
+ [
200
+ BasicMVTransformerBlock(
201
+ inner_dim,
202
+ num_attention_heads,
203
+ attention_head_dim,
204
+ dropout=dropout,
205
+ cross_attention_dim=cross_attention_dim,
206
+ activation_fn=activation_fn,
207
+ num_embeds_ada_norm=num_embeds_ada_norm,
208
+ attention_bias=attention_bias,
209
+ only_cross_attention=only_cross_attention,
210
+ upcast_attention=upcast_attention,
211
+ norm_type=norm_type,
212
+ norm_elementwise_affine=norm_elementwise_affine,
213
+ num_views=num_views,
214
+ cd_attention_last=cd_attention_last,
215
+ cd_attention_mid=cd_attention_mid,
216
+ multiview_attention=multiview_attention,
217
+ sparse_mv_attention=sparse_mv_attention,
218
+ mvcd_attention=mvcd_attention
219
+ )
220
+ for d in range(num_layers)
221
+ ]
222
+ )
223
+
224
+ # 4. Define output layers
225
+ self.out_channels = in_channels if out_channels is None else out_channels
226
+ if self.is_input_continuous:
227
+ # TODO: should use out_channels for continuous projections
228
+ if use_linear_projection:
229
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
230
+ else:
231
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
232
+ elif self.is_input_vectorized:
233
+ self.norm_out = nn.LayerNorm(inner_dim)
234
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
235
+ elif self.is_input_patches:
236
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
237
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
238
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ encoder_hidden_states: Optional[torch.Tensor] = None,
244
+ timestep: Optional[torch.LongTensor] = None,
245
+ class_labels: Optional[torch.LongTensor] = None,
246
+ cross_attention_kwargs: Dict[str, Any] = None,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ encoder_attention_mask: Optional[torch.Tensor] = None,
249
+ return_dict: bool = True,
250
+ ):
251
+ """
252
+ The [`Transformer2DModel`] forward method.
253
+
254
+ Args:
255
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
256
+ Input `hidden_states`.
257
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
258
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
259
+ self-attention.
260
+ timestep ( `torch.LongTensor`, *optional*):
261
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
262
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
263
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
264
+ `AdaLayerZeroNorm`.
265
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
266
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
267
+
268
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
269
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
270
+
271
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
272
+ above. This bias will be added to the cross-attention scores.
273
+ return_dict (`bool`, *optional*, defaults to `True`):
274
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
275
+ tuple.
276
+
277
+ Returns:
278
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
279
+ `tuple` where the first element is the sample tensor.
280
+ """
281
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
282
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
283
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
284
+ # expects mask of shape:
285
+ # [batch, key_tokens]
286
+ # adds singleton query_tokens dimension:
287
+ # [batch, 1, key_tokens]
288
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
289
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
290
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
291
+ if attention_mask is not None and attention_mask.ndim == 2:
292
+ # assume that mask is expressed as:
293
+ # (1 = keep, 0 = discard)
294
+ # convert mask into a bias that can be added to attention scores:
295
+ # (keep = +0, discard = -10000.0)
296
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
297
+ attention_mask = attention_mask.unsqueeze(1)
298
+
299
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
300
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
301
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
302
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
303
+
304
+ # 1. Input
305
+ if self.is_input_continuous:
306
+ batch, _, height, width = hidden_states.shape
307
+ residual = hidden_states
308
+
309
+ hidden_states = self.norm(hidden_states)
310
+ if not self.use_linear_projection:
311
+ hidden_states = self.proj_in(hidden_states)
312
+ inner_dim = hidden_states.shape[1]
313
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
314
+ else:
315
+ inner_dim = hidden_states.shape[1]
316
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
317
+ hidden_states = self.proj_in(hidden_states)
318
+ elif self.is_input_vectorized:
319
+ hidden_states = self.latent_image_embedding(hidden_states)
320
+ elif self.is_input_patches:
321
+ hidden_states = self.pos_embed(hidden_states)
322
+
323
+ # 2. Blocks
324
+ for block in self.transformer_blocks:
325
+ hidden_states = block(
326
+ hidden_states,
327
+ attention_mask=attention_mask,
328
+ encoder_hidden_states=encoder_hidden_states,
329
+ encoder_attention_mask=encoder_attention_mask,
330
+ timestep=timestep,
331
+ cross_attention_kwargs=cross_attention_kwargs,
332
+ class_labels=class_labels,
333
+ )
334
+
335
+ # 3. Output
336
+ if self.is_input_continuous:
337
+ if not self.use_linear_projection:
338
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
339
+ hidden_states = self.proj_out(hidden_states)
340
+ else:
341
+ hidden_states = self.proj_out(hidden_states)
342
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
343
+
344
+ output = hidden_states + residual
345
+ elif self.is_input_vectorized:
346
+ hidden_states = self.norm_out(hidden_states)
347
+ logits = self.out(hidden_states)
348
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
349
+ logits = logits.permute(0, 2, 1)
350
+
351
+ # log(p(x_0))
352
+ output = F.log_softmax(logits.double(), dim=1).float()
353
+ elif self.is_input_patches:
354
+ # TODO: cleanup!
355
+ conditioning = self.transformer_blocks[0].norm1.emb(
356
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
357
+ )
358
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
359
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
360
+ hidden_states = self.proj_out_2(hidden_states)
361
+
362
+ # unpatchify
363
+ height = width = int(hidden_states.shape[1] ** 0.5)
364
+ hidden_states = hidden_states.reshape(
365
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
366
+ )
367
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
368
+ output = hidden_states.reshape(
369
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
370
+ )
371
+
372
+ if not return_dict:
373
+ return (output,)
374
+
375
+ return TransformerMV2DModelOutput(sample=output)
376
+
377
+
378
+ @maybe_allow_in_graph
379
+ class BasicMVTransformerBlock(nn.Module):
380
+ r"""
381
+ A basic Transformer block.
382
+
383
+ Parameters:
384
+ dim (`int`): The number of channels in the input and output.
385
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
386
+ attention_head_dim (`int`): The number of channels in each head.
387
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
388
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
389
+ only_cross_attention (`bool`, *optional*):
390
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
391
+ double_self_attention (`bool`, *optional*):
392
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
393
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
394
+ num_embeds_ada_norm (:
395
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
396
+ attention_bias (:
397
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ dim: int,
403
+ num_attention_heads: int,
404
+ attention_head_dim: int,
405
+ dropout=0.0,
406
+ cross_attention_dim: Optional[int] = None,
407
+ activation_fn: str = "geglu",
408
+ num_embeds_ada_norm: Optional[int] = None,
409
+ attention_bias: bool = False,
410
+ only_cross_attention: bool = False,
411
+ double_self_attention: bool = False,
412
+ upcast_attention: bool = False,
413
+ norm_elementwise_affine: bool = True,
414
+ norm_type: str = "layer_norm",
415
+ final_dropout: bool = False,
416
+ num_views: int = 1,
417
+ cd_attention_last: bool = False,
418
+ cd_attention_mid: bool = False,
419
+ multiview_attention: bool = True,
420
+ sparse_mv_attention: bool = False,
421
+ mvcd_attention: bool = False
422
+ ):
423
+ super().__init__()
424
+ self.only_cross_attention = only_cross_attention
425
+
426
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
427
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
428
+
429
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
430
+ raise ValueError(
431
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
432
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
433
+ )
434
+
435
+ # Define 3 blocks. Each block has its own normalization layer.
436
+ # 1. Self-Attn
437
+ if self.use_ada_layer_norm:
438
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
439
+ elif self.use_ada_layer_norm_zero:
440
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
441
+ else:
442
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
443
+
444
+ self.multiview_attention = multiview_attention
445
+ self.sparse_mv_attention = sparse_mv_attention
446
+ self.mvcd_attention = mvcd_attention
447
+
448
+ self.attn1 = CustomAttention(
449
+ query_dim=dim,
450
+ heads=num_attention_heads,
451
+ dim_head=attention_head_dim,
452
+ dropout=dropout,
453
+ bias=attention_bias,
454
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
455
+ upcast_attention=upcast_attention,
456
+ processor=MVAttnProcessor()
457
+ )
458
+
459
+ # 2. Cross-Attn
460
+ if cross_attention_dim is not None or double_self_attention:
461
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
462
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
463
+ # the second cross attention block.
464
+ self.norm2 = (
465
+ AdaLayerNorm(dim, num_embeds_ada_norm)
466
+ if self.use_ada_layer_norm
467
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
468
+ )
469
+ self.attn2 = Attention(
470
+ query_dim=dim,
471
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
472
+ heads=num_attention_heads,
473
+ dim_head=attention_head_dim,
474
+ dropout=dropout,
475
+ bias=attention_bias,
476
+ upcast_attention=upcast_attention,
477
+ ) # is self-attn if encoder_hidden_states is none
478
+ else:
479
+ self.norm2 = None
480
+ self.attn2 = None
481
+
482
+ # 3. Feed-forward
483
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
484
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
485
+
486
+ # let chunk size default to None
487
+ self._chunk_size = None
488
+ self._chunk_dim = 0
489
+
490
+ self.num_views = num_views
491
+
492
+ self.cd_attention_last = cd_attention_last
493
+
494
+ if self.cd_attention_last:
495
+ # Joint task -Attn
496
+ self.attn_joint_last = CustomJointAttention(
497
+ query_dim=dim,
498
+ heads=num_attention_heads,
499
+ dim_head=attention_head_dim,
500
+ dropout=dropout,
501
+ bias=attention_bias,
502
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
503
+ upcast_attention=upcast_attention,
504
+ processor=JointAttnProcessor()
505
+ )
506
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
507
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
508
+
509
+
510
+ self.cd_attention_mid = cd_attention_mid
511
+
512
+ if self.cd_attention_mid:
513
+ # print("cross-domain attn in the middle")
514
+ # Joint task -Attn
515
+ self.attn_joint_mid = CustomJointAttention(
516
+ query_dim=dim,
517
+ heads=num_attention_heads,
518
+ dim_head=attention_head_dim,
519
+ dropout=dropout,
520
+ bias=attention_bias,
521
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
522
+ upcast_attention=upcast_attention,
523
+ processor=JointAttnProcessor()
524
+ )
525
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
526
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
527
+
528
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
529
+ # Sets chunk feed-forward
530
+ self._chunk_size = chunk_size
531
+ self._chunk_dim = dim
532
+
533
+ def forward(
534
+ self,
535
+ hidden_states: torch.FloatTensor,
536
+ attention_mask: Optional[torch.FloatTensor] = None,
537
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
538
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
539
+ timestep: Optional[torch.LongTensor] = None,
540
+ cross_attention_kwargs: Dict[str, Any] = None,
541
+ class_labels: Optional[torch.LongTensor] = None,
542
+ ):
543
+ assert attention_mask is None # not supported yet
544
+ # Notice that normalization is always applied before the real computation in the following blocks.
545
+ # 1. Self-Attention
546
+ if self.use_ada_layer_norm:
547
+ norm_hidden_states = self.norm1(hidden_states, timestep)
548
+ elif self.use_ada_layer_norm_zero:
549
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
550
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
551
+ )
552
+ else:
553
+ norm_hidden_states = self.norm1(hidden_states)
554
+
555
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
556
+
557
+ attn_output = self.attn1(
558
+ norm_hidden_states,
559
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
560
+ attention_mask=attention_mask,
561
+ num_views=self.num_views,
562
+ multiview_attention=self.multiview_attention,
563
+ sparse_mv_attention=self.sparse_mv_attention,
564
+ mvcd_attention=self.mvcd_attention,
565
+ **cross_attention_kwargs,
566
+ )
567
+
568
+
569
+ if self.use_ada_layer_norm_zero:
570
+ attn_output = gate_msa.unsqueeze(1) * attn_output
571
+ hidden_states = attn_output + hidden_states
572
+
573
+ # joint attention twice
574
+ if self.cd_attention_mid:
575
+ norm_hidden_states = (
576
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
577
+ )
578
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
579
+
580
+ # 2. Cross-Attention
581
+ if self.attn2 is not None:
582
+ norm_hidden_states = (
583
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
584
+ )
585
+
586
+ attn_output = self.attn2(
587
+ norm_hidden_states,
588
+ encoder_hidden_states=encoder_hidden_states,
589
+ attention_mask=encoder_attention_mask,
590
+ **cross_attention_kwargs,
591
+ )
592
+ hidden_states = attn_output + hidden_states
593
+
594
+ # 3. Feed-forward
595
+ norm_hidden_states = self.norm3(hidden_states)
596
+
597
+ if self.use_ada_layer_norm_zero:
598
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
599
+
600
+ if self._chunk_size is not None:
601
+ # "feed_forward_chunk_size" can be used to save memory
602
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
603
+ raise ValueError(
604
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
605
+ )
606
+
607
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
608
+ ff_output = torch.cat(
609
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
610
+ dim=self._chunk_dim,
611
+ )
612
+ else:
613
+ ff_output = self.ff(norm_hidden_states)
614
+
615
+ if self.use_ada_layer_norm_zero:
616
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
617
+
618
+ hidden_states = ff_output + hidden_states
619
+
620
+ if self.cd_attention_last:
621
+ norm_hidden_states = (
622
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
623
+ )
624
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
625
+
626
+ return hidden_states
627
+
628
+
629
+ class CustomAttention(Attention):
630
+ def set_use_memory_efficient_attention_xformers(
631
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
632
+ ):
633
+ processor = XFormersMVAttnProcessor()
634
+ self.set_processor(processor)
635
+ # print("using xformers attention processor")
636
+
637
+
638
+ class CustomJointAttention(Attention):
639
+ def set_use_memory_efficient_attention_xformers(
640
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
641
+ ):
642
+ processor = XFormersJointAttnProcessor()
643
+ self.set_processor(processor)
644
+ # print("using xformers attention processor")
645
+
646
+ class MVAttnProcessor:
647
+ r"""
648
+ Default processor for performing attention-related computations.
649
+ """
650
+
651
+ def __call__(
652
+ self,
653
+ attn: Attention,
654
+ hidden_states,
655
+ encoder_hidden_states=None,
656
+ attention_mask=None,
657
+ temb=None,
658
+ num_views=1,
659
+ multiview_attention=True
660
+ ):
661
+ residual = hidden_states
662
+
663
+ if attn.spatial_norm is not None:
664
+ hidden_states = attn.spatial_norm(hidden_states, temb)
665
+
666
+ input_ndim = hidden_states.ndim
667
+
668
+ if input_ndim == 4:
669
+ batch_size, channel, height, width = hidden_states.shape
670
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
671
+
672
+ batch_size, sequence_length, _ = (
673
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
674
+ )
675
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
676
+
677
+ if attn.group_norm is not None:
678
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
679
+
680
+ query = attn.to_q(hidden_states)
681
+
682
+ if encoder_hidden_states is None:
683
+ encoder_hidden_states = hidden_states
684
+ elif attn.norm_cross:
685
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
686
+
687
+ key = attn.to_k(encoder_hidden_states)
688
+ value = attn.to_v(encoder_hidden_states)
689
+
690
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
691
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
692
+ # pdb.set_trace()
693
+ # multi-view self-attention
694
+ if multiview_attention:
695
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
696
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
697
+
698
+ query = attn.head_to_batch_dim(query).contiguous()
699
+ key = attn.head_to_batch_dim(key).contiguous()
700
+ value = attn.head_to_batch_dim(value).contiguous()
701
+
702
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
703
+ hidden_states = torch.bmm(attention_probs, value)
704
+ hidden_states = attn.batch_to_head_dim(hidden_states)
705
+
706
+ # linear proj
707
+ hidden_states = attn.to_out[0](hidden_states)
708
+ # dropout
709
+ hidden_states = attn.to_out[1](hidden_states)
710
+
711
+ if input_ndim == 4:
712
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
713
+
714
+ if attn.residual_connection:
715
+ hidden_states = hidden_states + residual
716
+
717
+ hidden_states = hidden_states / attn.rescale_output_factor
718
+
719
+ return hidden_states
720
+
721
+
722
+ class XFormersMVAttnProcessor:
723
+ r"""
724
+ Default processor for performing attention-related computations.
725
+ """
726
+
727
+ def __call__(
728
+ self,
729
+ attn: Attention,
730
+ hidden_states,
731
+ encoder_hidden_states=None,
732
+ attention_mask=None,
733
+ temb=None,
734
+ num_views=1.,
735
+ multiview_attention=True,
736
+ sparse_mv_attention=False,
737
+ mvcd_attention=False,
738
+ ):
739
+ residual = hidden_states
740
+
741
+ if attn.spatial_norm is not None:
742
+ hidden_states = attn.spatial_norm(hidden_states, temb)
743
+
744
+ input_ndim = hidden_states.ndim
745
+
746
+ if input_ndim == 4:
747
+ batch_size, channel, height, width = hidden_states.shape
748
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
749
+
750
+ batch_size, sequence_length, _ = (
751
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
752
+ )
753
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
754
+
755
+ # from yuancheng; here attention_mask is None
756
+ if attention_mask is not None:
757
+ # expand our mask's singleton query_tokens dimension:
758
+ # [batch*heads, 1, key_tokens] ->
759
+ # [batch*heads, query_tokens, key_tokens]
760
+ # so that it can be added as a bias onto the attention scores that xformers computes:
761
+ # [batch*heads, query_tokens, key_tokens]
762
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
763
+ _, query_tokens, _ = hidden_states.shape
764
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
765
+
766
+ if attn.group_norm is not None:
767
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
768
+
769
+ query = attn.to_q(hidden_states)
770
+
771
+ if encoder_hidden_states is None:
772
+ encoder_hidden_states = hidden_states
773
+ elif attn.norm_cross:
774
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
775
+
776
+ key_raw = attn.to_k(encoder_hidden_states)
777
+ value_raw = attn.to_v(encoder_hidden_states)
778
+
779
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
780
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
781
+ # pdb.set_trace()
782
+ # multi-view self-attention
783
+ if multiview_attention:
784
+ if not sparse_mv_attention:
785
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
786
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
787
+ else:
788
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
789
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
790
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
791
+ value = torch.cat([value_front, value_raw], dim=1)
792
+
793
+ else:
794
+ # print("don't use multiview attention.")
795
+ key = key_raw
796
+ value = value_raw
797
+
798
+ query = attn.head_to_batch_dim(query)
799
+ key = attn.head_to_batch_dim(key)
800
+ value = attn.head_to_batch_dim(value)
801
+
802
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
803
+ hidden_states = attn.batch_to_head_dim(hidden_states)
804
+
805
+ # linear proj
806
+ hidden_states = attn.to_out[0](hidden_states)
807
+ # dropout
808
+ hidden_states = attn.to_out[1](hidden_states)
809
+
810
+ if input_ndim == 4:
811
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
812
+
813
+ if attn.residual_connection:
814
+ hidden_states = hidden_states + residual
815
+
816
+ hidden_states = hidden_states / attn.rescale_output_factor
817
+
818
+ return hidden_states
819
+
820
+
821
+
822
+ class XFormersJointAttnProcessor:
823
+ r"""
824
+ Default processor for performing attention-related computations.
825
+ """
826
+
827
+ def __call__(
828
+ self,
829
+ attn: Attention,
830
+ hidden_states,
831
+ encoder_hidden_states=None,
832
+ attention_mask=None,
833
+ temb=None,
834
+ num_tasks=2
835
+ ):
836
+
837
+ residual = hidden_states
838
+
839
+ if attn.spatial_norm is not None:
840
+ hidden_states = attn.spatial_norm(hidden_states, temb)
841
+
842
+ input_ndim = hidden_states.ndim
843
+
844
+ if input_ndim == 4:
845
+ batch_size, channel, height, width = hidden_states.shape
846
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
847
+
848
+ batch_size, sequence_length, _ = (
849
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
850
+ )
851
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
852
+
853
+ # from yuancheng; here attention_mask is None
854
+ if attention_mask is not None:
855
+ # expand our mask's singleton query_tokens dimension:
856
+ # [batch*heads, 1, key_tokens] ->
857
+ # [batch*heads, query_tokens, key_tokens]
858
+ # so that it can be added as a bias onto the attention scores that xformers computes:
859
+ # [batch*heads, query_tokens, key_tokens]
860
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
861
+ _, query_tokens, _ = hidden_states.shape
862
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
863
+
864
+ if attn.group_norm is not None:
865
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
866
+
867
+ query = attn.to_q(hidden_states)
868
+
869
+ if encoder_hidden_states is None:
870
+ encoder_hidden_states = hidden_states
871
+ elif attn.norm_cross:
872
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
873
+
874
+ key = attn.to_k(encoder_hidden_states)
875
+ value = attn.to_v(encoder_hidden_states)
876
+
877
+ assert num_tasks == 2 # only support two tasks now
878
+
879
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
880
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
881
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
882
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
883
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
884
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
885
+
886
+
887
+ query = attn.head_to_batch_dim(query).contiguous()
888
+ key = attn.head_to_batch_dim(key).contiguous()
889
+ value = attn.head_to_batch_dim(value).contiguous()
890
+
891
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
892
+ hidden_states = attn.batch_to_head_dim(hidden_states)
893
+
894
+ # linear proj
895
+ hidden_states = attn.to_out[0](hidden_states)
896
+ # dropout
897
+ hidden_states = attn.to_out[1](hidden_states)
898
+
899
+ if input_ndim == 4:
900
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
901
+
902
+ if attn.residual_connection:
903
+ hidden_states = hidden_states + residual
904
+
905
+ hidden_states = hidden_states / attn.rescale_output_factor
906
+
907
+ return hidden_states
908
+
909
+
910
+ class JointAttnProcessor:
911
+ r"""
912
+ Default processor for performing attention-related computations.
913
+ """
914
+
915
+ def __call__(
916
+ self,
917
+ attn: Attention,
918
+ hidden_states,
919
+ encoder_hidden_states=None,
920
+ attention_mask=None,
921
+ temb=None,
922
+ num_tasks=2
923
+ ):
924
+
925
+ residual = hidden_states
926
+
927
+ if attn.spatial_norm is not None:
928
+ hidden_states = attn.spatial_norm(hidden_states, temb)
929
+
930
+ input_ndim = hidden_states.ndim
931
+
932
+ if input_ndim == 4:
933
+ batch_size, channel, height, width = hidden_states.shape
934
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
935
+
936
+ batch_size, sequence_length, _ = (
937
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
938
+ )
939
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
940
+
941
+
942
+ if attn.group_norm is not None:
943
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
944
+
945
+ query = attn.to_q(hidden_states)
946
+
947
+ if encoder_hidden_states is None:
948
+ encoder_hidden_states = hidden_states
949
+ elif attn.norm_cross:
950
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
951
+
952
+ key = attn.to_k(encoder_hidden_states)
953
+ value = attn.to_v(encoder_hidden_states)
954
+
955
+ assert num_tasks == 2 # only support two tasks now
956
+
957
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
958
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
959
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
960
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
961
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
962
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
963
+
964
+
965
+ query = attn.head_to_batch_dim(query).contiguous()
966
+ key = attn.head_to_batch_dim(key).contiguous()
967
+ value = attn.head_to_batch_dim(value).contiguous()
968
+
969
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
970
+ hidden_states = torch.bmm(attention_probs, value)
971
+ hidden_states = attn.batch_to_head_dim(hidden_states)
972
+
973
+ # linear proj
974
+ hidden_states = attn.to_out[0](hidden_states)
975
+ # dropout
976
+ hidden_states = attn.to_out[1](hidden_states)
977
+
978
+ if input_ndim == 4:
979
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
980
+
981
+ if attn.residual_connection:
982
+ hidden_states = hidden_states + residual
983
+
984
+ hidden_states = hidden_states / attn.rescale_output_factor
985
+
986
+ return hidden_states
mvdiffusion/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.attention import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from mvdiffusion.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1,
60
+ cd_attention_last: bool = False,
61
+ cd_attention_mid: bool = False,
62
+ multiview_attention: bool = True,
63
+ sparse_mv_attention: bool = False,
64
+ mvcd_attention: bool=False
65
+ ):
66
+ # If attn head dim is not defined, we default it to the number of heads
67
+ if attention_head_dim is None:
68
+ logger.warn(
69
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
70
+ )
71
+ attention_head_dim = num_attention_heads
72
+
73
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
74
+ if down_block_type == "DownBlock2D":
75
+ return DownBlock2D(
76
+ num_layers=num_layers,
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ temb_channels=temb_channels,
80
+ add_downsample=add_downsample,
81
+ resnet_eps=resnet_eps,
82
+ resnet_act_fn=resnet_act_fn,
83
+ resnet_groups=resnet_groups,
84
+ downsample_padding=downsample_padding,
85
+ resnet_time_scale_shift=resnet_time_scale_shift,
86
+ )
87
+ elif down_block_type == "ResnetDownsampleBlock2D":
88
+ return ResnetDownsampleBlock2D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ temb_channels=temb_channels,
93
+ add_downsample=add_downsample,
94
+ resnet_eps=resnet_eps,
95
+ resnet_act_fn=resnet_act_fn,
96
+ resnet_groups=resnet_groups,
97
+ resnet_time_scale_shift=resnet_time_scale_shift,
98
+ skip_time_act=resnet_skip_time_act,
99
+ output_scale_factor=resnet_out_scale_factor,
100
+ )
101
+ elif down_block_type == "AttnDownBlock2D":
102
+ if add_downsample is False:
103
+ downsample_type = None
104
+ else:
105
+ downsample_type = downsample_type or "conv" # default to 'conv'
106
+ return AttnDownBlock2D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ temb_channels=temb_channels,
111
+ resnet_eps=resnet_eps,
112
+ resnet_act_fn=resnet_act_fn,
113
+ resnet_groups=resnet_groups,
114
+ downsample_padding=downsample_padding,
115
+ attention_head_dim=attention_head_dim,
116
+ resnet_time_scale_shift=resnet_time_scale_shift,
117
+ downsample_type=downsample_type,
118
+ )
119
+ elif down_block_type == "CrossAttnDownBlock2D":
120
+ if cross_attention_dim is None:
121
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
122
+ return CrossAttnDownBlock2D(
123
+ num_layers=num_layers,
124
+ transformer_layers_per_block=transformer_layers_per_block,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ temb_channels=temb_channels,
128
+ add_downsample=add_downsample,
129
+ resnet_eps=resnet_eps,
130
+ resnet_act_fn=resnet_act_fn,
131
+ resnet_groups=resnet_groups,
132
+ downsample_padding=downsample_padding,
133
+ cross_attention_dim=cross_attention_dim,
134
+ num_attention_heads=num_attention_heads,
135
+ dual_cross_attention=dual_cross_attention,
136
+ use_linear_projection=use_linear_projection,
137
+ only_cross_attention=only_cross_attention,
138
+ upcast_attention=upcast_attention,
139
+ resnet_time_scale_shift=resnet_time_scale_shift,
140
+ )
141
+ # custom MV2D attention block
142
+ elif down_block_type == "CrossAttnDownBlockMV2D":
143
+ if cross_attention_dim is None:
144
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
145
+ return CrossAttnDownBlockMV2D(
146
+ num_layers=num_layers,
147
+ transformer_layers_per_block=transformer_layers_per_block,
148
+ in_channels=in_channels,
149
+ out_channels=out_channels,
150
+ temb_channels=temb_channels,
151
+ add_downsample=add_downsample,
152
+ resnet_eps=resnet_eps,
153
+ resnet_act_fn=resnet_act_fn,
154
+ resnet_groups=resnet_groups,
155
+ downsample_padding=downsample_padding,
156
+ cross_attention_dim=cross_attention_dim,
157
+ num_attention_heads=num_attention_heads,
158
+ dual_cross_attention=dual_cross_attention,
159
+ use_linear_projection=use_linear_projection,
160
+ only_cross_attention=only_cross_attention,
161
+ upcast_attention=upcast_attention,
162
+ resnet_time_scale_shift=resnet_time_scale_shift,
163
+ num_views=num_views,
164
+ cd_attention_last=cd_attention_last,
165
+ cd_attention_mid=cd_attention_mid,
166
+ multiview_attention=multiview_attention,
167
+ sparse_mv_attention=sparse_mv_attention,
168
+ mvcd_attention=mvcd_attention
169
+ )
170
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
171
+ if cross_attention_dim is None:
172
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
173
+ return SimpleCrossAttnDownBlock2D(
174
+ num_layers=num_layers,
175
+ in_channels=in_channels,
176
+ out_channels=out_channels,
177
+ temb_channels=temb_channels,
178
+ add_downsample=add_downsample,
179
+ resnet_eps=resnet_eps,
180
+ resnet_act_fn=resnet_act_fn,
181
+ resnet_groups=resnet_groups,
182
+ cross_attention_dim=cross_attention_dim,
183
+ attention_head_dim=attention_head_dim,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ skip_time_act=resnet_skip_time_act,
186
+ output_scale_factor=resnet_out_scale_factor,
187
+ only_cross_attention=only_cross_attention,
188
+ cross_attention_norm=cross_attention_norm,
189
+ )
190
+ elif down_block_type == "SkipDownBlock2D":
191
+ return SkipDownBlock2D(
192
+ num_layers=num_layers,
193
+ in_channels=in_channels,
194
+ out_channels=out_channels,
195
+ temb_channels=temb_channels,
196
+ add_downsample=add_downsample,
197
+ resnet_eps=resnet_eps,
198
+ resnet_act_fn=resnet_act_fn,
199
+ downsample_padding=downsample_padding,
200
+ resnet_time_scale_shift=resnet_time_scale_shift,
201
+ )
202
+ elif down_block_type == "AttnSkipDownBlock2D":
203
+ return AttnSkipDownBlock2D(
204
+ num_layers=num_layers,
205
+ in_channels=in_channels,
206
+ out_channels=out_channels,
207
+ temb_channels=temb_channels,
208
+ add_downsample=add_downsample,
209
+ resnet_eps=resnet_eps,
210
+ resnet_act_fn=resnet_act_fn,
211
+ attention_head_dim=attention_head_dim,
212
+ resnet_time_scale_shift=resnet_time_scale_shift,
213
+ )
214
+ elif down_block_type == "DownEncoderBlock2D":
215
+ return DownEncoderBlock2D(
216
+ num_layers=num_layers,
217
+ in_channels=in_channels,
218
+ out_channels=out_channels,
219
+ add_downsample=add_downsample,
220
+ resnet_eps=resnet_eps,
221
+ resnet_act_fn=resnet_act_fn,
222
+ resnet_groups=resnet_groups,
223
+ downsample_padding=downsample_padding,
224
+ resnet_time_scale_shift=resnet_time_scale_shift,
225
+ )
226
+ elif down_block_type == "AttnDownEncoderBlock2D":
227
+ return AttnDownEncoderBlock2D(
228
+ num_layers=num_layers,
229
+ in_channels=in_channels,
230
+ out_channels=out_channels,
231
+ add_downsample=add_downsample,
232
+ resnet_eps=resnet_eps,
233
+ resnet_act_fn=resnet_act_fn,
234
+ resnet_groups=resnet_groups,
235
+ downsample_padding=downsample_padding,
236
+ attention_head_dim=attention_head_dim,
237
+ resnet_time_scale_shift=resnet_time_scale_shift,
238
+ )
239
+ elif down_block_type == "KDownBlock2D":
240
+ return KDownBlock2D(
241
+ num_layers=num_layers,
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ add_downsample=add_downsample,
246
+ resnet_eps=resnet_eps,
247
+ resnet_act_fn=resnet_act_fn,
248
+ )
249
+ elif down_block_type == "KCrossAttnDownBlock2D":
250
+ return KCrossAttnDownBlock2D(
251
+ num_layers=num_layers,
252
+ in_channels=in_channels,
253
+ out_channels=out_channels,
254
+ temb_channels=temb_channels,
255
+ add_downsample=add_downsample,
256
+ resnet_eps=resnet_eps,
257
+ resnet_act_fn=resnet_act_fn,
258
+ cross_attention_dim=cross_attention_dim,
259
+ attention_head_dim=attention_head_dim,
260
+ add_self_attention=True if not add_downsample else False,
261
+ )
262
+ raise ValueError(f"{down_block_type} does not exist.")
263
+
264
+
265
+ def get_up_block(
266
+ up_block_type,
267
+ num_layers,
268
+ in_channels,
269
+ out_channels,
270
+ prev_output_channel,
271
+ temb_channels,
272
+ add_upsample,
273
+ resnet_eps,
274
+ resnet_act_fn,
275
+ transformer_layers_per_block=1,
276
+ num_attention_heads=None,
277
+ resnet_groups=None,
278
+ cross_attention_dim=None,
279
+ dual_cross_attention=False,
280
+ use_linear_projection=False,
281
+ only_cross_attention=False,
282
+ upcast_attention=False,
283
+ resnet_time_scale_shift="default",
284
+ resnet_skip_time_act=False,
285
+ resnet_out_scale_factor=1.0,
286
+ cross_attention_norm=None,
287
+ attention_head_dim=None,
288
+ upsample_type=None,
289
+ num_views=1,
290
+ cd_attention_last: bool = False,
291
+ cd_attention_mid: bool = False,
292
+ multiview_attention: bool = True,
293
+ sparse_mv_attention: bool = False,
294
+ mvcd_attention: bool=False
295
+ ):
296
+ # If attn head dim is not defined, we default it to the number of heads
297
+ if attention_head_dim is None:
298
+ logger.warn(
299
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
300
+ )
301
+ attention_head_dim = num_attention_heads
302
+
303
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
304
+ if up_block_type == "UpBlock2D":
305
+ return UpBlock2D(
306
+ num_layers=num_layers,
307
+ in_channels=in_channels,
308
+ out_channels=out_channels,
309
+ prev_output_channel=prev_output_channel,
310
+ temb_channels=temb_channels,
311
+ add_upsample=add_upsample,
312
+ resnet_eps=resnet_eps,
313
+ resnet_act_fn=resnet_act_fn,
314
+ resnet_groups=resnet_groups,
315
+ resnet_time_scale_shift=resnet_time_scale_shift,
316
+ )
317
+ elif up_block_type == "ResnetUpsampleBlock2D":
318
+ return ResnetUpsampleBlock2D(
319
+ num_layers=num_layers,
320
+ in_channels=in_channels,
321
+ out_channels=out_channels,
322
+ prev_output_channel=prev_output_channel,
323
+ temb_channels=temb_channels,
324
+ add_upsample=add_upsample,
325
+ resnet_eps=resnet_eps,
326
+ resnet_act_fn=resnet_act_fn,
327
+ resnet_groups=resnet_groups,
328
+ resnet_time_scale_shift=resnet_time_scale_shift,
329
+ skip_time_act=resnet_skip_time_act,
330
+ output_scale_factor=resnet_out_scale_factor,
331
+ )
332
+ elif up_block_type == "CrossAttnUpBlock2D":
333
+ if cross_attention_dim is None:
334
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
335
+ return CrossAttnUpBlock2D(
336
+ num_layers=num_layers,
337
+ transformer_layers_per_block=transformer_layers_per_block,
338
+ in_channels=in_channels,
339
+ out_channels=out_channels,
340
+ prev_output_channel=prev_output_channel,
341
+ temb_channels=temb_channels,
342
+ add_upsample=add_upsample,
343
+ resnet_eps=resnet_eps,
344
+ resnet_act_fn=resnet_act_fn,
345
+ resnet_groups=resnet_groups,
346
+ cross_attention_dim=cross_attention_dim,
347
+ num_attention_heads=num_attention_heads,
348
+ dual_cross_attention=dual_cross_attention,
349
+ use_linear_projection=use_linear_projection,
350
+ only_cross_attention=only_cross_attention,
351
+ upcast_attention=upcast_attention,
352
+ resnet_time_scale_shift=resnet_time_scale_shift,
353
+ )
354
+ # custom MV2D attention block
355
+ elif up_block_type == "CrossAttnUpBlockMV2D":
356
+ if cross_attention_dim is None:
357
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
358
+ return CrossAttnUpBlockMV2D(
359
+ num_layers=num_layers,
360
+ transformer_layers_per_block=transformer_layers_per_block,
361
+ in_channels=in_channels,
362
+ out_channels=out_channels,
363
+ prev_output_channel=prev_output_channel,
364
+ temb_channels=temb_channels,
365
+ add_upsample=add_upsample,
366
+ resnet_eps=resnet_eps,
367
+ resnet_act_fn=resnet_act_fn,
368
+ resnet_groups=resnet_groups,
369
+ cross_attention_dim=cross_attention_dim,
370
+ num_attention_heads=num_attention_heads,
371
+ dual_cross_attention=dual_cross_attention,
372
+ use_linear_projection=use_linear_projection,
373
+ only_cross_attention=only_cross_attention,
374
+ upcast_attention=upcast_attention,
375
+ resnet_time_scale_shift=resnet_time_scale_shift,
376
+ num_views=num_views,
377
+ cd_attention_last=cd_attention_last,
378
+ cd_attention_mid=cd_attention_mid,
379
+ multiview_attention=multiview_attention,
380
+ sparse_mv_attention=sparse_mv_attention,
381
+ mvcd_attention=mvcd_attention
382
+ )
383
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
384
+ if cross_attention_dim is None:
385
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
386
+ return SimpleCrossAttnUpBlock2D(
387
+ num_layers=num_layers,
388
+ in_channels=in_channels,
389
+ out_channels=out_channels,
390
+ prev_output_channel=prev_output_channel,
391
+ temb_channels=temb_channels,
392
+ add_upsample=add_upsample,
393
+ resnet_eps=resnet_eps,
394
+ resnet_act_fn=resnet_act_fn,
395
+ resnet_groups=resnet_groups,
396
+ cross_attention_dim=cross_attention_dim,
397
+ attention_head_dim=attention_head_dim,
398
+ resnet_time_scale_shift=resnet_time_scale_shift,
399
+ skip_time_act=resnet_skip_time_act,
400
+ output_scale_factor=resnet_out_scale_factor,
401
+ only_cross_attention=only_cross_attention,
402
+ cross_attention_norm=cross_attention_norm,
403
+ )
404
+ elif up_block_type == "AttnUpBlock2D":
405
+ if add_upsample is False:
406
+ upsample_type = None
407
+ else:
408
+ upsample_type = upsample_type or "conv" # default to 'conv'
409
+
410
+ return AttnUpBlock2D(
411
+ num_layers=num_layers,
412
+ in_channels=in_channels,
413
+ out_channels=out_channels,
414
+ prev_output_channel=prev_output_channel,
415
+ temb_channels=temb_channels,
416
+ resnet_eps=resnet_eps,
417
+ resnet_act_fn=resnet_act_fn,
418
+ resnet_groups=resnet_groups,
419
+ attention_head_dim=attention_head_dim,
420
+ resnet_time_scale_shift=resnet_time_scale_shift,
421
+ upsample_type=upsample_type,
422
+ )
423
+ elif up_block_type == "SkipUpBlock2D":
424
+ return SkipUpBlock2D(
425
+ num_layers=num_layers,
426
+ in_channels=in_channels,
427
+ out_channels=out_channels,
428
+ prev_output_channel=prev_output_channel,
429
+ temb_channels=temb_channels,
430
+ add_upsample=add_upsample,
431
+ resnet_eps=resnet_eps,
432
+ resnet_act_fn=resnet_act_fn,
433
+ resnet_time_scale_shift=resnet_time_scale_shift,
434
+ )
435
+ elif up_block_type == "AttnSkipUpBlock2D":
436
+ return AttnSkipUpBlock2D(
437
+ num_layers=num_layers,
438
+ in_channels=in_channels,
439
+ out_channels=out_channels,
440
+ prev_output_channel=prev_output_channel,
441
+ temb_channels=temb_channels,
442
+ add_upsample=add_upsample,
443
+ resnet_eps=resnet_eps,
444
+ resnet_act_fn=resnet_act_fn,
445
+ attention_head_dim=attention_head_dim,
446
+ resnet_time_scale_shift=resnet_time_scale_shift,
447
+ )
448
+ elif up_block_type == "UpDecoderBlock2D":
449
+ return UpDecoderBlock2D(
450
+ num_layers=num_layers,
451
+ in_channels=in_channels,
452
+ out_channels=out_channels,
453
+ add_upsample=add_upsample,
454
+ resnet_eps=resnet_eps,
455
+ resnet_act_fn=resnet_act_fn,
456
+ resnet_groups=resnet_groups,
457
+ resnet_time_scale_shift=resnet_time_scale_shift,
458
+ temb_channels=temb_channels,
459
+ )
460
+ elif up_block_type == "AttnUpDecoderBlock2D":
461
+ return AttnUpDecoderBlock2D(
462
+ num_layers=num_layers,
463
+ in_channels=in_channels,
464
+ out_channels=out_channels,
465
+ add_upsample=add_upsample,
466
+ resnet_eps=resnet_eps,
467
+ resnet_act_fn=resnet_act_fn,
468
+ resnet_groups=resnet_groups,
469
+ attention_head_dim=attention_head_dim,
470
+ resnet_time_scale_shift=resnet_time_scale_shift,
471
+ temb_channels=temb_channels,
472
+ )
473
+ elif up_block_type == "KUpBlock2D":
474
+ return KUpBlock2D(
475
+ num_layers=num_layers,
476
+ in_channels=in_channels,
477
+ out_channels=out_channels,
478
+ temb_channels=temb_channels,
479
+ add_upsample=add_upsample,
480
+ resnet_eps=resnet_eps,
481
+ resnet_act_fn=resnet_act_fn,
482
+ )
483
+ elif up_block_type == "KCrossAttnUpBlock2D":
484
+ return KCrossAttnUpBlock2D(
485
+ num_layers=num_layers,
486
+ in_channels=in_channels,
487
+ out_channels=out_channels,
488
+ temb_channels=temb_channels,
489
+ add_upsample=add_upsample,
490
+ resnet_eps=resnet_eps,
491
+ resnet_act_fn=resnet_act_fn,
492
+ cross_attention_dim=cross_attention_dim,
493
+ attention_head_dim=attention_head_dim,
494
+ )
495
+
496
+ raise ValueError(f"{up_block_type} does not exist.")
497
+
498
+
499
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
500
+ def __init__(
501
+ self,
502
+ in_channels: int,
503
+ temb_channels: int,
504
+ dropout: float = 0.0,
505
+ num_layers: int = 1,
506
+ transformer_layers_per_block: int = 1,
507
+ resnet_eps: float = 1e-6,
508
+ resnet_time_scale_shift: str = "default",
509
+ resnet_act_fn: str = "swish",
510
+ resnet_groups: int = 32,
511
+ resnet_pre_norm: bool = True,
512
+ num_attention_heads=1,
513
+ output_scale_factor=1.0,
514
+ cross_attention_dim=1280,
515
+ dual_cross_attention=False,
516
+ use_linear_projection=False,
517
+ upcast_attention=False,
518
+ num_views: int = 1,
519
+ cd_attention_last: bool = False,
520
+ cd_attention_mid: bool = False,
521
+ multiview_attention: bool = True,
522
+ sparse_mv_attention: bool = False,
523
+ mvcd_attention: bool=False
524
+ ):
525
+ super().__init__()
526
+
527
+ self.has_cross_attention = True
528
+ self.num_attention_heads = num_attention_heads
529
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
530
+
531
+ # there is always at least one resnet
532
+ resnets = [
533
+ ResnetBlock2D(
534
+ in_channels=in_channels,
535
+ out_channels=in_channels,
536
+ temb_channels=temb_channels,
537
+ eps=resnet_eps,
538
+ groups=resnet_groups,
539
+ dropout=dropout,
540
+ time_embedding_norm=resnet_time_scale_shift,
541
+ non_linearity=resnet_act_fn,
542
+ output_scale_factor=output_scale_factor,
543
+ pre_norm=resnet_pre_norm,
544
+ )
545
+ ]
546
+ attentions = []
547
+
548
+ for _ in range(num_layers):
549
+ if not dual_cross_attention:
550
+ attentions.append(
551
+ TransformerMV2DModel(
552
+ num_attention_heads,
553
+ in_channels // num_attention_heads,
554
+ in_channels=in_channels,
555
+ num_layers=transformer_layers_per_block,
556
+ cross_attention_dim=cross_attention_dim,
557
+ norm_num_groups=resnet_groups,
558
+ use_linear_projection=use_linear_projection,
559
+ upcast_attention=upcast_attention,
560
+ num_views=num_views,
561
+ cd_attention_last=cd_attention_last,
562
+ cd_attention_mid=cd_attention_mid,
563
+ multiview_attention=multiview_attention,
564
+ sparse_mv_attention=sparse_mv_attention,
565
+ mvcd_attention=mvcd_attention
566
+ )
567
+ )
568
+ else:
569
+ raise NotImplementedError
570
+ resnets.append(
571
+ ResnetBlock2D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ temb_channels=temb_channels,
575
+ eps=resnet_eps,
576
+ groups=resnet_groups,
577
+ dropout=dropout,
578
+ time_embedding_norm=resnet_time_scale_shift,
579
+ non_linearity=resnet_act_fn,
580
+ output_scale_factor=output_scale_factor,
581
+ pre_norm=resnet_pre_norm,
582
+ )
583
+ )
584
+
585
+ self.attentions = nn.ModuleList(attentions)
586
+ self.resnets = nn.ModuleList(resnets)
587
+
588
+ def forward(
589
+ self,
590
+ hidden_states: torch.FloatTensor,
591
+ temb: Optional[torch.FloatTensor] = None,
592
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
593
+ attention_mask: Optional[torch.FloatTensor] = None,
594
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
595
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
596
+ ) -> torch.FloatTensor:
597
+ hidden_states = self.resnets[0](hidden_states, temb)
598
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
599
+ hidden_states = attn(
600
+ hidden_states,
601
+ encoder_hidden_states=encoder_hidden_states,
602
+ cross_attention_kwargs=cross_attention_kwargs,
603
+ attention_mask=attention_mask,
604
+ encoder_attention_mask=encoder_attention_mask,
605
+ return_dict=False,
606
+ )[0]
607
+ hidden_states = resnet(hidden_states, temb)
608
+
609
+ return hidden_states
610
+
611
+
612
+ class CrossAttnUpBlockMV2D(nn.Module):
613
+ def __init__(
614
+ self,
615
+ in_channels: int,
616
+ out_channels: int,
617
+ prev_output_channel: int,
618
+ temb_channels: int,
619
+ dropout: float = 0.0,
620
+ num_layers: int = 1,
621
+ transformer_layers_per_block: int = 1,
622
+ resnet_eps: float = 1e-6,
623
+ resnet_time_scale_shift: str = "default",
624
+ resnet_act_fn: str = "swish",
625
+ resnet_groups: int = 32,
626
+ resnet_pre_norm: bool = True,
627
+ num_attention_heads=1,
628
+ cross_attention_dim=1280,
629
+ output_scale_factor=1.0,
630
+ add_upsample=True,
631
+ dual_cross_attention=False,
632
+ use_linear_projection=False,
633
+ only_cross_attention=False,
634
+ upcast_attention=False,
635
+ num_views: int = 1,
636
+ cd_attention_last: bool = False,
637
+ cd_attention_mid: bool = False,
638
+ multiview_attention: bool = True,
639
+ sparse_mv_attention: bool = False,
640
+ mvcd_attention: bool=False
641
+ ):
642
+ super().__init__()
643
+ resnets = []
644
+ attentions = []
645
+
646
+ self.has_cross_attention = True
647
+ self.num_attention_heads = num_attention_heads
648
+
649
+ for i in range(num_layers):
650
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
651
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
652
+
653
+ resnets.append(
654
+ ResnetBlock2D(
655
+ in_channels=resnet_in_channels + res_skip_channels,
656
+ out_channels=out_channels,
657
+ temb_channels=temb_channels,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+ if not dual_cross_attention:
668
+ attentions.append(
669
+ TransformerMV2DModel(
670
+ num_attention_heads,
671
+ out_channels // num_attention_heads,
672
+ in_channels=out_channels,
673
+ num_layers=transformer_layers_per_block,
674
+ cross_attention_dim=cross_attention_dim,
675
+ norm_num_groups=resnet_groups,
676
+ use_linear_projection=use_linear_projection,
677
+ only_cross_attention=only_cross_attention,
678
+ upcast_attention=upcast_attention,
679
+ num_views=num_views,
680
+ cd_attention_last=cd_attention_last,
681
+ cd_attention_mid=cd_attention_mid,
682
+ multiview_attention=multiview_attention,
683
+ sparse_mv_attention=sparse_mv_attention,
684
+ mvcd_attention=mvcd_attention
685
+ )
686
+ )
687
+ else:
688
+ raise NotImplementedError
689
+ self.attentions = nn.ModuleList(attentions)
690
+ self.resnets = nn.ModuleList(resnets)
691
+
692
+ if add_upsample:
693
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
694
+ else:
695
+ self.upsamplers = None
696
+
697
+ self.gradient_checkpointing = False
698
+
699
+ def forward(
700
+ self,
701
+ hidden_states: torch.FloatTensor,
702
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
703
+ temb: Optional[torch.FloatTensor] = None,
704
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
705
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
706
+ upsample_size: Optional[int] = None,
707
+ attention_mask: Optional[torch.FloatTensor] = None,
708
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
709
+ ):
710
+ for resnet, attn in zip(self.resnets, self.attentions):
711
+ # pop res hidden states
712
+ res_hidden_states = res_hidden_states_tuple[-1]
713
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
714
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
715
+
716
+ if self.training and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module, return_dict=None):
719
+ def custom_forward(*inputs):
720
+ if return_dict is not None:
721
+ return module(*inputs, return_dict=return_dict)
722
+ else:
723
+ return module(*inputs)
724
+
725
+ return custom_forward
726
+
727
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(resnet),
730
+ hidden_states,
731
+ temb,
732
+ **ckpt_kwargs,
733
+ )
734
+ hidden_states = torch.utils.checkpoint.checkpoint(
735
+ create_custom_forward(attn, return_dict=False),
736
+ hidden_states,
737
+ encoder_hidden_states,
738
+ None, # timestep
739
+ None, # class_labels
740
+ cross_attention_kwargs,
741
+ attention_mask,
742
+ encoder_attention_mask,
743
+ **ckpt_kwargs,
744
+ )[0]
745
+ else:
746
+ hidden_states = resnet(hidden_states, temb)
747
+ hidden_states = attn(
748
+ hidden_states,
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ cross_attention_kwargs=cross_attention_kwargs,
751
+ attention_mask=attention_mask,
752
+ encoder_attention_mask=encoder_attention_mask,
753
+ return_dict=False,
754
+ )[0]
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states
761
+
762
+
763
+ class CrossAttnDownBlockMV2D(nn.Module):
764
+ def __init__(
765
+ self,
766
+ in_channels: int,
767
+ out_channels: int,
768
+ temb_channels: int,
769
+ dropout: float = 0.0,
770
+ num_layers: int = 1,
771
+ transformer_layers_per_block: int = 1,
772
+ resnet_eps: float = 1e-6,
773
+ resnet_time_scale_shift: str = "default",
774
+ resnet_act_fn: str = "swish",
775
+ resnet_groups: int = 32,
776
+ resnet_pre_norm: bool = True,
777
+ num_attention_heads=1,
778
+ cross_attention_dim=1280,
779
+ output_scale_factor=1.0,
780
+ downsample_padding=1,
781
+ add_downsample=True,
782
+ dual_cross_attention=False,
783
+ use_linear_projection=False,
784
+ only_cross_attention=False,
785
+ upcast_attention=False,
786
+ num_views: int = 1,
787
+ cd_attention_last: bool = False,
788
+ cd_attention_mid: bool = False,
789
+ multiview_attention: bool = True,
790
+ sparse_mv_attention: bool = False,
791
+ mvcd_attention: bool=False
792
+ ):
793
+ super().__init__()
794
+ resnets = []
795
+ attentions = []
796
+
797
+ self.has_cross_attention = True
798
+ self.num_attention_heads = num_attention_heads
799
+
800
+ for i in range(num_layers):
801
+ in_channels = in_channels if i == 0 else out_channels
802
+ resnets.append(
803
+ ResnetBlock2D(
804
+ in_channels=in_channels,
805
+ out_channels=out_channels,
806
+ temb_channels=temb_channels,
807
+ eps=resnet_eps,
808
+ groups=resnet_groups,
809
+ dropout=dropout,
810
+ time_embedding_norm=resnet_time_scale_shift,
811
+ non_linearity=resnet_act_fn,
812
+ output_scale_factor=output_scale_factor,
813
+ pre_norm=resnet_pre_norm,
814
+ )
815
+ )
816
+ if not dual_cross_attention:
817
+ attentions.append(
818
+ TransformerMV2DModel(
819
+ num_attention_heads,
820
+ out_channels // num_attention_heads,
821
+ in_channels=out_channels,
822
+ num_layers=transformer_layers_per_block,
823
+ cross_attention_dim=cross_attention_dim,
824
+ norm_num_groups=resnet_groups,
825
+ use_linear_projection=use_linear_projection,
826
+ only_cross_attention=only_cross_attention,
827
+ upcast_attention=upcast_attention,
828
+ num_views=num_views,
829
+ cd_attention_last=cd_attention_last,
830
+ cd_attention_mid=cd_attention_mid,
831
+ multiview_attention=multiview_attention,
832
+ sparse_mv_attention=sparse_mv_attention,
833
+ mvcd_attention=mvcd_attention
834
+ )
835
+ )
836
+ else:
837
+ raise NotImplementedError
838
+ self.attentions = nn.ModuleList(attentions)
839
+ self.resnets = nn.ModuleList(resnets)
840
+
841
+ if add_downsample:
842
+ self.downsamplers = nn.ModuleList(
843
+ [
844
+ Downsample2D(
845
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
846
+ )
847
+ ]
848
+ )
849
+ else:
850
+ self.downsamplers = None
851
+
852
+ self.gradient_checkpointing = False
853
+
854
+ def forward(
855
+ self,
856
+ hidden_states: torch.FloatTensor,
857
+ temb: Optional[torch.FloatTensor] = None,
858
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
859
+ attention_mask: Optional[torch.FloatTensor] = None,
860
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
861
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
862
+ additional_residuals=None,
863
+ ):
864
+ output_states = ()
865
+
866
+ blocks = list(zip(self.resnets, self.attentions))
867
+
868
+ for i, (resnet, attn) in enumerate(blocks):
869
+ if self.training and self.gradient_checkpointing:
870
+
871
+ def create_custom_forward(module, return_dict=None):
872
+ def custom_forward(*inputs):
873
+ if return_dict is not None:
874
+ return module(*inputs, return_dict=return_dict)
875
+ else:
876
+ return module(*inputs)
877
+
878
+ return custom_forward
879
+
880
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
881
+ hidden_states = torch.utils.checkpoint.checkpoint(
882
+ create_custom_forward(resnet),
883
+ hidden_states,
884
+ temb,
885
+ **ckpt_kwargs,
886
+ )
887
+ hidden_states = torch.utils.checkpoint.checkpoint(
888
+ create_custom_forward(attn, return_dict=False),
889
+ hidden_states,
890
+ encoder_hidden_states,
891
+ None, # timestep
892
+ None, # class_labels
893
+ cross_attention_kwargs,
894
+ attention_mask,
895
+ encoder_attention_mask,
896
+ **ckpt_kwargs,
897
+ )[0]
898
+ else:
899
+ hidden_states = resnet(hidden_states, temb)
900
+ hidden_states = attn(
901
+ hidden_states,
902
+ encoder_hidden_states=encoder_hidden_states,
903
+ cross_attention_kwargs=cross_attention_kwargs,
904
+ attention_mask=attention_mask,
905
+ encoder_attention_mask=encoder_attention_mask,
906
+ return_dict=False,
907
+ )[0]
908
+
909
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
910
+ if i == len(blocks) - 1 and additional_residuals is not None:
911
+ hidden_states = hidden_states + additional_residuals
912
+
913
+ output_states = output_states + (hidden_states,)
914
+
915
+ if self.downsamplers is not None:
916
+ for downsampler in self.downsamplers:
917
+ hidden_states = downsampler(hidden_states)
918
+
919
+ output_states = output_states + (hidden_states,)
920
+
921
+ return hidden_states, output_states
922
+
mvdiffusion/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ DIFFUSERS_CACHE,
50
+ FLAX_WEIGHTS_NAME,
51
+ HF_HUB_OFFLINE,
52
+ SAFETENSORS_WEIGHTS_NAME,
53
+ WEIGHTS_NAME,
54
+ _add_variant,
55
+ _get_model_file,
56
+ deprecate,
57
+ is_accelerate_available,
58
+ is_safetensors_available,
59
+ is_torch_version,
60
+ logging,
61
+ )
62
+ from diffusers import __version__
63
+ from mvdiffusion.models.unet_mv2d_blocks import (
64
+ CrossAttnDownBlockMV2D,
65
+ CrossAttnUpBlockMV2D,
66
+ UNetMidBlockMV2DCrossAttn,
67
+ get_down_block,
68
+ get_up_block,
69
+ )
70
+
71
+
72
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
+
74
+
75
+ @dataclass
76
+ class UNetMV2DConditionOutput(BaseOutput):
77
+ """
78
+ The output of [`UNet2DConditionModel`].
79
+
80
+ Args:
81
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
82
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
83
+ """
84
+
85
+ sample: torch.FloatTensor = None
86
+
87
+
88
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
89
+ r"""
90
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
91
+ shaped output.
92
+
93
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
94
+ for all models (such as downloading or saving).
95
+
96
+ Parameters:
97
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
98
+ Height and width of input/output sample.
99
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
100
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
101
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
102
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
103
+ Whether to flip the sin to cos in the time embedding.
104
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
105
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
106
+ The tuple of downsample blocks to use.
107
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
108
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
109
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
110
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
111
+ The tuple of upsample blocks to use.
112
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
113
+ Whether to include self-attention in the basic transformer blocks, see
114
+ [`~models.attention.BasicTransformerBlock`].
115
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
116
+ The tuple of output channels for each block.
117
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
118
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
119
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
120
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
121
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
122
+ If `None`, normalization and activation layers is skipped in post-processing.
123
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
124
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
125
+ The dimension of the cross attention features.
126
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
127
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
128
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
129
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
130
+ encoder_hid_dim (`int`, *optional*, defaults to None):
131
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
132
+ dimension to `cross_attention_dim`.
133
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
134
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
135
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
136
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
137
+ num_attention_heads (`int`, *optional*):
138
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
139
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
140
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
141
+ class_embed_type (`str`, *optional*, defaults to `None`):
142
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
143
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
144
+ addition_embed_type (`str`, *optional*, defaults to `None`):
145
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
146
+ "text". "text" will use the `TextTimeEmbedding` layer.
147
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
148
+ Dimension for the timestep embeddings.
149
+ num_class_embeds (`int`, *optional*, defaults to `None`):
150
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
151
+ class conditioning with `class_embed_type` equal to `None`.
152
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
153
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
154
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
155
+ An optional override for the dimension of the projected time embedding.
156
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
157
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
158
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
159
+ timestep_post_act (`str`, *optional*, defaults to `None`):
160
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
161
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
162
+ The dimension of `cond_proj` layer in the timestep embedding.
163
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
164
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
165
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
166
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
167
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
168
+ embeddings with the class embeddings.
169
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
170
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
171
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
172
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
173
+ otherwise.
174
+ """
175
+
176
+ _supports_gradient_checkpointing = True
177
+
178
+ @register_to_config
179
+ def __init__(
180
+ self,
181
+ sample_size: Optional[int] = None,
182
+ in_channels: int = 4,
183
+ out_channels: int = 4,
184
+ center_input_sample: bool = False,
185
+ flip_sin_to_cos: bool = True,
186
+ freq_shift: int = 0,
187
+ down_block_types: Tuple[str] = (
188
+ "CrossAttnDownBlockMV2D",
189
+ "CrossAttnDownBlockMV2D",
190
+ "CrossAttnDownBlockMV2D",
191
+ "DownBlock2D",
192
+ ),
193
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
194
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
195
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
196
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
197
+ layers_per_block: Union[int, Tuple[int]] = 2,
198
+ downsample_padding: int = 1,
199
+ mid_block_scale_factor: float = 1,
200
+ act_fn: str = "silu",
201
+ norm_num_groups: Optional[int] = 32,
202
+ norm_eps: float = 1e-5,
203
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
204
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
205
+ encoder_hid_dim: Optional[int] = None,
206
+ encoder_hid_dim_type: Optional[str] = None,
207
+ attention_head_dim: Union[int, Tuple[int]] = 8,
208
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
+ dual_cross_attention: bool = False,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_skip_time_act: bool = False,
218
+ resnet_out_scale_factor: int = 1.0,
219
+ time_embedding_type: str = "positional",
220
+ time_embedding_dim: Optional[int] = None,
221
+ time_embedding_act_fn: Optional[str] = None,
222
+ timestep_post_act: Optional[str] = None,
223
+ time_cond_proj_dim: Optional[int] = None,
224
+ conv_in_kernel: int = 3,
225
+ conv_out_kernel: int = 3,
226
+ projection_class_embeddings_input_dim: Optional[int] = None,
227
+ class_embeddings_concat: bool = False,
228
+ mid_block_only_cross_attention: Optional[bool] = None,
229
+ cross_attention_norm: Optional[str] = None,
230
+ addition_embed_type_num_heads=64,
231
+ num_views: int = 1,
232
+ cd_attention_last: bool = False,
233
+ cd_attention_mid: bool = False,
234
+ multiview_attention: bool = True,
235
+ sparse_mv_attention: bool = False,
236
+ mvcd_attention: bool = False
237
+ ):
238
+ super().__init__()
239
+
240
+ self.sample_size = sample_size
241
+
242
+ if num_attention_heads is not None:
243
+ raise ValueError(
244
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
245
+ )
246
+
247
+ # If `num_attention_heads` is not defined (which is the case for most models)
248
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
249
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
250
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
251
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
252
+ # which is why we correct for the naming here.
253
+ num_attention_heads = num_attention_heads or attention_head_dim
254
+
255
+ # Check inputs
256
+ if len(down_block_types) != len(up_block_types):
257
+ raise ValueError(
258
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
259
+ )
260
+
261
+ if len(block_out_channels) != len(down_block_types):
262
+ raise ValueError(
263
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
264
+ )
265
+
266
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
267
+ raise ValueError(
268
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
272
+ raise ValueError(
273
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
274
+ )
275
+
276
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
277
+ raise ValueError(
278
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
279
+ )
280
+
281
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
282
+ raise ValueError(
283
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
284
+ )
285
+
286
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
287
+ raise ValueError(
288
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
289
+ )
290
+
291
+ # input
292
+ conv_in_padding = (conv_in_kernel - 1) // 2
293
+ self.conv_in = nn.Conv2d(
294
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
295
+ )
296
+
297
+ # time
298
+ if time_embedding_type == "fourier":
299
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
300
+ if time_embed_dim % 2 != 0:
301
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
302
+ self.time_proj = GaussianFourierProjection(
303
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
304
+ )
305
+ timestep_input_dim = time_embed_dim
306
+ elif time_embedding_type == "positional":
307
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
308
+
309
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
310
+ timestep_input_dim = block_out_channels[0]
311
+ else:
312
+ raise ValueError(
313
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
314
+ )
315
+
316
+ self.time_embedding = TimestepEmbedding(
317
+ timestep_input_dim,
318
+ time_embed_dim,
319
+ act_fn=act_fn,
320
+ post_act_fn=timestep_post_act,
321
+ cond_proj_dim=time_cond_proj_dim,
322
+ )
323
+
324
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
325
+ encoder_hid_dim_type = "text_proj"
326
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
327
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
328
+
329
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
330
+ raise ValueError(
331
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
332
+ )
333
+
334
+ if encoder_hid_dim_type == "text_proj":
335
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
336
+ elif encoder_hid_dim_type == "text_image_proj":
337
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
338
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
339
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
340
+ self.encoder_hid_proj = TextImageProjection(
341
+ text_embed_dim=encoder_hid_dim,
342
+ image_embed_dim=cross_attention_dim,
343
+ cross_attention_dim=cross_attention_dim,
344
+ )
345
+ elif encoder_hid_dim_type == "image_proj":
346
+ # Kandinsky 2.2
347
+ self.encoder_hid_proj = ImageProjection(
348
+ image_embed_dim=encoder_hid_dim,
349
+ cross_attention_dim=cross_attention_dim,
350
+ )
351
+ elif encoder_hid_dim_type is not None:
352
+ raise ValueError(
353
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
354
+ )
355
+ else:
356
+ self.encoder_hid_proj = None
357
+
358
+ # class embedding
359
+ if class_embed_type is None and num_class_embeds is not None:
360
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
361
+ elif class_embed_type == "timestep":
362
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
363
+ elif class_embed_type == "identity":
364
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
365
+ elif class_embed_type == "projection":
366
+ if projection_class_embeddings_input_dim is None:
367
+ raise ValueError(
368
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
369
+ )
370
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
371
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
372
+ # 2. it projects from an arbitrary input dimension.
373
+ #
374
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
375
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
376
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
377
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
378
+ elif class_embed_type == "simple_projection":
379
+ if projection_class_embeddings_input_dim is None:
380
+ raise ValueError(
381
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
382
+ )
383
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
384
+ else:
385
+ self.class_embedding = None
386
+
387
+ if addition_embed_type == "text":
388
+ if encoder_hid_dim is not None:
389
+ text_time_embedding_from_dim = encoder_hid_dim
390
+ else:
391
+ text_time_embedding_from_dim = cross_attention_dim
392
+
393
+ self.add_embedding = TextTimeEmbedding(
394
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
395
+ )
396
+ elif addition_embed_type == "text_image":
397
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
398
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
399
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
400
+ self.add_embedding = TextImageTimeEmbedding(
401
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
402
+ )
403
+ elif addition_embed_type == "text_time":
404
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
405
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
406
+ elif addition_embed_type == "image":
407
+ # Kandinsky 2.2
408
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
409
+ elif addition_embed_type == "image_hint":
410
+ # Kandinsky 2.2 ControlNet
411
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
412
+ elif addition_embed_type is not None:
413
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
414
+
415
+ if time_embedding_act_fn is None:
416
+ self.time_embed_act = None
417
+ else:
418
+ self.time_embed_act = get_activation(time_embedding_act_fn)
419
+
420
+ self.down_blocks = nn.ModuleList([])
421
+ self.up_blocks = nn.ModuleList([])
422
+
423
+ if isinstance(only_cross_attention, bool):
424
+ if mid_block_only_cross_attention is None:
425
+ mid_block_only_cross_attention = only_cross_attention
426
+
427
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
428
+
429
+ if mid_block_only_cross_attention is None:
430
+ mid_block_only_cross_attention = False
431
+
432
+ if isinstance(num_attention_heads, int):
433
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
434
+
435
+ if isinstance(attention_head_dim, int):
436
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
437
+
438
+ if isinstance(cross_attention_dim, int):
439
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
440
+
441
+ if isinstance(layers_per_block, int):
442
+ layers_per_block = [layers_per_block] * len(down_block_types)
443
+
444
+ if isinstance(transformer_layers_per_block, int):
445
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
446
+
447
+ if class_embeddings_concat:
448
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
449
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
450
+ # regular time embeddings
451
+ blocks_time_embed_dim = time_embed_dim * 2
452
+ else:
453
+ blocks_time_embed_dim = time_embed_dim
454
+
455
+ # down
456
+ output_channel = block_out_channels[0]
457
+ for i, down_block_type in enumerate(down_block_types):
458
+ input_channel = output_channel
459
+ output_channel = block_out_channels[i]
460
+ is_final_block = i == len(block_out_channels) - 1
461
+
462
+ down_block = get_down_block(
463
+ down_block_type,
464
+ num_layers=layers_per_block[i],
465
+ transformer_layers_per_block=transformer_layers_per_block[i],
466
+ in_channels=input_channel,
467
+ out_channels=output_channel,
468
+ temb_channels=blocks_time_embed_dim,
469
+ add_downsample=not is_final_block,
470
+ resnet_eps=norm_eps,
471
+ resnet_act_fn=act_fn,
472
+ resnet_groups=norm_num_groups,
473
+ cross_attention_dim=cross_attention_dim[i],
474
+ num_attention_heads=num_attention_heads[i],
475
+ downsample_padding=downsample_padding,
476
+ dual_cross_attention=dual_cross_attention,
477
+ use_linear_projection=use_linear_projection,
478
+ only_cross_attention=only_cross_attention[i],
479
+ upcast_attention=upcast_attention,
480
+ resnet_time_scale_shift=resnet_time_scale_shift,
481
+ resnet_skip_time_act=resnet_skip_time_act,
482
+ resnet_out_scale_factor=resnet_out_scale_factor,
483
+ cross_attention_norm=cross_attention_norm,
484
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
485
+ num_views=num_views,
486
+ cd_attention_last=cd_attention_last,
487
+ cd_attention_mid=cd_attention_mid,
488
+ multiview_attention=multiview_attention,
489
+ sparse_mv_attention=sparse_mv_attention,
490
+ mvcd_attention=mvcd_attention
491
+ )
492
+ self.down_blocks.append(down_block)
493
+
494
+ # mid
495
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
496
+ self.mid_block = UNetMidBlock2DCrossAttn(
497
+ transformer_layers_per_block=transformer_layers_per_block[-1],
498
+ in_channels=block_out_channels[-1],
499
+ temb_channels=blocks_time_embed_dim,
500
+ resnet_eps=norm_eps,
501
+ resnet_act_fn=act_fn,
502
+ output_scale_factor=mid_block_scale_factor,
503
+ resnet_time_scale_shift=resnet_time_scale_shift,
504
+ cross_attention_dim=cross_attention_dim[-1],
505
+ num_attention_heads=num_attention_heads[-1],
506
+ resnet_groups=norm_num_groups,
507
+ dual_cross_attention=dual_cross_attention,
508
+ use_linear_projection=use_linear_projection,
509
+ upcast_attention=upcast_attention,
510
+ )
511
+ # custom MV2D attention block
512
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
513
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
514
+ transformer_layers_per_block=transformer_layers_per_block[-1],
515
+ in_channels=block_out_channels[-1],
516
+ temb_channels=blocks_time_embed_dim,
517
+ resnet_eps=norm_eps,
518
+ resnet_act_fn=act_fn,
519
+ output_scale_factor=mid_block_scale_factor,
520
+ resnet_time_scale_shift=resnet_time_scale_shift,
521
+ cross_attention_dim=cross_attention_dim[-1],
522
+ num_attention_heads=num_attention_heads[-1],
523
+ resnet_groups=norm_num_groups,
524
+ dual_cross_attention=dual_cross_attention,
525
+ use_linear_projection=use_linear_projection,
526
+ upcast_attention=upcast_attention,
527
+ num_views=num_views,
528
+ cd_attention_last=cd_attention_last,
529
+ cd_attention_mid=cd_attention_mid,
530
+ multiview_attention=multiview_attention,
531
+ sparse_mv_attention=sparse_mv_attention,
532
+ mvcd_attention=mvcd_attention
533
+ )
534
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
535
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
536
+ in_channels=block_out_channels[-1],
537
+ temb_channels=blocks_time_embed_dim,
538
+ resnet_eps=norm_eps,
539
+ resnet_act_fn=act_fn,
540
+ output_scale_factor=mid_block_scale_factor,
541
+ cross_attention_dim=cross_attention_dim[-1],
542
+ attention_head_dim=attention_head_dim[-1],
543
+ resnet_groups=norm_num_groups,
544
+ resnet_time_scale_shift=resnet_time_scale_shift,
545
+ skip_time_act=resnet_skip_time_act,
546
+ only_cross_attention=mid_block_only_cross_attention,
547
+ cross_attention_norm=cross_attention_norm,
548
+ )
549
+ elif mid_block_type is None:
550
+ self.mid_block = None
551
+ else:
552
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
553
+
554
+ # count how many layers upsample the images
555
+ self.num_upsamplers = 0
556
+
557
+ # up
558
+ reversed_block_out_channels = list(reversed(block_out_channels))
559
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
560
+ reversed_layers_per_block = list(reversed(layers_per_block))
561
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
562
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
563
+ only_cross_attention = list(reversed(only_cross_attention))
564
+
565
+ output_channel = reversed_block_out_channels[0]
566
+ for i, up_block_type in enumerate(up_block_types):
567
+ is_final_block = i == len(block_out_channels) - 1
568
+
569
+ prev_output_channel = output_channel
570
+ output_channel = reversed_block_out_channels[i]
571
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
572
+
573
+ # add upsample block for all BUT final layer
574
+ if not is_final_block:
575
+ add_upsample = True
576
+ self.num_upsamplers += 1
577
+ else:
578
+ add_upsample = False
579
+
580
+ up_block = get_up_block(
581
+ up_block_type,
582
+ num_layers=reversed_layers_per_block[i] + 1,
583
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
584
+ in_channels=input_channel,
585
+ out_channels=output_channel,
586
+ prev_output_channel=prev_output_channel,
587
+ temb_channels=blocks_time_embed_dim,
588
+ add_upsample=add_upsample,
589
+ resnet_eps=norm_eps,
590
+ resnet_act_fn=act_fn,
591
+ resnet_groups=norm_num_groups,
592
+ cross_attention_dim=reversed_cross_attention_dim[i],
593
+ num_attention_heads=reversed_num_attention_heads[i],
594
+ dual_cross_attention=dual_cross_attention,
595
+ use_linear_projection=use_linear_projection,
596
+ only_cross_attention=only_cross_attention[i],
597
+ upcast_attention=upcast_attention,
598
+ resnet_time_scale_shift=resnet_time_scale_shift,
599
+ resnet_skip_time_act=resnet_skip_time_act,
600
+ resnet_out_scale_factor=resnet_out_scale_factor,
601
+ cross_attention_norm=cross_attention_norm,
602
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
603
+ num_views=num_views,
604
+ cd_attention_last=cd_attention_last,
605
+ cd_attention_mid=cd_attention_mid,
606
+ multiview_attention=multiview_attention,
607
+ sparse_mv_attention=sparse_mv_attention,
608
+ mvcd_attention=mvcd_attention
609
+ )
610
+ self.up_blocks.append(up_block)
611
+ prev_output_channel = output_channel
612
+
613
+ # out
614
+ if norm_num_groups is not None:
615
+ self.conv_norm_out = nn.GroupNorm(
616
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
617
+ )
618
+
619
+ self.conv_act = get_activation(act_fn)
620
+
621
+ else:
622
+ self.conv_norm_out = None
623
+ self.conv_act = None
624
+
625
+ conv_out_padding = (conv_out_kernel - 1) // 2
626
+ self.conv_out = nn.Conv2d(
627
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
628
+ )
629
+
630
+ @property
631
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
632
+ r"""
633
+ Returns:
634
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
635
+ indexed by its weight name.
636
+ """
637
+ # set recursively
638
+ processors = {}
639
+
640
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
641
+ if hasattr(module, "set_processor"):
642
+ processors[f"{name}.processor"] = module.processor
643
+
644
+ for sub_name, child in module.named_children():
645
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
646
+
647
+ return processors
648
+
649
+ for name, module in self.named_children():
650
+ fn_recursive_add_processors(name, module, processors)
651
+
652
+ return processors
653
+
654
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
655
+ r"""
656
+ Sets the attention processor to use to compute attention.
657
+
658
+ Parameters:
659
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
660
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
661
+ for **all** `Attention` layers.
662
+
663
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
664
+ processor. This is strongly recommended when setting trainable attention processors.
665
+
666
+ """
667
+ count = len(self.attn_processors.keys())
668
+
669
+ if isinstance(processor, dict) and len(processor) != count:
670
+ raise ValueError(
671
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
672
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
673
+ )
674
+
675
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
676
+ if hasattr(module, "set_processor"):
677
+ if not isinstance(processor, dict):
678
+ module.set_processor(processor)
679
+ else:
680
+ module.set_processor(processor.pop(f"{name}.processor"))
681
+
682
+ for sub_name, child in module.named_children():
683
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
684
+
685
+ for name, module in self.named_children():
686
+ fn_recursive_attn_processor(name, module, processor)
687
+
688
+ def set_default_attn_processor(self):
689
+ """
690
+ Disables custom attention processors and sets the default attention implementation.
691
+ """
692
+ self.set_attn_processor(AttnProcessor())
693
+
694
+ def set_attention_slice(self, slice_size):
695
+ r"""
696
+ Enable sliced attention computation.
697
+
698
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
699
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
700
+
701
+ Args:
702
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
703
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
704
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
705
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
706
+ must be a multiple of `slice_size`.
707
+ """
708
+ sliceable_head_dims = []
709
+
710
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
711
+ if hasattr(module, "set_attention_slice"):
712
+ sliceable_head_dims.append(module.sliceable_head_dim)
713
+
714
+ for child in module.children():
715
+ fn_recursive_retrieve_sliceable_dims(child)
716
+
717
+ # retrieve number of attention layers
718
+ for module in self.children():
719
+ fn_recursive_retrieve_sliceable_dims(module)
720
+
721
+ num_sliceable_layers = len(sliceable_head_dims)
722
+
723
+ if slice_size == "auto":
724
+ # half the attention head size is usually a good trade-off between
725
+ # speed and memory
726
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
727
+ elif slice_size == "max":
728
+ # make smallest slice possible
729
+ slice_size = num_sliceable_layers * [1]
730
+
731
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
732
+
733
+ if len(slice_size) != len(sliceable_head_dims):
734
+ raise ValueError(
735
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
736
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
737
+ )
738
+
739
+ for i in range(len(slice_size)):
740
+ size = slice_size[i]
741
+ dim = sliceable_head_dims[i]
742
+ if size is not None and size > dim:
743
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
744
+
745
+ # Recursively walk through all the children.
746
+ # Any children which exposes the set_attention_slice method
747
+ # gets the message
748
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
749
+ if hasattr(module, "set_attention_slice"):
750
+ module.set_attention_slice(slice_size.pop())
751
+
752
+ for child in module.children():
753
+ fn_recursive_set_attention_slice(child, slice_size)
754
+
755
+ reversed_slice_size = list(reversed(slice_size))
756
+ for module in self.children():
757
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
758
+
759
+ def _set_gradient_checkpointing(self, module, value=False):
760
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
761
+ module.gradient_checkpointing = value
762
+
763
+ def forward(
764
+ self,
765
+ sample: torch.FloatTensor,
766
+ timestep: Union[torch.Tensor, float, int],
767
+ encoder_hidden_states: torch.Tensor,
768
+ class_labels: Optional[torch.Tensor] = None,
769
+ timestep_cond: Optional[torch.Tensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
772
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
773
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
774
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
775
+ encoder_attention_mask: Optional[torch.Tensor] = None,
776
+ return_dict: bool = True,
777
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
778
+ r"""
779
+ The [`UNet2DConditionModel`] forward method.
780
+
781
+ Args:
782
+ sample (`torch.FloatTensor`):
783
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
784
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
785
+ encoder_hidden_states (`torch.FloatTensor`):
786
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
787
+ encoder_attention_mask (`torch.Tensor`):
788
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
789
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
790
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
791
+ return_dict (`bool`, *optional*, defaults to `True`):
792
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
793
+ tuple.
794
+ cross_attention_kwargs (`dict`, *optional*):
795
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
796
+ added_cond_kwargs: (`dict`, *optional*):
797
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
798
+ are passed along to the UNet blocks.
799
+
800
+ Returns:
801
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
802
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
803
+ a `tuple` is returned where the first element is the sample tensor.
804
+ """
805
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
806
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
807
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
808
+ # on the fly if necessary.
809
+ default_overall_up_factor = 2**self.num_upsamplers
810
+
811
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
812
+ forward_upsample_size = False
813
+ upsample_size = None
814
+
815
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
816
+ logger.info("Forward upsample size to force interpolation output size.")
817
+ forward_upsample_size = True
818
+
819
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
820
+ # expects mask of shape:
821
+ # [batch, key_tokens]
822
+ # adds singleton query_tokens dimension:
823
+ # [batch, 1, key_tokens]
824
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
825
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
826
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
827
+ if attention_mask is not None:
828
+ # assume that mask is expressed as:
829
+ # (1 = keep, 0 = discard)
830
+ # convert mask into a bias that can be added to attention scores:
831
+ # (keep = +0, discard = -10000.0)
832
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
833
+ attention_mask = attention_mask.unsqueeze(1)
834
+
835
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
836
+ if encoder_attention_mask is not None:
837
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
838
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
839
+
840
+ # 0. center input if necessary
841
+ if self.config.center_input_sample:
842
+ sample = 2 * sample - 1.0
843
+
844
+ # 1. time
845
+ timesteps = timestep
846
+ if not torch.is_tensor(timesteps):
847
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
848
+ # This would be a good case for the `match` statement (Python 3.10+)
849
+ is_mps = sample.device.type == "mps"
850
+ if isinstance(timestep, float):
851
+ dtype = torch.float32 if is_mps else torch.float64
852
+ else:
853
+ dtype = torch.int32 if is_mps else torch.int64
854
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
855
+ elif len(timesteps.shape) == 0:
856
+ timesteps = timesteps[None].to(sample.device)
857
+
858
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
859
+ timesteps = timesteps.expand(sample.shape[0])
860
+
861
+ t_emb = self.time_proj(timesteps)
862
+
863
+ # `Timesteps` does not contain any weights and will always return f32 tensors
864
+ # but time_embedding might actually be running in fp16. so we need to cast here.
865
+ # there might be better ways to encapsulate this.
866
+ t_emb = t_emb.to(dtype=sample.dtype)
867
+
868
+ emb = self.time_embedding(t_emb, timestep_cond)
869
+ aug_emb = None
870
+
871
+ if self.class_embedding is not None:
872
+ if class_labels is None:
873
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
874
+
875
+ if self.config.class_embed_type == "timestep":
876
+ class_labels = self.time_proj(class_labels)
877
+
878
+ # `Timesteps` does not contain any weights and will always return f32 tensors
879
+ # there might be better ways to encapsulate this.
880
+ class_labels = class_labels.to(dtype=sample.dtype)
881
+
882
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
883
+
884
+ if self.config.class_embeddings_concat:
885
+ emb = torch.cat([emb, class_emb], dim=-1)
886
+ else:
887
+ emb = emb + class_emb
888
+
889
+ if self.config.addition_embed_type == "text":
890
+ aug_emb = self.add_embedding(encoder_hidden_states)
891
+ elif self.config.addition_embed_type == "text_image":
892
+ # Kandinsky 2.1 - style
893
+ if "image_embeds" not in added_cond_kwargs:
894
+ raise ValueError(
895
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
896
+ )
897
+
898
+ image_embs = added_cond_kwargs.get("image_embeds")
899
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
900
+ aug_emb = self.add_embedding(text_embs, image_embs)
901
+ elif self.config.addition_embed_type == "text_time":
902
+ # SDXL - style
903
+ if "text_embeds" not in added_cond_kwargs:
904
+ raise ValueError(
905
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
906
+ )
907
+ text_embeds = added_cond_kwargs.get("text_embeds")
908
+ if "time_ids" not in added_cond_kwargs:
909
+ raise ValueError(
910
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
911
+ )
912
+ time_ids = added_cond_kwargs.get("time_ids")
913
+ time_embeds = self.add_time_proj(time_ids.flatten())
914
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
915
+
916
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
917
+ add_embeds = add_embeds.to(emb.dtype)
918
+ aug_emb = self.add_embedding(add_embeds)
919
+ elif self.config.addition_embed_type == "image":
920
+ # Kandinsky 2.2 - style
921
+ if "image_embeds" not in added_cond_kwargs:
922
+ raise ValueError(
923
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
924
+ )
925
+ image_embs = added_cond_kwargs.get("image_embeds")
926
+ aug_emb = self.add_embedding(image_embs)
927
+ elif self.config.addition_embed_type == "image_hint":
928
+ # Kandinsky 2.2 - style
929
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
930
+ raise ValueError(
931
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
932
+ )
933
+ image_embs = added_cond_kwargs.get("image_embeds")
934
+ hint = added_cond_kwargs.get("hint")
935
+ aug_emb, hint = self.add_embedding(image_embs, hint)
936
+ sample = torch.cat([sample, hint], dim=1)
937
+
938
+ emb = emb + aug_emb if aug_emb is not None else emb
939
+
940
+ if self.time_embed_act is not None:
941
+ emb = self.time_embed_act(emb)
942
+
943
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
944
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
945
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
946
+ # Kadinsky 2.1 - style
947
+ if "image_embeds" not in added_cond_kwargs:
948
+ raise ValueError(
949
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
950
+ )
951
+
952
+ image_embeds = added_cond_kwargs.get("image_embeds")
953
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
954
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
955
+ # Kandinsky 2.2 - style
956
+ if "image_embeds" not in added_cond_kwargs:
957
+ raise ValueError(
958
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
959
+ )
960
+ image_embeds = added_cond_kwargs.get("image_embeds")
961
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
962
+ # 2. pre-process
963
+ sample = self.conv_in(sample)
964
+
965
+ # 3. down
966
+
967
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
968
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
969
+
970
+ down_block_res_samples = (sample,)
971
+ for downsample_block in self.down_blocks:
972
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
973
+ # For t2i-adapter CrossAttnDownBlock2D
974
+ additional_residuals = {}
975
+ if is_adapter and len(down_block_additional_residuals) > 0:
976
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
977
+
978
+ sample, res_samples = downsample_block(
979
+ hidden_states=sample,
980
+ temb=emb,
981
+ encoder_hidden_states=encoder_hidden_states,
982
+ attention_mask=attention_mask,
983
+ cross_attention_kwargs=cross_attention_kwargs,
984
+ encoder_attention_mask=encoder_attention_mask,
985
+ **additional_residuals,
986
+ )
987
+ else:
988
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
989
+
990
+ if is_adapter and len(down_block_additional_residuals) > 0:
991
+ sample += down_block_additional_residuals.pop(0)
992
+
993
+ down_block_res_samples += res_samples
994
+
995
+ if is_controlnet:
996
+ new_down_block_res_samples = ()
997
+
998
+ for down_block_res_sample, down_block_additional_residual in zip(
999
+ down_block_res_samples, down_block_additional_residuals
1000
+ ):
1001
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1002
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1003
+
1004
+ down_block_res_samples = new_down_block_res_samples
1005
+
1006
+ # 4. mid
1007
+ if self.mid_block is not None:
1008
+ sample = self.mid_block(
1009
+ sample,
1010
+ emb,
1011
+ encoder_hidden_states=encoder_hidden_states,
1012
+ attention_mask=attention_mask,
1013
+ cross_attention_kwargs=cross_attention_kwargs,
1014
+ encoder_attention_mask=encoder_attention_mask,
1015
+ )
1016
+
1017
+ if is_controlnet:
1018
+ sample = sample + mid_block_additional_residual
1019
+
1020
+ # 5. up
1021
+ for i, upsample_block in enumerate(self.up_blocks):
1022
+ is_final_block = i == len(self.up_blocks) - 1
1023
+
1024
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1025
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1026
+
1027
+ # if we have not reached the final block and need to forward the
1028
+ # upsample size, we do it here
1029
+ if not is_final_block and forward_upsample_size:
1030
+ upsample_size = down_block_res_samples[-1].shape[2:]
1031
+
1032
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1033
+ sample = upsample_block(
1034
+ hidden_states=sample,
1035
+ temb=emb,
1036
+ res_hidden_states_tuple=res_samples,
1037
+ encoder_hidden_states=encoder_hidden_states,
1038
+ cross_attention_kwargs=cross_attention_kwargs,
1039
+ upsample_size=upsample_size,
1040
+ attention_mask=attention_mask,
1041
+ encoder_attention_mask=encoder_attention_mask,
1042
+ )
1043
+ else:
1044
+ sample = upsample_block(
1045
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1046
+ )
1047
+
1048
+ # 6. post-process
1049
+ if self.conv_norm_out:
1050
+ sample = self.conv_norm_out(sample)
1051
+ sample = self.conv_act(sample)
1052
+ sample = self.conv_out(sample)
1053
+
1054
+ if not return_dict:
1055
+ return (sample,)
1056
+
1057
+ return UNetMV2DConditionOutput(sample=sample)
1058
+
1059
+ @classmethod
1060
+ def from_pretrained_2d(
1061
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1062
+ camera_embedding_type: str, num_views: int, sample_size: int,
1063
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1064
+ projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False,
1065
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1066
+ sparse_mv_attention: bool = False, mvcd_attention: bool = False,
1067
+ in_channels: int = 8, out_channels: int = 4,
1068
+ **kwargs
1069
+ ):
1070
+ r"""
1071
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1072
+
1073
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1074
+ train the model, set it back in training mode with `model.train()`.
1075
+
1076
+ Parameters:
1077
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1078
+ Can be either:
1079
+
1080
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1081
+ the Hub.
1082
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1083
+ with [`~ModelMixin.save_pretrained`].
1084
+
1085
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1086
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1087
+ is not used.
1088
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1089
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1090
+ dtype is automatically derived from the model's weights.
1091
+ force_download (`bool`, *optional*, defaults to `False`):
1092
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1093
+ cached versions if they exist.
1094
+ resume_download (`bool`, *optional*, defaults to `False`):
1095
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1096
+ incompletely downloaded files are deleted.
1097
+ proxies (`Dict[str, str]`, *optional*):
1098
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1099
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1100
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1101
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1102
+ local_files_only(`bool`, *optional*, defaults to `False`):
1103
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1104
+ won't be downloaded from the Hub.
1105
+ use_auth_token (`str` or *bool*, *optional*):
1106
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1107
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1108
+ revision (`str`, *optional*, defaults to `"main"`):
1109
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1110
+ allowed by Git.
1111
+ from_flax (`bool`, *optional*, defaults to `False`):
1112
+ Load the model weights from a Flax checkpoint save file.
1113
+ subfolder (`str`, *optional*, defaults to `""`):
1114
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1115
+ mirror (`str`, *optional*):
1116
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1117
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1118
+ information.
1119
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1120
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1121
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1122
+ same device.
1123
+
1124
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1125
+ more information about each option see [designing a device
1126
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1127
+ max_memory (`Dict`, *optional*):
1128
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1129
+ each GPU and the available CPU RAM if unset.
1130
+ offload_folder (`str` or `os.PathLike`, *optional*):
1131
+ The path to offload weights if `device_map` contains the value `"disk"`.
1132
+ offload_state_dict (`bool`, *optional*):
1133
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1134
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1135
+ when there is some disk offload.
1136
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1137
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1138
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1139
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1140
+ argument to `True` will raise an error.
1141
+ variant (`str`, *optional*):
1142
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1143
+ loading `from_flax`.
1144
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1145
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1146
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1147
+ weights. If set to `False`, `safetensors` weights are not loaded.
1148
+
1149
+ <Tip>
1150
+
1151
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1152
+ `huggingface-cli login`. You can also activate the special
1153
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1154
+ firewalled environment.
1155
+
1156
+ </Tip>
1157
+
1158
+ Example:
1159
+
1160
+ ```py
1161
+ from diffusers import UNet2DConditionModel
1162
+
1163
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1164
+ ```
1165
+
1166
+ If you get the error message below, you need to finetune the weights for your downstream task:
1167
+
1168
+ ```bash
1169
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1170
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1171
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1172
+ ```
1173
+ """
1174
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1175
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1176
+ force_download = kwargs.pop("force_download", False)
1177
+ from_flax = kwargs.pop("from_flax", False)
1178
+ resume_download = kwargs.pop("resume_download", False)
1179
+ proxies = kwargs.pop("proxies", None)
1180
+ output_loading_info = kwargs.pop("output_loading_info", False)
1181
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1182
+ use_auth_token = kwargs.pop("use_auth_token", None)
1183
+ revision = kwargs.pop("revision", None)
1184
+ torch_dtype = kwargs.pop("torch_dtype", None)
1185
+ subfolder = kwargs.pop("subfolder", None)
1186
+ device_map = kwargs.pop("device_map", None)
1187
+ max_memory = kwargs.pop("max_memory", None)
1188
+ offload_folder = kwargs.pop("offload_folder", None)
1189
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1190
+ variant = kwargs.pop("variant", None)
1191
+ use_safetensors = kwargs.pop("use_safetensors", None)
1192
+
1193
+ if use_safetensors and not is_safetensors_available():
1194
+ raise ValueError(
1195
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1196
+ )
1197
+
1198
+ allow_pickle = False
1199
+ if use_safetensors is None:
1200
+ use_safetensors = is_safetensors_available()
1201
+ allow_pickle = True
1202
+
1203
+ if device_map is not None and not is_accelerate_available():
1204
+ raise NotImplementedError(
1205
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1206
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1207
+ )
1208
+
1209
+ # Check if we can handle device_map and dispatching the weights
1210
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1211
+ raise NotImplementedError(
1212
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1213
+ " `device_map=None`."
1214
+ )
1215
+
1216
+ # Load config if we don't provide a configuration
1217
+ config_path = pretrained_model_name_or_path
1218
+
1219
+ user_agent = {
1220
+ "diffusers": __version__,
1221
+ "file_type": "model",
1222
+ "framework": "pytorch",
1223
+ }
1224
+
1225
+ # load config
1226
+ config, unused_kwargs, commit_hash = cls.load_config(
1227
+ config_path,
1228
+ cache_dir=cache_dir,
1229
+ return_unused_kwargs=True,
1230
+ return_commit_hash=True,
1231
+ force_download=force_download,
1232
+ resume_download=resume_download,
1233
+ proxies=proxies,
1234
+ local_files_only=local_files_only,
1235
+ use_auth_token=use_auth_token,
1236
+ revision=revision,
1237
+ subfolder=subfolder,
1238
+ device_map=device_map,
1239
+ max_memory=max_memory,
1240
+ offload_folder=offload_folder,
1241
+ offload_state_dict=offload_state_dict,
1242
+ user_agent=user_agent,
1243
+ **kwargs,
1244
+ )
1245
+
1246
+ # modify config
1247
+ config["_class_name"] = cls.__name__
1248
+ config['in_channels'] = in_channels
1249
+ config['out_channels'] = out_channels
1250
+ config['sample_size'] = sample_size # training resolution
1251
+ config['num_views'] = num_views
1252
+ config['cd_attention_last'] = cd_attention_last
1253
+ config['cd_attention_mid'] = cd_attention_mid
1254
+ config['multiview_attention'] = multiview_attention
1255
+ config['sparse_mv_attention'] = sparse_mv_attention
1256
+ config['mvcd_attention'] = mvcd_attention
1257
+ config["down_block_types"] = [
1258
+ "CrossAttnDownBlockMV2D",
1259
+ "CrossAttnDownBlockMV2D",
1260
+ "CrossAttnDownBlockMV2D",
1261
+ "DownBlock2D"
1262
+ ]
1263
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1264
+ config["up_block_types"] = [
1265
+ "UpBlock2D",
1266
+ "CrossAttnUpBlockMV2D",
1267
+ "CrossAttnUpBlockMV2D",
1268
+ "CrossAttnUpBlockMV2D"
1269
+ ]
1270
+ config['class_embed_type'] = 'projection'
1271
+ if camera_embedding_type == 'e_de_da_sincos':
1272
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1273
+ else:
1274
+ raise NotImplementedError
1275
+
1276
+ # load model
1277
+ model_file = None
1278
+ if from_flax:
1279
+ raise NotImplementedError
1280
+ else:
1281
+ if use_safetensors:
1282
+ try:
1283
+ model_file = _get_model_file(
1284
+ pretrained_model_name_or_path,
1285
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1286
+ cache_dir=cache_dir,
1287
+ force_download=force_download,
1288
+ resume_download=resume_download,
1289
+ proxies=proxies,
1290
+ local_files_only=local_files_only,
1291
+ use_auth_token=use_auth_token,
1292
+ revision=revision,
1293
+ subfolder=subfolder,
1294
+ user_agent=user_agent,
1295
+ commit_hash=commit_hash,
1296
+ )
1297
+ except IOError as e:
1298
+ if not allow_pickle:
1299
+ raise e
1300
+ pass
1301
+ if model_file is None:
1302
+ model_file = _get_model_file(
1303
+ pretrained_model_name_or_path,
1304
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1305
+ cache_dir=cache_dir,
1306
+ force_download=force_download,
1307
+ resume_download=resume_download,
1308
+ proxies=proxies,
1309
+ local_files_only=local_files_only,
1310
+ use_auth_token=use_auth_token,
1311
+ revision=revision,
1312
+ subfolder=subfolder,
1313
+ user_agent=user_agent,
1314
+ commit_hash=commit_hash,
1315
+ )
1316
+
1317
+ model = cls.from_config(config, **unused_kwargs)
1318
+ import copy
1319
+ state_dict_v0 = load_state_dict(model_file, variant=variant)
1320
+ state_dict = copy.deepcopy(state_dict_v0)
1321
+ # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last
1322
+ # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid
1323
+ for key in state_dict_v0:
1324
+ if 'attn_joint.' in key:
1325
+ tmp = copy.deepcopy(key)
1326
+ state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp)
1327
+ if 'norm_joint.' in key:
1328
+ tmp = copy.deepcopy(key)
1329
+ state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp)
1330
+ if 'attn_joint_twice.' in key:
1331
+ tmp = copy.deepcopy(key)
1332
+ state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp)
1333
+ if 'norm_joint_twice.' in key:
1334
+ tmp = copy.deepcopy(key)
1335
+ state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp)
1336
+
1337
+ model._convert_deprecated_attention_blocks(state_dict)
1338
+
1339
+ conv_in_weight = state_dict['conv_in.weight']
1340
+ conv_out_weight = state_dict['conv_out.weight']
1341
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1342
+ model,
1343
+ state_dict,
1344
+ model_file,
1345
+ pretrained_model_name_or_path,
1346
+ ignore_mismatched_sizes=True,
1347
+ )
1348
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1349
+ # initialize from the original SD structure
1350
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1351
+
1352
+ # whether to place all zero to new layers?
1353
+ if zero_init_conv_in:
1354
+ model.conv_in.weight.data[:,4:] = 0.
1355
+
1356
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1357
+ # initialize from the original SD structure
1358
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1359
+ if out_channels == 8: # copy for the last 4 channels
1360
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1361
+
1362
+ if zero_init_camera_projection:
1363
+ for p in model.class_embedding.parameters():
1364
+ torch.nn.init.zeros_(p)
1365
+
1366
+ loading_info = {
1367
+ "missing_keys": missing_keys,
1368
+ "unexpected_keys": unexpected_keys,
1369
+ "mismatched_keys": mismatched_keys,
1370
+ "error_msgs": error_msgs,
1371
+ }
1372
+
1373
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1374
+ raise ValueError(
1375
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1376
+ )
1377
+ elif torch_dtype is not None:
1378
+ model = model.to(torch_dtype)
1379
+
1380
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1381
+
1382
+ # Set model in evaluation mode to deactivate DropOut modules by default
1383
+ model.eval()
1384
+ if output_loading_info:
1385
+ return model, loading_info
1386
+
1387
+ return model
1388
+
1389
+ @classmethod
1390
+ def _load_pretrained_model_2d(
1391
+ cls,
1392
+ model,
1393
+ state_dict,
1394
+ resolved_archive_file,
1395
+ pretrained_model_name_or_path,
1396
+ ignore_mismatched_sizes=False,
1397
+ ):
1398
+ # Retrieve missing & unexpected_keys
1399
+ model_state_dict = model.state_dict()
1400
+ loaded_keys = list(state_dict.keys())
1401
+
1402
+ expected_keys = list(model_state_dict.keys())
1403
+
1404
+ original_loaded_keys = loaded_keys
1405
+
1406
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1407
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1408
+
1409
+ # Make sure we are able to load base models as well as derived models (with heads)
1410
+ model_to_load = model
1411
+
1412
+ def _find_mismatched_keys(
1413
+ state_dict,
1414
+ model_state_dict,
1415
+ loaded_keys,
1416
+ ignore_mismatched_sizes,
1417
+ ):
1418
+ mismatched_keys = []
1419
+ if ignore_mismatched_sizes:
1420
+ for checkpoint_key in loaded_keys:
1421
+ model_key = checkpoint_key
1422
+
1423
+ if (
1424
+ model_key in model_state_dict
1425
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1426
+ ):
1427
+ mismatched_keys.append(
1428
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1429
+ )
1430
+ del state_dict[checkpoint_key]
1431
+ return mismatched_keys
1432
+
1433
+ if state_dict is not None:
1434
+ # Whole checkpoint
1435
+ mismatched_keys = _find_mismatched_keys(
1436
+ state_dict,
1437
+ model_state_dict,
1438
+ original_loaded_keys,
1439
+ ignore_mismatched_sizes,
1440
+ )
1441
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1442
+
1443
+ if len(error_msgs) > 0:
1444
+ error_msg = "\n\t".join(error_msgs)
1445
+ if "size mismatch" in error_msg:
1446
+ error_msg += (
1447
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1448
+ )
1449
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1450
+
1451
+ if len(unexpected_keys) > 0:
1452
+ logger.warning(
1453
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1454
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1455
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1456
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1457
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1458
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1459
+ " identical (initializing a BertForSequenceClassification model from a"
1460
+ " BertForSequenceClassification model)."
1461
+ )
1462
+ else:
1463
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1464
+ if len(missing_keys) > 0:
1465
+ logger.warning(
1466
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1467
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1468
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1469
+ )
1470
+ elif len(mismatched_keys) == 0:
1471
+ logger.info(
1472
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1473
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1474
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1475
+ " without further training."
1476
+ )
1477
+ if len(mismatched_keys) > 0:
1478
+ mismatched_warning = "\n".join(
1479
+ [
1480
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1481
+ for key, shape1, shape2 in mismatched_keys
1482
+ ]
1483
+ )
1484
+ logger.warning(
1485
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1486
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1487
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1488
+ " able to use it for predictions and inference."
1489
+ )
1490
+
1491
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1492
+