pookiefoof commited on
Commit
ff0340e
1 Parent(s): eaa1b39
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 2D_Stage/configs/infer.yaml +24 -0
  2. 2D_Stage/input.png +0 -0
  3. 2D_Stage/material/examples/1.png +0 -0
  4. 2D_Stage/material/examples/2.png +0 -0
  5. 2D_Stage/material/examples/3.png +0 -0
  6. 2D_Stage/material/examples/4.png +0 -0
  7. 2D_Stage/material/examples/5.png +0 -0
  8. 2D_Stage/material/examples/6.png +0 -0
  9. 2D_Stage/material/examples/7.png +0 -0
  10. 2D_Stage/material/examples/8.png +0 -0
  11. 2D_Stage/material/pose.json +38 -0
  12. 2D_Stage/material/pose0.png +0 -0
  13. 2D_Stage/material/pose1.png +0 -0
  14. 2D_Stage/material/pose2.png +0 -0
  15. 2D_Stage/material/pose3.png +0 -0
  16. 2D_Stage/tuneavideo/__pycache__/util.cpython-310.pyc +0 -0
  17. 2D_Stage/tuneavideo/models/PoseGuider.py +59 -0
  18. 2D_Stage/tuneavideo/models/__pycache__/PoseGuider.cpython-310.pyc +0 -0
  19. 2D_Stage/tuneavideo/models/__pycache__/refunet.cpython-310.pyc +0 -0
  20. 2D_Stage/tuneavideo/models/__pycache__/resnet.cpython-310.pyc +0 -0
  21. 2D_Stage/tuneavideo/models/__pycache__/transformer_mv2d.cpython-310.pyc +0 -0
  22. 2D_Stage/tuneavideo/models/__pycache__/unet.cpython-310.pyc +0 -0
  23. 2D_Stage/tuneavideo/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  24. 2D_Stage/tuneavideo/models/__pycache__/unet_mv2d_blocks.cpython-310.pyc +0 -0
  25. 2D_Stage/tuneavideo/models/__pycache__/unet_mv2d_condition.cpython-310.pyc +0 -0
  26. 2D_Stage/tuneavideo/models/__pycache__/unet_mv2d_ref.cpython-310.pyc +0 -0
  27. 2D_Stage/tuneavideo/models/attention.py +344 -0
  28. 2D_Stage/tuneavideo/models/imageproj.py +118 -0
  29. 2D_Stage/tuneavideo/models/refunet.py +125 -0
  30. 2D_Stage/tuneavideo/models/resnet.py +210 -0
  31. 2D_Stage/tuneavideo/models/transformer_mv2d.py +1010 -0
  32. 2D_Stage/tuneavideo/models/unet.py +497 -0
  33. 2D_Stage/tuneavideo/models/unet_blocks.py +596 -0
  34. 2D_Stage/tuneavideo/models/unet_mv2d_blocks.py +926 -0
  35. 2D_Stage/tuneavideo/models/unet_mv2d_condition.py +1509 -0
  36. 2D_Stage/tuneavideo/models/unet_mv2d_ref.py +1570 -0
  37. 2D_Stage/tuneavideo/pipelines/__pycache__/pipeline_tuneavideo.cpython-310.pyc +0 -0
  38. 2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py +585 -0
  39. 2D_Stage/tuneavideo/util.py +128 -0
  40. 2D_Stage/webui.py +323 -0
  41. 3D_Stage/__pycache__/refine.cpython-310.pyc +0 -0
  42. 3D_Stage/configs/infer.yaml +104 -0
  43. 3D_Stage/load/tets/128_tets.npz +3 -0
  44. 3D_Stage/load/tets/256_tets.npz +3 -0
  45. 3D_Stage/load/tets/32_tets.npz +3 -0
  46. 3D_Stage/load/tets/64_tets.npz +3 -0
  47. 3D_Stage/load/tets/generate_tets.py +58 -0
  48. 3D_Stage/lrm/__init__.py +29 -0
  49. 3D_Stage/lrm/__pycache__/__init__.cpython-310.pyc +0 -0
  50. 3D_Stage/lrm/models/__init__.py +0 -0
2D_Stage/configs/infer.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: "stabilityai/stable-diffusion-2-1"
2
+ image_encoder_path: "./models/image_encoder"
3
+ ckpt_dir: "./models/checkpoint"
4
+
5
+ validation:
6
+ guidance_scale: 5.0
7
+ use_inv_latent: False
8
+ video_length: 4
9
+
10
+ use_pose_guider: True
11
+ use_noise: False
12
+ use_shifted_noise: False
13
+ unet_condition_type: image
14
+
15
+ unet_from_pretrained_kwargs:
16
+ camera_embedding_type: 'e_de_da_sincos'
17
+ projection_class_embeddings_input_dim: 10 # modify
18
+ joint_attention: false # modify
19
+ num_views: 4
20
+ sample_size: 96
21
+ zero_init_conv_in: false
22
+ zero_init_camera_projection: false
23
+ in_channels: 4
24
+ use_safetensors: true
2D_Stage/input.png ADDED
2D_Stage/material/examples/1.png ADDED
2D_Stage/material/examples/2.png ADDED
2D_Stage/material/examples/3.png ADDED
2D_Stage/material/examples/4.png ADDED
2D_Stage/material/examples/5.png ADDED
2D_Stage/material/examples/6.png ADDED
2D_Stage/material/examples/7.png ADDED
2D_Stage/material/examples/8.png ADDED
2D_Stage/material/pose.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 0, 0, -1, 0,
5
+ 0, 1, 0, 0,
6
+ 1, 0, 0, 0,
7
+ 1.5, 0, 0, 1
8
+ ],
9
+ "pose0.png"
10
+ ],
11
+ [
12
+ [
13
+ 0, 0, 1, 0,
14
+ 0, 1, 0, 0,
15
+ -1, 0, 0, 0,
16
+ -1.5, 0, 0, 1
17
+ ],
18
+ "pose1.png"
19
+ ],
20
+ [
21
+ [
22
+ 0, 0, 1, 0,
23
+ 0, 1, 0, 0,
24
+ -1, 0, 0, 0,
25
+ -1.5, 0, 0, 1
26
+ ],
27
+ "pose2.png"
28
+ ],
29
+ [
30
+ [
31
+ -1, 0, 0, 0,
32
+ 0, 1, 0, 0,
33
+ 0, 0, -1, 0,
34
+ 0, 0, -1.5, 1
35
+ ],
36
+ "pose3.png"
37
+ ]
38
+ ]
2D_Stage/material/pose0.png ADDED
2D_Stage/material/pose1.png ADDED
2D_Stage/material/pose2.png ADDED
2D_Stage/material/pose3.png ADDED
2D_Stage/tuneavideo/__pycache__/util.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
2D_Stage/tuneavideo/models/PoseGuider.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.init as init
5
+ from einops import rearrange
6
+
7
+ class PoseGuider(nn.Module):
8
+ def __init__(self, noise_latent_channels=4):
9
+ super(PoseGuider, self).__init__()
10
+
11
+ self.conv_layers = nn.Sequential(
12
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
13
+ nn.ReLU(),
14
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
15
+ nn.ReLU(),
16
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
17
+ nn.ReLU(),
18
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
19
+ nn.ReLU()
20
+ )
21
+
22
+ # Final projection layer
23
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
24
+
25
+ # Initialize layers
26
+ self._initialize_weights()
27
+
28
+ def _initialize_weights(self):
29
+ # Initialize weights with Gaussian distribution and zero out the final layer
30
+ for m in self.conv_layers:
31
+ if isinstance(m, nn.Conv2d):
32
+ init.normal_(m.weight, mean=0.0, std=0.02)
33
+ if m.bias is not None:
34
+ init.zeros_(m.bias)
35
+
36
+ init.zeros_(self.final_proj.weight)
37
+ if self.final_proj.bias is not None:
38
+ init.zeros_(self.final_proj.bias)
39
+
40
+ def forward(self, pose_image):
41
+ x = self.conv_layers(pose_image)
42
+ x = self.final_proj(x)
43
+
44
+ return x
45
+
46
+ @classmethod
47
+ def from_pretrained(pretrained_model_path):
48
+ if not os.path.exists(pretrained_model_path):
49
+ print(f"There is no model file in {pretrained_model_path}")
50
+ print(f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ...")
51
+
52
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
53
+ model = PoseGuider(noise_latent_channels=4)
54
+ m, u = model.load_state_dict(state_dict, strict=False)
55
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
56
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
57
+ print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M")
58
+
59
+ return model
2D_Stage/tuneavideo/models/__pycache__/PoseGuider.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/refunet.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (5.13 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/transformer_mv2d.cpython-310.pyc ADDED
Binary file (23 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/unet.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/unet_mv2d_blocks.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/unet_mv2d_condition.cpython-310.pyc ADDED
Binary file (45.7 kB). View file
 
2D_Stage/tuneavideo/models/__pycache__/unet_mv2d_ref.cpython-310.pyc ADDED
Binary file (48.1 kB). View file
 
2D_Stage/tuneavideo/models/attention.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ use_attn_temp: bool = False,
49
+ ):
50
+ super().__init__()
51
+ self.use_linear_projection = use_linear_projection
52
+ self.num_attention_heads = num_attention_heads
53
+ self.attention_head_dim = attention_head_dim
54
+ inner_dim = num_attention_heads * attention_head_dim
55
+
56
+ # Define input layers
57
+ self.in_channels = in_channels
58
+
59
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
60
+ if use_linear_projection:
61
+ self.proj_in = nn.Linear(in_channels, inner_dim)
62
+ else:
63
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
64
+
65
+ # Define transformers blocks
66
+ self.transformer_blocks = nn.ModuleList(
67
+ [
68
+ BasicTransformerBlock(
69
+ inner_dim,
70
+ num_attention_heads,
71
+ attention_head_dim,
72
+ dropout=dropout,
73
+ cross_attention_dim=cross_attention_dim,
74
+ activation_fn=activation_fn,
75
+ num_embeds_ada_norm=num_embeds_ada_norm,
76
+ attention_bias=attention_bias,
77
+ only_cross_attention=only_cross_attention,
78
+ upcast_attention=upcast_attention,
79
+ use_attn_temp = use_attn_temp,
80
+ )
81
+ for d in range(num_layers)
82
+ ]
83
+ )
84
+
85
+ # 4. Define output layers
86
+ if use_linear_projection:
87
+ self.proj_out = nn.Linear(in_channels, inner_dim)
88
+ else:
89
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
90
+
91
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
92
+ # Input
93
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
94
+ video_length = hidden_states.shape[2]
95
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
96
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
97
+
98
+ batch, channel, height, weight = hidden_states.shape
99
+ residual = hidden_states
100
+
101
+ hidden_states = self.norm(hidden_states)
102
+ if not self.use_linear_projection:
103
+ hidden_states = self.proj_in(hidden_states)
104
+ inner_dim = hidden_states.shape[1]
105
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
106
+ else:
107
+ inner_dim = hidden_states.shape[1]
108
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
109
+ hidden_states = self.proj_in(hidden_states)
110
+
111
+ # Blocks
112
+ for block in self.transformer_blocks:
113
+ hidden_states = block(
114
+ hidden_states,
115
+ encoder_hidden_states=encoder_hidden_states,
116
+ timestep=timestep,
117
+ video_length=video_length
118
+ )
119
+
120
+ # Output
121
+ if not self.use_linear_projection:
122
+ hidden_states = (
123
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
124
+ )
125
+ hidden_states = self.proj_out(hidden_states)
126
+ else:
127
+ hidden_states = self.proj_out(hidden_states)
128
+ hidden_states = (
129
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
130
+ )
131
+
132
+ output = hidden_states + residual
133
+
134
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
135
+ if not return_dict:
136
+ return (output,)
137
+
138
+ return Transformer3DModelOutput(sample=output)
139
+
140
+
141
+ class BasicTransformerBlock(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim: int,
145
+ num_attention_heads: int,
146
+ attention_head_dim: int,
147
+ dropout=0.0,
148
+ cross_attention_dim: Optional[int] = None,
149
+ activation_fn: str = "geglu",
150
+ num_embeds_ada_norm: Optional[int] = None,
151
+ attention_bias: bool = False,
152
+ only_cross_attention: bool = False,
153
+ upcast_attention: bool = False,
154
+ use_attn_temp: bool = False
155
+ ):
156
+ super().__init__()
157
+ self.only_cross_attention = only_cross_attention
158
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
159
+ self.use_attn_temp = use_attn_temp
160
+ # SC-Attn
161
+ self.attn1 = SparseCausalAttention(
162
+ query_dim=dim,
163
+ heads=num_attention_heads,
164
+ dim_head=attention_head_dim,
165
+ dropout=dropout,
166
+ bias=attention_bias,
167
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
168
+ upcast_attention=upcast_attention,
169
+ )
170
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
171
+
172
+ # Cross-Attn
173
+ if cross_attention_dim is not None:
174
+ self.attn2 = CrossAttention(
175
+ query_dim=dim,
176
+ cross_attention_dim=cross_attention_dim,
177
+ heads=num_attention_heads,
178
+ dim_head=attention_head_dim,
179
+ dropout=dropout,
180
+ bias=attention_bias,
181
+ upcast_attention=upcast_attention,
182
+ )
183
+ else:
184
+ self.attn2 = None
185
+
186
+ if cross_attention_dim is not None:
187
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
188
+ else:
189
+ self.norm2 = None
190
+
191
+ # Feed-forward
192
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
193
+ self.norm3 = nn.LayerNorm(dim)
194
+
195
+ # Temp-Attn
196
+ if self.use_attn_temp:
197
+ self.attn_temp = CrossAttention(
198
+ query_dim=dim,
199
+ heads=num_attention_heads,
200
+ dim_head=attention_head_dim,
201
+ dropout=dropout,
202
+ bias=attention_bias,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
206
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
207
+
208
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
209
+ if not is_xformers_available():
210
+ print("Here is how to install it")
211
+ raise ModuleNotFoundError(
212
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
213
+ " xformers",
214
+ name="xformers",
215
+ )
216
+ elif not torch.cuda.is_available():
217
+ raise ValueError(
218
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
219
+ " available for GPU "
220
+ )
221
+ else:
222
+ try:
223
+ # Make sure we can run the memory efficient attention
224
+ _ = xformers.ops.memory_efficient_attention(
225
+ torch.randn((1, 2, 40), device="cuda"),
226
+ torch.randn((1, 2, 40), device="cuda"),
227
+ torch.randn((1, 2, 40), device="cuda"),
228
+ )
229
+ except Exception as e:
230
+ raise e
231
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
232
+ if self.attn2 is not None:
233
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
234
+ #self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
235
+
236
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
237
+ # SparseCausal-Attention
238
+ norm_hidden_states = (
239
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
240
+ )
241
+
242
+ if self.only_cross_attention:
243
+ hidden_states = (
244
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
245
+ )
246
+ else:
247
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
248
+
249
+ if self.attn2 is not None:
250
+ # Cross-Attention
251
+ norm_hidden_states = (
252
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
253
+ )
254
+ hidden_states = (
255
+ self.attn2(
256
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
257
+ )
258
+ + hidden_states
259
+ )
260
+
261
+ # Feed-forward
262
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
263
+
264
+ # Temporal-Attention
265
+ if self.use_attn_temp:
266
+ d = hidden_states.shape[1]
267
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
268
+ norm_hidden_states = (
269
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
270
+ )
271
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
272
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
273
+
274
+ return hidden_states
275
+
276
+
277
+ class SparseCausalAttention(CrossAttention):
278
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, use_full_attn=True):
279
+ batch_size, sequence_length, _ = hidden_states.shape
280
+
281
+ encoder_hidden_states = encoder_hidden_states
282
+
283
+ if self.group_norm is not None:
284
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
285
+
286
+ query = self.to_q(hidden_states)
287
+ # query = rearrange(query, "(b f) d c -> b (f d) c", f=video_length)
288
+ dim = query.shape[-1]
289
+ query = self.reshape_heads_to_batch_dim(query)
290
+
291
+ if self.added_kv_proj_dim is not None:
292
+ raise NotImplementedError
293
+
294
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
295
+ key = self.to_k(encoder_hidden_states)
296
+ value = self.to_v(encoder_hidden_states)
297
+
298
+ former_frame_index = torch.arange(video_length) - 1
299
+ former_frame_index[0] = 0
300
+
301
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
302
+ if not use_full_attn:
303
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
304
+ else:
305
+ # key = torch.cat([key[:, [0] * video_length], key[:, [1] * video_length], key[:, [2] * video_length], key[:, [3] * video_length]], dim=2)
306
+ key_video_length = [key[:, [i] * video_length] for i in range(video_length)]
307
+ key = torch.cat(key_video_length, dim=2)
308
+ key = rearrange(key, "b f d c -> (b f) d c")
309
+
310
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
311
+ if not use_full_attn:
312
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
313
+ else:
314
+ # value = torch.cat([value[:, [0] * video_length], value[:, [1] * video_length], value[:, [2] * video_length], value[:, [3] * video_length]], dim=2)
315
+ value_video_length = [value[:, [i] * video_length] for i in range(video_length)]
316
+ value = torch.cat(value_video_length, dim=2)
317
+ value = rearrange(value, "b f d c -> (b f) d c")
318
+
319
+ key = self.reshape_heads_to_batch_dim(key)
320
+ value = self.reshape_heads_to_batch_dim(value)
321
+
322
+ if attention_mask is not None:
323
+ if attention_mask.shape[-1] != query.shape[1]:
324
+ target_length = query.shape[1]
325
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
326
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
327
+
328
+ # attention, what we cannot get enough of
329
+ if self._use_memory_efficient_attention_xformers:
330
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
331
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
332
+ hidden_states = hidden_states.to(query.dtype)
333
+ else:
334
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
335
+ hidden_states = self._attention(query, key, value, attention_mask)
336
+ else:
337
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
338
+
339
+ # linear proj
340
+ hidden_states = self.to_out[0](hidden_states)
341
+
342
+ # dropout
343
+ hidden_states = self.to_out[1](hidden_states)
344
+ return hidden_states
2D_Stage/tuneavideo/models/imageproj.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ # FFN
8
+ def FeedForward(dim, mult=4):
9
+ inner_dim = int(dim * mult)
10
+ return nn.Sequential(
11
+ nn.LayerNorm(dim),
12
+ nn.Linear(dim, inner_dim, bias=False),
13
+ nn.GELU(),
14
+ nn.Linear(inner_dim, dim, bias=False),
15
+ )
16
+
17
+ def reshape_tensor(x, heads):
18
+ bs, length, width = x.shape
19
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
+ x = x.view(bs, length, heads, -1)
21
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
+ x = x.transpose(1, 2)
23
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
+ x = x.reshape(bs, heads, length, -1)
25
+ return x
26
+
27
+
28
+ class PerceiverAttention(nn.Module):
29
+ def __init__(self, *, dim, dim_head=64, heads=8):
30
+ super().__init__()
31
+ self.scale = dim_head**-0.5
32
+ self.dim_head = dim_head
33
+ self.heads = heads
34
+ inner_dim = dim_head * heads
35
+
36
+ self.norm1 = nn.LayerNorm(dim)
37
+ self.norm2 = nn.LayerNorm(dim)
38
+
39
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
40
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
41
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
42
+
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, l, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+
61
+ q = reshape_tensor(q, self.heads)
62
+ k = reshape_tensor(k, self.heads)
63
+ v = reshape_tensor(v, self.heads)
64
+
65
+ # attention
66
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
67
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
68
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
+ out = weight @ v
70
+
71
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
+
73
+ return self.to_out(out)
74
+
75
+ class Resampler(nn.Module):
76
+ def __init__(
77
+ self,
78
+ dim=1024,
79
+ depth=8,
80
+ dim_head=64,
81
+ heads=16,
82
+ num_queries=8,
83
+ embedding_dim=768,
84
+ output_dim=1024,
85
+ ff_mult=4,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
90
+
91
+ self.proj_in = nn.Linear(embedding_dim, dim)
92
+
93
+ self.proj_out = nn.Linear(dim, output_dim)
94
+ self.norm_out = nn.LayerNorm(output_dim)
95
+
96
+ self.layers = nn.ModuleList([])
97
+ for _ in range(depth):
98
+ self.layers.append(
99
+ nn.ModuleList(
100
+ [
101
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
102
+ FeedForward(dim=dim, mult=ff_mult),
103
+ ]
104
+ )
105
+ )
106
+
107
+ def forward(self, x):
108
+
109
+ latents = self.latents.repeat(x.size(0), 1, 1)
110
+
111
+ x = self.proj_in(x)
112
+
113
+ for attn, ff in self.layers:
114
+ latents = attn(x, latents) + latents
115
+ latents = ff(latents) + latents
116
+
117
+ latents = self.proj_out(latents)
118
+ return self.norm_out(latents)
2D_Stage/tuneavideo/models/refunet.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from typing import Any, Dict, Optional
4
+ from diffusers.utils.import_utils import is_xformers_available
5
+ from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
6
+ class ReferenceOnlyAttnProc(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ chained_proc,
10
+ enabled=False,
11
+ name=None
12
+ ) -> None:
13
+ super().__init__()
14
+ self.enabled = enabled
15
+ self.chained_proc = chained_proc
16
+ self.name = name
17
+
18
+ def __call__(
19
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
20
+ mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4,
21
+ multiview_attention=True,
22
+ cross_domain_attention=False,
23
+ ) -> Any:
24
+ if encoder_hidden_states is None:
25
+ encoder_hidden_states = hidden_states
26
+ # print(self.enabled)
27
+ if self.enabled:
28
+ if mode == 'w':
29
+ ref_dict[self.name] = encoder_hidden_states
30
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1,
31
+ multiview_attention=False,
32
+ cross_domain_attention=False,)
33
+ elif mode == 'r':
34
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
35
+ if self.name in ref_dict:
36
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
37
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
38
+ multiview_attention=False,
39
+ cross_domain_attention=False,)
40
+ elif mode == 'm':
41
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
42
+ elif mode == 'n':
43
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
44
+ encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
45
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
46
+ multiview_attention=False,
47
+ cross_domain_attention=False,)
48
+ else:
49
+ assert False, mode
50
+ else:
51
+ res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
52
+ return res
53
+
54
+ class RefOnlyNoisedUNet(torch.nn.Module):
55
+ def __init__(self, unet, train_sched, val_sched) -> None:
56
+ super().__init__()
57
+ self.unet = unet
58
+ self.train_sched = train_sched
59
+ self.val_sched = val_sched
60
+
61
+ unet_lora_attn_procs = dict()
62
+ for name, _ in unet.attn_processors.items():
63
+ if is_xformers_available():
64
+ default_attn_proc = XFormersMVAttnProcessor()
65
+ else:
66
+ default_attn_proc = MVAttnProcessor()
67
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
68
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name)
69
+
70
+ self.unet.set_attn_processor(unet_lora_attn_procs)
71
+
72
+ def __getattr__(self, name: str):
73
+ try:
74
+ return super().__getattr__(name)
75
+ except AttributeError:
76
+ return getattr(self.unet, name)
77
+
78
+ def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
79
+ if is_cfg_guidance:
80
+ encoder_hidden_states = encoder_hidden_states[1:]
81
+ class_labels = class_labels[1:]
82
+ self.unet(
83
+ noisy_cond_lat, timestep,
84
+ encoder_hidden_states=encoder_hidden_states,
85
+ class_labels=class_labels,
86
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
87
+ **kwargs
88
+ )
89
+
90
+ def forward(
91
+ self, sample, timestep, encoder_hidden_states, class_labels=None,
92
+ *args, cross_attention_kwargs,
93
+ down_block_res_samples=None, mid_block_res_sample=None,
94
+ **kwargs
95
+ ):
96
+ cond_lat = cross_attention_kwargs['cond_lat']
97
+ is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
98
+ noise = torch.randn_like(cond_lat)
99
+ if self.training:
100
+ noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
101
+ noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
102
+ else:
103
+ noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
104
+ noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
105
+ ref_dict = {}
106
+ self.forward_cond(
107
+ noisy_cond_lat, timestep,
108
+ encoder_hidden_states, class_labels,
109
+ ref_dict, is_cfg_guidance, **kwargs
110
+ )
111
+ weight_dtype = self.unet.dtype
112
+ return self.unet(
113
+ sample, timestep,
114
+ encoder_hidden_states, *args,
115
+ class_labels=class_labels,
116
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
117
+ down_block_additional_residuals=[
118
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
119
+ ] if down_block_res_samples is not None else None,
120
+ mid_block_additional_residual=(
121
+ mid_block_res_sample.to(dtype=weight_dtype)
122
+ if mid_block_res_sample is not None else None
123
+ ),
124
+ **kwargs
125
+ )
2D_Stage/tuneavideo/models/resnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class Upsample3D(nn.Module):
22
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.out_channels = out_channels or channels
26
+ self.use_conv = use_conv
27
+ self.use_conv_transpose = use_conv_transpose
28
+ self.name = name
29
+
30
+ conv = None
31
+ if use_conv_transpose:
32
+ raise NotImplementedError
33
+ elif use_conv:
34
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35
+
36
+ if name == "conv":
37
+ self.conv = conv
38
+ else:
39
+ self.Conv2d_0 = conv
40
+
41
+ def forward(self, hidden_states, output_size=None):
42
+ assert hidden_states.shape[1] == self.channels
43
+
44
+ if self.use_conv_transpose:
45
+ raise NotImplementedError
46
+
47
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
48
+ dtype = hidden_states.dtype
49
+ if dtype == torch.bfloat16:
50
+ hidden_states = hidden_states.to(torch.float32)
51
+
52
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
53
+ if hidden_states.shape[0] >= 64:
54
+ hidden_states = hidden_states.contiguous()
55
+
56
+ # if `output_size` is passed we force the interpolation output
57
+ # size and do not make use of `scale_factor=2`
58
+ if output_size is None:
59
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
60
+ else:
61
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
62
+
63
+ # If the input is bfloat16, we cast back to bfloat16
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(dtype)
66
+
67
+ if self.use_conv:
68
+ if self.name == "conv":
69
+ hidden_states = self.conv(hidden_states)
70
+ else:
71
+ hidden_states = self.Conv2d_0(hidden_states)
72
+
73
+ return hidden_states
74
+
75
+
76
+ class Downsample3D(nn.Module):
77
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.out_channels = out_channels or channels
81
+ self.use_conv = use_conv
82
+ self.padding = padding
83
+ stride = 2
84
+ self.name = name
85
+
86
+ if use_conv:
87
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ if name == "conv":
92
+ self.Conv2d_0 = conv
93
+ self.conv = conv
94
+ elif name == "Conv2d_0":
95
+ self.conv = conv
96
+ else:
97
+ self.conv = conv
98
+
99
+ def forward(self, hidden_states):
100
+ assert hidden_states.shape[1] == self.channels
101
+ if self.use_conv and self.padding == 0:
102
+ raise NotImplementedError
103
+
104
+ assert hidden_states.shape[1] == self.channels
105
+ hidden_states = self.conv(hidden_states)
106
+
107
+ return hidden_states
108
+
109
+
110
+ class ResnetBlock3D(nn.Module):
111
+ def __init__(
112
+ self,
113
+ *,
114
+ in_channels,
115
+ out_channels=None,
116
+ conv_shortcut=False,
117
+ dropout=0.0,
118
+ temb_channels=512,
119
+ groups=32,
120
+ groups_out=None,
121
+ pre_norm=True,
122
+ eps=1e-6,
123
+ non_linearity="swish",
124
+ time_embedding_norm="default",
125
+ output_scale_factor=1.0,
126
+ use_in_shortcut=None,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142
+
143
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
144
+
145
+ if temb_channels is not None:
146
+ if self.time_embedding_norm == "default":
147
+ time_emb_proj_out_channels = out_channels
148
+ elif self.time_embedding_norm == "scale_shift":
149
+ time_emb_proj_out_channels = out_channels * 2
150
+ else:
151
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
152
+
153
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
154
+ else:
155
+ self.time_emb_proj = None
156
+
157
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
158
+ self.dropout = torch.nn.Dropout(dropout)
159
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
160
+
161
+ if non_linearity == "swish":
162
+ self.nonlinearity = lambda x: F.silu(x)
163
+ elif non_linearity == "mish":
164
+ self.nonlinearity = Mish()
165
+ elif non_linearity == "silu":
166
+ self.nonlinearity = nn.SiLU()
167
+
168
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
169
+
170
+ self.conv_shortcut = None
171
+ if self.use_in_shortcut:
172
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173
+
174
+ def forward(self, input_tensor, temb):
175
+ hidden_states = input_tensor
176
+
177
+ hidden_states = self.norm1(hidden_states)
178
+ hidden_states = self.nonlinearity(hidden_states)
179
+
180
+ hidden_states = self.conv1(hidden_states)
181
+
182
+ if temb is not None:
183
+ # temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
184
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, :, None, None].permute(0,2,1,3,4)
185
+
186
+ if temb is not None and self.time_embedding_norm == "default":
187
+ hidden_states = hidden_states + temb
188
+
189
+ hidden_states = self.norm2(hidden_states)
190
+
191
+ if temb is not None and self.time_embedding_norm == "scale_shift":
192
+ scale, shift = torch.chunk(temb, 2, dim=1)
193
+ hidden_states = hidden_states * (1 + scale) + shift
194
+
195
+ hidden_states = self.nonlinearity(hidden_states)
196
+
197
+ hidden_states = self.dropout(hidden_states)
198
+ hidden_states = self.conv2(hidden_states)
199
+
200
+ if self.conv_shortcut is not None:
201
+ input_tensor = self.conv_shortcut(input_tensor)
202
+
203
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
204
+
205
+ return output_tensor
206
+
207
+
208
+ class Mish(torch.nn.Module):
209
+ def forward(self, hidden_states):
210
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
2D_Stage/tuneavideo/models/transformer_mv2d.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ try:
25
+ from diffusers.utils import maybe_allow_in_graph
26
+ except:
27
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
28
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
29
+ from diffusers.models.embeddings import PatchEmbed
30
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.utils.import_utils import is_xformers_available
33
+
34
+ from einops import rearrange
35
+ import pdb
36
+ import random
37
+
38
+
39
+ if is_xformers_available():
40
+ import xformers
41
+ import xformers.ops
42
+ else:
43
+ xformers = None
44
+
45
+
46
+ @dataclass
47
+ class TransformerMV2DModelOutput(BaseOutput):
48
+ """
49
+ The output of [`Transformer2DModel`].
50
+
51
+ Args:
52
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
53
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
54
+ distributions for the unnoised latent pixels.
55
+ """
56
+
57
+ sample: torch.FloatTensor
58
+
59
+
60
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
61
+ """
62
+ A 2D Transformer model for image-like data.
63
+
64
+ Parameters:
65
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
66
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
67
+ in_channels (`int`, *optional*):
68
+ The number of channels in the input and output (specify if the input is **continuous**).
69
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
70
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
71
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
72
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
73
+ This is fixed during training since it is used to learn a number of position embeddings.
74
+ num_vector_embeds (`int`, *optional*):
75
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
76
+ Includes the class for the masked latent pixel.
77
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
78
+ num_embeds_ada_norm ( `int`, *optional*):
79
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
80
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
81
+ added to the hidden states.
82
+
83
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
84
+ attention_bias (`bool`, *optional*):
85
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
86
+ """
87
+
88
+ @register_to_config
89
+ def __init__(
90
+ self,
91
+ num_attention_heads: int = 16,
92
+ attention_head_dim: int = 88,
93
+ in_channels: Optional[int] = None,
94
+ out_channels: Optional[int] = None,
95
+ num_layers: int = 1,
96
+ dropout: float = 0.0,
97
+ norm_num_groups: int = 32,
98
+ cross_attention_dim: Optional[int] = None,
99
+ attention_bias: bool = False,
100
+ sample_size: Optional[int] = None,
101
+ num_vector_embeds: Optional[int] = None,
102
+ patch_size: Optional[int] = None,
103
+ activation_fn: str = "geglu",
104
+ num_embeds_ada_norm: Optional[int] = None,
105
+ use_linear_projection: bool = False,
106
+ only_cross_attention: bool = False,
107
+ upcast_attention: bool = False,
108
+ norm_type: str = "layer_norm",
109
+ norm_elementwise_affine: bool = True,
110
+ num_views: int = 1,
111
+ joint_attention: bool=False,
112
+ joint_attention_twice: bool=False,
113
+ multiview_attention: bool=True,
114
+ cross_domain_attention: bool=False
115
+ ):
116
+ super().__init__()
117
+ self.use_linear_projection = use_linear_projection
118
+ self.num_attention_heads = num_attention_heads
119
+ self.attention_head_dim = attention_head_dim
120
+ inner_dim = num_attention_heads * attention_head_dim
121
+
122
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
123
+ # Define whether input is continuous or discrete depending on configuration
124
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
125
+ self.is_input_vectorized = num_vector_embeds is not None
126
+ self.is_input_patches = in_channels is not None and patch_size is not None
127
+
128
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
129
+ deprecation_message = (
130
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
131
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
132
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
133
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
134
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
135
+ )
136
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
137
+ norm_type = "ada_norm"
138
+
139
+ if self.is_input_continuous and self.is_input_vectorized:
140
+ raise ValueError(
141
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
142
+ " sure that either `in_channels` or `num_vector_embeds` is None."
143
+ )
144
+ elif self.is_input_vectorized and self.is_input_patches:
145
+ raise ValueError(
146
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
147
+ " sure that either `num_vector_embeds` or `num_patches` is None."
148
+ )
149
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
150
+ raise ValueError(
151
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
152
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
153
+ )
154
+
155
+ # 2. Define input layers
156
+ if self.is_input_continuous:
157
+ self.in_channels = in_channels
158
+
159
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
160
+ if use_linear_projection:
161
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
162
+ else:
163
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
164
+ elif self.is_input_vectorized:
165
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
166
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
167
+
168
+ self.height = sample_size
169
+ self.width = sample_size
170
+ self.num_vector_embeds = num_vector_embeds
171
+ self.num_latent_pixels = self.height * self.width
172
+
173
+ self.latent_image_embedding = ImagePositionalEmbeddings(
174
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
175
+ )
176
+ elif self.is_input_patches:
177
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
178
+
179
+ self.height = sample_size
180
+ self.width = sample_size
181
+
182
+ self.patch_size = patch_size
183
+ self.pos_embed = PatchEmbed(
184
+ height=sample_size,
185
+ width=sample_size,
186
+ patch_size=patch_size,
187
+ in_channels=in_channels,
188
+ embed_dim=inner_dim,
189
+ )
190
+
191
+ # 3. Define transformers blocks
192
+ self.transformer_blocks = nn.ModuleList(
193
+ [
194
+ BasicMVTransformerBlock(
195
+ inner_dim,
196
+ num_attention_heads,
197
+ attention_head_dim,
198
+ dropout=dropout,
199
+ cross_attention_dim=cross_attention_dim,
200
+ activation_fn=activation_fn,
201
+ num_embeds_ada_norm=num_embeds_ada_norm,
202
+ attention_bias=attention_bias,
203
+ only_cross_attention=only_cross_attention,
204
+ upcast_attention=upcast_attention,
205
+ norm_type=norm_type,
206
+ norm_elementwise_affine=norm_elementwise_affine,
207
+ num_views=num_views,
208
+ joint_attention=joint_attention,
209
+ joint_attention_twice=joint_attention_twice,
210
+ multiview_attention=multiview_attention,
211
+ cross_domain_attention=cross_domain_attention
212
+ )
213
+ for d in range(num_layers)
214
+ ]
215
+ )
216
+
217
+ # 4. Define output layers
218
+ self.out_channels = in_channels if out_channels is None else out_channels
219
+ if self.is_input_continuous:
220
+ # TODO: should use out_channels for continuous projections
221
+ if use_linear_projection:
222
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
223
+ else:
224
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
225
+ elif self.is_input_vectorized:
226
+ self.norm_out = nn.LayerNorm(inner_dim)
227
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
228
+ elif self.is_input_patches:
229
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
230
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
231
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ encoder_hidden_states: Optional[torch.Tensor] = None,
237
+ timestep: Optional[torch.LongTensor] = None,
238
+ class_labels: Optional[torch.LongTensor] = None,
239
+ cross_attention_kwargs: Dict[str, Any] = None,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ encoder_attention_mask: Optional[torch.Tensor] = None,
242
+ return_dict: bool = True,
243
+ ):
244
+ """
245
+ The [`Transformer2DModel`] forward method.
246
+
247
+ Args:
248
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
249
+ Input `hidden_states`.
250
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
251
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
252
+ self-attention.
253
+ timestep ( `torch.LongTensor`, *optional*):
254
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
255
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
256
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
257
+ `AdaLayerZeroNorm`.
258
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
259
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
260
+
261
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
262
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
263
+
264
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
265
+ above. This bias will be added to the cross-attention scores.
266
+ return_dict (`bool`, *optional*, defaults to `True`):
267
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
268
+ tuple.
269
+
270
+ Returns:
271
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
272
+ `tuple` where the first element is the sample tensor.
273
+ """
274
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
275
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
276
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
277
+ # expects mask of shape:
278
+ # [batch, key_tokens]
279
+ # adds singleton query_tokens dimension:
280
+ # [batch, 1, key_tokens]
281
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
282
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
283
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
284
+ if attention_mask is not None and attention_mask.ndim == 2:
285
+ # assume that mask is expressed as:
286
+ # (1 = keep, 0 = discard)
287
+ # convert mask into a bias that can be added to attention scores:
288
+ # (keep = +0, discard = -10000.0)
289
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
290
+ attention_mask = attention_mask.unsqueeze(1)
291
+
292
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
293
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
294
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
295
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
296
+
297
+ # 1. Input
298
+ if self.is_input_continuous:
299
+ batch, _, height, width = hidden_states.shape
300
+ residual = hidden_states
301
+
302
+ hidden_states = self.norm(hidden_states)
303
+ if not self.use_linear_projection:
304
+ hidden_states = self.proj_in(hidden_states)
305
+ inner_dim = hidden_states.shape[1]
306
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
307
+ else:
308
+ inner_dim = hidden_states.shape[1]
309
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
310
+ hidden_states = self.proj_in(hidden_states)
311
+ elif self.is_input_vectorized:
312
+ hidden_states = self.latent_image_embedding(hidden_states)
313
+ elif self.is_input_patches:
314
+ hidden_states = self.pos_embed(hidden_states)
315
+
316
+ # 2. Blocks
317
+ for block in self.transformer_blocks:
318
+ hidden_states = block(
319
+ hidden_states,
320
+ attention_mask=attention_mask,
321
+ encoder_hidden_states=encoder_hidden_states,
322
+ encoder_attention_mask=encoder_attention_mask,
323
+ timestep=timestep,
324
+ cross_attention_kwargs=cross_attention_kwargs,
325
+ class_labels=class_labels,
326
+ )
327
+
328
+ # 3. Output
329
+ if self.is_input_continuous:
330
+ if not self.use_linear_projection:
331
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
332
+ hidden_states = self.proj_out(hidden_states)
333
+ else:
334
+ hidden_states = self.proj_out(hidden_states)
335
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
336
+
337
+ output = hidden_states + residual
338
+ elif self.is_input_vectorized:
339
+ hidden_states = self.norm_out(hidden_states)
340
+ logits = self.out(hidden_states)
341
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
342
+ logits = logits.permute(0, 2, 1)
343
+
344
+ # log(p(x_0))
345
+ output = F.log_softmax(logits.double(), dim=1).float()
346
+ elif self.is_input_patches:
347
+ # TODO: cleanup!
348
+ conditioning = self.transformer_blocks[0].norm1.emb(
349
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
350
+ )
351
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
352
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
353
+ hidden_states = self.proj_out_2(hidden_states)
354
+
355
+ # unpatchify
356
+ height = width = int(hidden_states.shape[1] ** 0.5)
357
+ hidden_states = hidden_states.reshape(
358
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
359
+ )
360
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
361
+ output = hidden_states.reshape(
362
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
363
+ )
364
+
365
+ if not return_dict:
366
+ return (output,)
367
+
368
+ return TransformerMV2DModelOutput(sample=output)
369
+
370
+
371
+ @maybe_allow_in_graph
372
+ class BasicMVTransformerBlock(nn.Module):
373
+ r"""
374
+ A basic Transformer block.
375
+
376
+ Parameters:
377
+ dim (`int`): The number of channels in the input and output.
378
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
379
+ attention_head_dim (`int`): The number of channels in each head.
380
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
381
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
382
+ only_cross_attention (`bool`, *optional*):
383
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
384
+ double_self_attention (`bool`, *optional*):
385
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
386
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
387
+ num_embeds_ada_norm (:
388
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
389
+ attention_bias (:
390
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
391
+ """
392
+
393
+ def __init__(
394
+ self,
395
+ dim: int,
396
+ num_attention_heads: int,
397
+ attention_head_dim: int,
398
+ dropout=0.0,
399
+ cross_attention_dim: Optional[int] = None,
400
+ activation_fn: str = "geglu",
401
+ num_embeds_ada_norm: Optional[int] = None,
402
+ attention_bias: bool = False,
403
+ only_cross_attention: bool = False,
404
+ double_self_attention: bool = False,
405
+ upcast_attention: bool = False,
406
+ norm_elementwise_affine: bool = True,
407
+ norm_type: str = "layer_norm",
408
+ final_dropout: bool = False,
409
+ num_views: int = 1,
410
+ joint_attention: bool = False,
411
+ joint_attention_twice: bool = False,
412
+ multiview_attention: bool = True,
413
+ cross_domain_attention: bool = False
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+
418
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
419
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
420
+
421
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
422
+ raise ValueError(
423
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
424
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
425
+ )
426
+
427
+ # Define 3 blocks. Each block has its own normalization layer.
428
+ # 1. Self-Attn
429
+ if self.use_ada_layer_norm:
430
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
431
+ elif self.use_ada_layer_norm_zero:
432
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
433
+ else:
434
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
435
+
436
+ self.multiview_attention = multiview_attention
437
+ self.cross_domain_attention = cross_domain_attention
438
+ # import pdb;pdb.set_trace()
439
+ self.attn1 = CustomAttention(
440
+ query_dim=dim,
441
+ heads=num_attention_heads,
442
+ dim_head=attention_head_dim,
443
+ dropout=dropout,
444
+ bias=attention_bias,
445
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
446
+ upcast_attention=upcast_attention,
447
+ processor=MVAttnProcessor()
448
+ )
449
+
450
+ # 2. Cross-Attn
451
+ if cross_attention_dim is not None or double_self_attention:
452
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
453
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
454
+ # the second cross attention block.
455
+ self.norm2 = (
456
+ AdaLayerNorm(dim, num_embeds_ada_norm)
457
+ if self.use_ada_layer_norm
458
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
459
+ )
460
+ self.attn2 = Attention(
461
+ query_dim=dim,
462
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
463
+ heads=num_attention_heads,
464
+ dim_head=attention_head_dim,
465
+ dropout=dropout,
466
+ bias=attention_bias,
467
+ upcast_attention=upcast_attention,
468
+ ) # is self-attn if encoder_hidden_states is none
469
+ else:
470
+ self.norm2 = None
471
+ self.attn2 = None
472
+
473
+ # 3. Feed-forward
474
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
475
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
476
+
477
+ # let chunk size default to None
478
+ self._chunk_size = None
479
+ self._chunk_dim = 0
480
+
481
+ self.num_views = num_views
482
+
483
+ self.joint_attention = joint_attention
484
+
485
+ if self.joint_attention:
486
+ # Joint task -Attn
487
+ self.attn_joint = CustomJointAttention(
488
+ query_dim=dim,
489
+ heads=num_attention_heads,
490
+ dim_head=attention_head_dim,
491
+ dropout=dropout,
492
+ bias=attention_bias,
493
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
494
+ upcast_attention=upcast_attention,
495
+ processor=JointAttnProcessor()
496
+ )
497
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
498
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
499
+
500
+
501
+ self.joint_attention_twice = joint_attention_twice
502
+
503
+ if self.joint_attention_twice:
504
+ print("joint twice")
505
+ # Joint task -Attn
506
+ self.attn_joint_twice = CustomJointAttention(
507
+ query_dim=dim,
508
+ heads=num_attention_heads,
509
+ dim_head=attention_head_dim,
510
+ dropout=dropout,
511
+ bias=attention_bias,
512
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
513
+ upcast_attention=upcast_attention,
514
+ processor=JointAttnProcessor()
515
+ )
516
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
517
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
518
+
519
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
520
+ # Sets chunk feed-forward
521
+ self._chunk_size = chunk_size
522
+ self._chunk_dim = dim
523
+
524
+ def forward(
525
+ self,
526
+ hidden_states: torch.FloatTensor,
527
+ attention_mask: Optional[torch.FloatTensor] = None,
528
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
529
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
530
+ timestep: Optional[torch.LongTensor] = None,
531
+ cross_attention_kwargs: Dict[str, Any] = None,
532
+ class_labels: Optional[torch.LongTensor] = None,
533
+ ):
534
+ assert attention_mask is None # not supported yet
535
+ # Notice that normalization is always applied before the real computation in the following blocks.
536
+ # 1. Self-Attention
537
+ if self.use_ada_layer_norm:
538
+ norm_hidden_states = self.norm1(hidden_states, timestep)
539
+ elif self.use_ada_layer_norm_zero:
540
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
541
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
542
+ )
543
+ else:
544
+ norm_hidden_states = self.norm1(hidden_states)
545
+
546
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
547
+ attn_output = self.attn1(
548
+ norm_hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
550
+ attention_mask=attention_mask,
551
+ num_views=self.num_views,
552
+ multiview_attention=self.multiview_attention,
553
+ cross_domain_attention=self.cross_domain_attention,
554
+ **cross_attention_kwargs,
555
+ )
556
+
557
+
558
+ if self.use_ada_layer_norm_zero:
559
+ attn_output = gate_msa.unsqueeze(1) * attn_output
560
+ hidden_states = attn_output + hidden_states
561
+
562
+ # joint attention twice
563
+ if self.joint_attention_twice:
564
+ norm_hidden_states = (
565
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
566
+ )
567
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
568
+
569
+ # 2. Cross-Attention
570
+ if self.attn2 is not None:
571
+ norm_hidden_states = (
572
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
573
+ )
574
+ attn_output = self.attn2(
575
+ norm_hidden_states,
576
+ encoder_hidden_states=encoder_hidden_states,
577
+ attention_mask=encoder_attention_mask,
578
+ **cross_attention_kwargs,
579
+ )
580
+ hidden_states = attn_output + hidden_states
581
+
582
+ # 3. Feed-forward
583
+ norm_hidden_states = self.norm3(hidden_states)
584
+
585
+ if self.use_ada_layer_norm_zero:
586
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
587
+
588
+ if self._chunk_size is not None:
589
+ # "feed_forward_chunk_size" can be used to save memory
590
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
591
+ raise ValueError(
592
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
593
+ )
594
+
595
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
596
+ ff_output = torch.cat(
597
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
598
+ dim=self._chunk_dim,
599
+ )
600
+ else:
601
+ ff_output = self.ff(norm_hidden_states)
602
+
603
+ if self.use_ada_layer_norm_zero:
604
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
605
+
606
+ hidden_states = ff_output + hidden_states
607
+
608
+ if self.joint_attention:
609
+ norm_hidden_states = (
610
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
611
+ )
612
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
613
+
614
+ return hidden_states
615
+
616
+
617
+ class CustomAttention(Attention):
618
+ def set_use_memory_efficient_attention_xformers(
619
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
620
+ ):
621
+ processor = XFormersMVAttnProcessor()
622
+ self.set_processor(processor)
623
+ # print("using xformers attention processor")
624
+
625
+
626
+ class CustomJointAttention(Attention):
627
+ def set_use_memory_efficient_attention_xformers(
628
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
629
+ ):
630
+ processor = XFormersJointAttnProcessor()
631
+ self.set_processor(processor)
632
+ # print("using xformers attention processor")
633
+
634
+ class MVAttnProcessor:
635
+ r"""
636
+ Default processor for performing attention-related computations.
637
+ """
638
+
639
+ def __call__(
640
+ self,
641
+ attn: Attention,
642
+ hidden_states,
643
+ encoder_hidden_states=None,
644
+ attention_mask=None,
645
+ temb=None,
646
+ num_views=1,
647
+ multiview_attention=True
648
+ ):
649
+ residual = hidden_states
650
+
651
+ if attn.spatial_norm is not None:
652
+ hidden_states = attn.spatial_norm(hidden_states, temb)
653
+
654
+ input_ndim = hidden_states.ndim
655
+
656
+ if input_ndim == 4:
657
+ batch_size, channel, height, width = hidden_states.shape
658
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
659
+
660
+ batch_size, sequence_length, _ = (
661
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
662
+ )
663
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
664
+
665
+ if attn.group_norm is not None:
666
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
667
+
668
+ query = attn.to_q(hidden_states)
669
+
670
+ if encoder_hidden_states is None:
671
+ encoder_hidden_states = hidden_states
672
+ elif attn.norm_cross:
673
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
674
+
675
+ key = attn.to_k(encoder_hidden_states)
676
+ value = attn.to_v(encoder_hidden_states)
677
+
678
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
679
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
680
+ # pdb.set_trace()
681
+ # multi-view self-attention
682
+ if multiview_attention:
683
+ if num_views <= 6:
684
+ # after use xformer; possible to train with 6 views
685
+ # key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
686
+ # value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
687
+ key = rearrange(key, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
688
+ value = rearrange(value, '(b t) d c-> b (t d) c', t=num_views).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
689
+
690
+ else:# apply sparse attention
691
+ pass
692
+ # print("use sparse attention")
693
+ # # seems that the sparse random sampling cause problems
694
+ # # don't use random sampling, just fix the indexes
695
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
696
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
697
+ # allkeys = []
698
+ # allvalues = []
699
+ # all_indexes = {
700
+ # 0 : [0, 2, 3, 4],
701
+ # 1: [0, 1, 3, 5],
702
+ # 2: [0, 2, 3, 4],
703
+ # 3: [0, 2, 3, 4],
704
+ # 4: [0, 2, 3, 4],
705
+ # 5: [0, 1, 3, 5]
706
+ # }
707
+ # for jj in range(num_views):
708
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
709
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
710
+ # indexes = all_indexes[jj]
711
+
712
+ # indexes = torch.tensor(indexes).long().to(key.device)
713
+ # allkeys.append(onekey[:, indexes])
714
+ # allvalues.append(onevalue[:, indexes])
715
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
716
+ # values = torch.stack(allvalues, dim=1)
717
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
718
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
719
+
720
+
721
+ query = attn.head_to_batch_dim(query).contiguous()
722
+ key = attn.head_to_batch_dim(key).contiguous()
723
+ value = attn.head_to_batch_dim(value).contiguous()
724
+
725
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
726
+ hidden_states = torch.bmm(attention_probs, value)
727
+ hidden_states = attn.batch_to_head_dim(hidden_states)
728
+
729
+ # linear proj
730
+ hidden_states = attn.to_out[0](hidden_states)
731
+ # dropout
732
+ hidden_states = attn.to_out[1](hidden_states)
733
+
734
+ if input_ndim == 4:
735
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
736
+
737
+ if attn.residual_connection:
738
+ hidden_states = hidden_states + residual
739
+
740
+ hidden_states = hidden_states / attn.rescale_output_factor
741
+
742
+ return hidden_states
743
+
744
+
745
+ class XFormersMVAttnProcessor:
746
+ r"""
747
+ Default processor for performing attention-related computations.
748
+ """
749
+
750
+ def __call__(
751
+ self,
752
+ attn: Attention,
753
+ hidden_states,
754
+ encoder_hidden_states=None,
755
+ attention_mask=None,
756
+ temb=None,
757
+ num_views=1.,
758
+ multiview_attention=True,
759
+ cross_domain_attention=False,
760
+ ):
761
+ residual = hidden_states
762
+
763
+ if attn.spatial_norm is not None:
764
+ hidden_states = attn.spatial_norm(hidden_states, temb)
765
+
766
+ input_ndim = hidden_states.ndim
767
+
768
+ if input_ndim == 4:
769
+ batch_size, channel, height, width = hidden_states.shape
770
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
771
+
772
+ batch_size, sequence_length, _ = (
773
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
774
+ )
775
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
776
+
777
+ # from yuancheng; here attention_mask is None
778
+ if attention_mask is not None:
779
+ # expand our mask's singleton query_tokens dimension:
780
+ # [batch*heads, 1, key_tokens] ->
781
+ # [batch*heads, query_tokens, key_tokens]
782
+ # so that it can be added as a bias onto the attention scores that xformers computes:
783
+ # [batch*heads, query_tokens, key_tokens]
784
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
785
+ _, query_tokens, _ = hidden_states.shape
786
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
787
+
788
+ if attn.group_norm is not None:
789
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
790
+
791
+ query = attn.to_q(hidden_states)
792
+
793
+ if encoder_hidden_states is None:
794
+ encoder_hidden_states = hidden_states
795
+ elif attn.norm_cross:
796
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
797
+
798
+ key_raw = attn.to_k(encoder_hidden_states)
799
+ value_raw = attn.to_v(encoder_hidden_states)
800
+
801
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
802
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
803
+ # pdb.set_trace()
804
+ # multi-view self-attention
805
+ if multiview_attention:
806
+ key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
807
+ value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
808
+
809
+ if cross_domain_attention:
810
+ # memory efficient, cross domain attention
811
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
812
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
813
+ key_cross = torch.concat([key_1, key_0], dim=0)
814
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
815
+ key = torch.cat([key, key_cross], dim=1)
816
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
817
+ else:
818
+ # print("don't use multiview attention.")
819
+ key = key_raw
820
+ value = value_raw
821
+
822
+ query = attn.head_to_batch_dim(query)
823
+ key = attn.head_to_batch_dim(key)
824
+ value = attn.head_to_batch_dim(value)
825
+
826
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
827
+ hidden_states = attn.batch_to_head_dim(hidden_states)
828
+
829
+ # linear proj
830
+ hidden_states = attn.to_out[0](hidden_states)
831
+ # dropout
832
+ hidden_states = attn.to_out[1](hidden_states)
833
+
834
+ if input_ndim == 4:
835
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
836
+
837
+ if attn.residual_connection:
838
+ hidden_states = hidden_states + residual
839
+
840
+ hidden_states = hidden_states / attn.rescale_output_factor
841
+
842
+ return hidden_states
843
+
844
+
845
+
846
+ class XFormersJointAttnProcessor:
847
+ r"""
848
+ Default processor for performing attention-related computations.
849
+ """
850
+
851
+ def __call__(
852
+ self,
853
+ attn: Attention,
854
+ hidden_states,
855
+ encoder_hidden_states=None,
856
+ attention_mask=None,
857
+ temb=None,
858
+ num_tasks=2
859
+ ):
860
+
861
+ residual = hidden_states
862
+
863
+ if attn.spatial_norm is not None:
864
+ hidden_states = attn.spatial_norm(hidden_states, temb)
865
+
866
+ input_ndim = hidden_states.ndim
867
+
868
+ if input_ndim == 4:
869
+ batch_size, channel, height, width = hidden_states.shape
870
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
871
+
872
+ batch_size, sequence_length, _ = (
873
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
874
+ )
875
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
876
+
877
+ # from yuancheng; here attention_mask is None
878
+ if attention_mask is not None:
879
+ # expand our mask's singleton query_tokens dimension:
880
+ # [batch*heads, 1, key_tokens] ->
881
+ # [batch*heads, query_tokens, key_tokens]
882
+ # so that it can be added as a bias onto the attention scores that xformers computes:
883
+ # [batch*heads, query_tokens, key_tokens]
884
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
885
+ _, query_tokens, _ = hidden_states.shape
886
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
887
+
888
+ if attn.group_norm is not None:
889
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
890
+
891
+ query = attn.to_q(hidden_states)
892
+
893
+ if encoder_hidden_states is None:
894
+ encoder_hidden_states = hidden_states
895
+ elif attn.norm_cross:
896
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
897
+
898
+ key = attn.to_k(encoder_hidden_states)
899
+ value = attn.to_v(encoder_hidden_states)
900
+
901
+ assert num_tasks == 2 # only support two tasks now
902
+
903
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
904
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
905
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
906
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
907
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
908
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
909
+
910
+
911
+ query = attn.head_to_batch_dim(query).contiguous()
912
+ key = attn.head_to_batch_dim(key).contiguous()
913
+ value = attn.head_to_batch_dim(value).contiguous()
914
+
915
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
916
+ hidden_states = attn.batch_to_head_dim(hidden_states)
917
+
918
+ # linear proj
919
+ hidden_states = attn.to_out[0](hidden_states)
920
+ # dropout
921
+ hidden_states = attn.to_out[1](hidden_states)
922
+
923
+ if input_ndim == 4:
924
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
925
+
926
+ if attn.residual_connection:
927
+ hidden_states = hidden_states + residual
928
+
929
+ hidden_states = hidden_states / attn.rescale_output_factor
930
+
931
+ return hidden_states
932
+
933
+
934
+ class JointAttnProcessor:
935
+ r"""
936
+ Default processor for performing attention-related computations.
937
+ """
938
+
939
+ def __call__(
940
+ self,
941
+ attn: Attention,
942
+ hidden_states,
943
+ encoder_hidden_states=None,
944
+ attention_mask=None,
945
+ temb=None,
946
+ num_tasks=2
947
+ ):
948
+
949
+ residual = hidden_states
950
+
951
+ if attn.spatial_norm is not None:
952
+ hidden_states = attn.spatial_norm(hidden_states, temb)
953
+
954
+ input_ndim = hidden_states.ndim
955
+
956
+ if input_ndim == 4:
957
+ batch_size, channel, height, width = hidden_states.shape
958
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
959
+
960
+ batch_size, sequence_length, _ = (
961
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
962
+ )
963
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
964
+
965
+
966
+ if attn.group_norm is not None:
967
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
968
+
969
+ query = attn.to_q(hidden_states)
970
+
971
+ if encoder_hidden_states is None:
972
+ encoder_hidden_states = hidden_states
973
+ elif attn.norm_cross:
974
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
975
+
976
+ key = attn.to_k(encoder_hidden_states)
977
+ value = attn.to_v(encoder_hidden_states)
978
+
979
+ assert num_tasks == 2 # only support two tasks now
980
+
981
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
982
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
983
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
984
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
985
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
986
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
987
+
988
+
989
+ query = attn.head_to_batch_dim(query).contiguous()
990
+ key = attn.head_to_batch_dim(key).contiguous()
991
+ value = attn.head_to_batch_dim(value).contiguous()
992
+
993
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
994
+ hidden_states = torch.bmm(attention_probs, value)
995
+ hidden_states = attn.batch_to_head_dim(hidden_states)
996
+
997
+ # linear proj
998
+ hidden_states = attn.to_out[0](hidden_states)
999
+ # dropout
1000
+ hidden_states = attn.to_out[1](hidden_states)
1001
+
1002
+ if input_ndim == 4:
1003
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1004
+
1005
+ if attn.residual_connection:
1006
+ hidden_states = hidden_states + residual
1007
+
1008
+ hidden_states = hidden_states / attn.rescale_output_factor
1009
+
1010
+ return hidden_states
2D_Stage/tuneavideo/models/unet.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers import ModelMixin
15
+ from diffusers.utils import BaseOutput, logging
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from .unet_blocks import (
18
+ CrossAttnDownBlock3D,
19
+ CrossAttnUpBlock3D,
20
+ DownBlock3D,
21
+ UNetMidBlock3DCrossAttn,
22
+ UpBlock3D,
23
+ get_down_block,
24
+ get_up_block,
25
+ )
26
+ from .resnet import InflatedConv3d
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ @dataclass
33
+ class UNet3DConditionOutput(BaseOutput):
34
+ sample: torch.FloatTensor
35
+
36
+
37
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
38
+ _supports_gradient_checkpointing = True
39
+
40
+ @register_to_config
41
+ def __init__(
42
+ self,
43
+ sample_size: Optional[int] = None,
44
+ in_channels: int = 4,
45
+ out_channels: int = 4,
46
+ center_input_sample: bool = False,
47
+ flip_sin_to_cos: bool = True,
48
+ freq_shift: int = 0,
49
+ down_block_types: Tuple[str] = (
50
+ "CrossAttnDownBlock3D",
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "DownBlock3D",
54
+ ),
55
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
56
+ up_block_types: Tuple[str] = (
57
+ "UpBlock3D",
58
+ "CrossAttnUpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D"
61
+ ),
62
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
63
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
64
+ layers_per_block: int = 2,
65
+ downsample_padding: int = 1,
66
+ mid_block_scale_factor: float = 1,
67
+ act_fn: str = "silu",
68
+ norm_num_groups: int = 32,
69
+ norm_eps: float = 1e-5,
70
+ cross_attention_dim: int = 1280,
71
+ attention_head_dim: Union[int, Tuple[int]] = 8,
72
+ dual_cross_attention: bool = False,
73
+ use_linear_projection: bool = False,
74
+ class_embed_type: Optional[str] = None,
75
+ num_class_embeds: Optional[int] = None,
76
+ upcast_attention: bool = False,
77
+ resnet_time_scale_shift: str = "default",
78
+ use_attn_temp: bool = False,
79
+ camera_input_dim: int = 12,
80
+ camera_hidden_dim: int = 320,
81
+ camera_output_dim: int = 1280,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.sample_size = sample_size
86
+ time_embed_dim = block_out_channels[0] * 4
87
+
88
+ # input
89
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
90
+
91
+ # time
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ # class embedding
98
+ if class_embed_type is None and num_class_embeds is not None:
99
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
100
+ elif class_embed_type == "timestep":
101
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
102
+ elif class_embed_type == "identity":
103
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
104
+ else:
105
+ self.class_embedding = None
106
+
107
+ # camera metrix
108
+ # def init_linear(l, stddev):
109
+ # nn.init.normal_(l.weight, std=stddev)
110
+ # if l.bias is not None:
111
+ # nn.init.constant_(l.bias, 0.0)
112
+ # self.camera_embedding_1 = nn.Linear(camera_input_dim, camera_hidden_dim)
113
+ # self.camera_embedding_2 = nn.Linear(camera_hidden_dim, camera_output_dim)
114
+ # init_linear(self.camera_embedding_1, 0.25)
115
+ # init_linear(self.camera_embedding_2, 0.25)
116
+
117
+ self.camera_embedding = nn.Sequential(
118
+ nn.Linear(camera_input_dim, time_embed_dim),
119
+ nn.SiLU(),
120
+ nn.Linear(time_embed_dim, time_embed_dim),
121
+ )
122
+
123
+ self.down_blocks = nn.ModuleList([])
124
+ self.mid_block = None
125
+ self.up_blocks = nn.ModuleList([])
126
+
127
+ if isinstance(only_cross_attention, bool):
128
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
129
+
130
+ if isinstance(attention_head_dim, int):
131
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
132
+
133
+ # down
134
+ output_channel = block_out_channels[0]
135
+ for i, down_block_type in enumerate(down_block_types):
136
+ input_channel = output_channel
137
+ output_channel = block_out_channels[i]
138
+ is_final_block = i == len(block_out_channels) - 1
139
+
140
+ down_block = get_down_block(
141
+ down_block_type,
142
+ num_layers=layers_per_block,
143
+ in_channels=input_channel,
144
+ out_channels=output_channel,
145
+ temb_channels=time_embed_dim,
146
+ add_downsample=not is_final_block,
147
+ resnet_eps=norm_eps,
148
+ resnet_act_fn=act_fn,
149
+ resnet_groups=norm_num_groups,
150
+ cross_attention_dim=cross_attention_dim,
151
+ attn_num_head_channels=attention_head_dim[i],
152
+ downsample_padding=downsample_padding,
153
+ dual_cross_attention=dual_cross_attention,
154
+ use_linear_projection=use_linear_projection,
155
+ only_cross_attention=only_cross_attention[i],
156
+ upcast_attention=upcast_attention,
157
+ resnet_time_scale_shift=resnet_time_scale_shift,
158
+ use_attn_temp=use_attn_temp
159
+ )
160
+ self.down_blocks.append(down_block)
161
+
162
+ # mid
163
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
164
+ self.mid_block = UNetMidBlock3DCrossAttn(
165
+ in_channels=block_out_channels[-1],
166
+ temb_channels=time_embed_dim,
167
+ resnet_eps=norm_eps,
168
+ resnet_act_fn=act_fn,
169
+ output_scale_factor=mid_block_scale_factor,
170
+ resnet_time_scale_shift=resnet_time_scale_shift,
171
+ cross_attention_dim=cross_attention_dim,
172
+ attn_num_head_channels=attention_head_dim[-1],
173
+ resnet_groups=norm_num_groups,
174
+ dual_cross_attention=dual_cross_attention,
175
+ use_linear_projection=use_linear_projection,
176
+ upcast_attention=upcast_attention,
177
+ )
178
+ else:
179
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
180
+
181
+ # count how many layers upsample the videos
182
+ self.num_upsamplers = 0
183
+
184
+ # up
185
+ reversed_block_out_channels = list(reversed(block_out_channels))
186
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
187
+ only_cross_attention = list(reversed(only_cross_attention))
188
+ output_channel = reversed_block_out_channels[0]
189
+ for i, up_block_type in enumerate(up_block_types):
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
195
+
196
+ # add upsample block for all BUT final layer
197
+ if not is_final_block:
198
+ add_upsample = True
199
+ self.num_upsamplers += 1
200
+ else:
201
+ add_upsample = False
202
+
203
+ up_block = get_up_block(
204
+ up_block_type,
205
+ num_layers=layers_per_block + 1,
206
+ in_channels=input_channel,
207
+ out_channels=output_channel,
208
+ prev_output_channel=prev_output_channel,
209
+ temb_channels=time_embed_dim,
210
+ add_upsample=add_upsample,
211
+ resnet_eps=norm_eps,
212
+ resnet_act_fn=act_fn,
213
+ resnet_groups=norm_num_groups,
214
+ cross_attention_dim=cross_attention_dim,
215
+ attn_num_head_channels=reversed_attention_head_dim[i],
216
+ dual_cross_attention=dual_cross_attention,
217
+ use_linear_projection=use_linear_projection,
218
+ only_cross_attention=only_cross_attention[i],
219
+ upcast_attention=upcast_attention,
220
+ resnet_time_scale_shift=resnet_time_scale_shift,
221
+ use_attn_temp=use_attn_temp,
222
+ )
223
+ self.up_blocks.append(up_block)
224
+ prev_output_channel = output_channel
225
+
226
+ # out
227
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
228
+ self.conv_act = nn.SiLU()
229
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
230
+
231
+ def set_attention_slice(self, slice_size):
232
+ r"""
233
+ Enable sliced attention computation.
234
+
235
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
236
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
237
+
238
+ Args:
239
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
240
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
241
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
242
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
243
+ must be a multiple of `slice_size`.
244
+ """
245
+ sliceable_head_dims = []
246
+
247
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
248
+ if hasattr(module, "set_attention_slice"):
249
+ sliceable_head_dims.append(module.sliceable_head_dim)
250
+
251
+ for child in module.children():
252
+ fn_recursive_retrieve_slicable_dims(child)
253
+
254
+ # retrieve number of attention layers
255
+ for module in self.children():
256
+ fn_recursive_retrieve_slicable_dims(module)
257
+
258
+ num_slicable_layers = len(sliceable_head_dims)
259
+
260
+ if slice_size == "auto":
261
+ # half the attention head size is usually a good trade-off between
262
+ # speed and memory
263
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
264
+ elif slice_size == "max":
265
+ # make smallest slice possible
266
+ slice_size = num_slicable_layers * [1]
267
+
268
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
269
+
270
+ if len(slice_size) != len(sliceable_head_dims):
271
+ raise ValueError(
272
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
273
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
274
+ )
275
+
276
+ for i in range(len(slice_size)):
277
+ size = slice_size[i]
278
+ dim = sliceable_head_dims[i]
279
+ if size is not None and size > dim:
280
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
281
+
282
+ # Recursively walk through all the children.
283
+ # Any children which exposes the set_attention_slice method
284
+ # gets the message
285
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
286
+ if hasattr(module, "set_attention_slice"):
287
+ module.set_attention_slice(slice_size.pop())
288
+
289
+ for child in module.children():
290
+ fn_recursive_set_attention_slice(child, slice_size)
291
+
292
+ reversed_slice_size = list(reversed(slice_size))
293
+ for module in self.children():
294
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
295
+
296
+ def _set_gradient_checkpointing(self, module, value=False):
297
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
298
+ module.gradient_checkpointing = value
299
+
300
+ def forward(
301
+ self,
302
+ sample: torch.FloatTensor,
303
+ timestep: Union[torch.Tensor, float, int],
304
+ encoder_hidden_states: torch.Tensor,
305
+ camera_matrixs: Optional[torch.Tensor] = None,
306
+ class_labels: Optional[torch.Tensor] = None,
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ return_dict: bool = True,
309
+ ) -> Union[UNet3DConditionOutput, Tuple]:
310
+ r"""
311
+ Args:
312
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
313
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
314
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
315
+ return_dict (`bool`, *optional*, defaults to `True`):
316
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
317
+
318
+ Returns:
319
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
320
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
321
+ returning a tuple, the first element is the sample tensor.
322
+ """
323
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
324
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
325
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
326
+ # on the fly if necessary.
327
+ default_overall_up_factor = 2**self.num_upsamplers
328
+
329
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
330
+ forward_upsample_size = False
331
+ upsample_size = None
332
+
333
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
334
+ logger.info("Forward upsample size to force interpolation output size.")
335
+ forward_upsample_size = True
336
+
337
+ # prepare attention_mask
338
+ if attention_mask is not None:
339
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
340
+ attention_mask = attention_mask.unsqueeze(1)
341
+
342
+ # center input if necessary
343
+ if self.config.center_input_sample:
344
+ sample = 2 * sample - 1.0
345
+ # time
346
+ timesteps = timestep
347
+ if not torch.is_tensor(timesteps):
348
+ # This would be a good case for the `match` statement (Python 3.10+)
349
+ is_mps = sample.device.type == "mps"
350
+ if isinstance(timestep, float):
351
+ dtype = torch.float32 if is_mps else torch.float64
352
+ else:
353
+ dtype = torch.int32 if is_mps else torch.int64
354
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
355
+ elif len(timesteps.shape) == 0:
356
+ timesteps = timesteps[None].to(sample.device)
357
+
358
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
359
+ timesteps = timesteps.expand(sample.shape[0])
360
+
361
+ t_emb = self.time_proj(timesteps)
362
+
363
+ # timesteps does not contain any weights and will always return f32 tensors
364
+ # but time_embedding might actually be running in fp16. so we need to cast here.
365
+ # there might be better ways to encapsulate this.
366
+ t_emb = t_emb.to(dtype=self.dtype)
367
+ emb = self.time_embedding(t_emb) #torch.Size([32, 1280])
368
+ emb = torch.unsqueeze(emb, 1)
369
+ if camera_matrixs is not None:
370
+ # came emb
371
+ cam_emb = self.camera_embedding(camera_matrixs)
372
+ # cam_emb = self.camera_embedding_2(cam_emb)
373
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
374
+ emb = emb + cam_emb
375
+
376
+ # import pdb;pdb.set_trace()
377
+ if self.class_embedding is not None:
378
+ # if class_labels is None:
379
+ # raise ValueError("class_labels should be provided when num_class_embeds > 0")
380
+ if class_labels is not None:
381
+
382
+ if self.config.class_embed_type == "timestep":
383
+ class_labels = self.time_proj(class_labels)
384
+
385
+ class_emb = self.class_embedding(class_labels)
386
+ emb = emb + class_emb
387
+
388
+ # pre-process
389
+ sample = self.conv_in(sample)
390
+
391
+ # down
392
+ down_block_res_samples = (sample,)
393
+ for downsample_block in self.down_blocks:
394
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
395
+ sample, res_samples = downsample_block(
396
+ hidden_states=sample,
397
+ temb=emb,
398
+ encoder_hidden_states=encoder_hidden_states,
399
+ attention_mask=attention_mask,
400
+ )
401
+ else:
402
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
403
+
404
+ down_block_res_samples += res_samples
405
+
406
+ # mid
407
+ sample = self.mid_block(
408
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
409
+ )
410
+
411
+ # up
412
+ for i, upsample_block in enumerate(self.up_blocks):
413
+ is_final_block = i == len(self.up_blocks) - 1
414
+
415
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
416
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
417
+
418
+ # if we have not reached the final block and need to forward the
419
+ # upsample size, we do it here
420
+ if not is_final_block and forward_upsample_size:
421
+ upsample_size = down_block_res_samples[-1].shape[2:]
422
+
423
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
424
+ sample = upsample_block(
425
+ hidden_states=sample,
426
+ temb=emb,
427
+ res_hidden_states_tuple=res_samples,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ upsample_size=upsample_size,
430
+ attention_mask=attention_mask,
431
+ )
432
+ else:
433
+ sample = upsample_block(
434
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
435
+ )
436
+ # post-process
437
+ sample = self.conv_norm_out(sample)
438
+ sample = self.conv_act(sample)
439
+ sample = self.conv_out(sample)
440
+
441
+ if not return_dict:
442
+ return (sample,)
443
+
444
+ return UNet3DConditionOutput(sample=sample)
445
+
446
+ @classmethod
447
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
448
+ if subfolder is not None:
449
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
450
+
451
+ config_file = os.path.join(pretrained_model_path, 'config.json')
452
+ if not os.path.isfile(config_file):
453
+ raise RuntimeError(f"{config_file} does not exist")
454
+ with open(config_file, "r") as f:
455
+ config = json.load(f)
456
+ config["_class_name"] = cls.__name__
457
+ config["down_block_types"] = [
458
+ "CrossAttnDownBlock3D",
459
+ "CrossAttnDownBlock3D",
460
+ "CrossAttnDownBlock3D",
461
+ "DownBlock3D"
462
+ ]
463
+ config["up_block_types"] = [
464
+ "UpBlock3D",
465
+ "CrossAttnUpBlock3D",
466
+ "CrossAttnUpBlock3D",
467
+ "CrossAttnUpBlock3D"
468
+ ]
469
+
470
+ from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
471
+ # model = cls.from_config(config)
472
+ # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
473
+ # if not os.path.isfile(model_file):
474
+ # raise RuntimeError(f"{model_file} does not exist")
475
+ # state_dict = torch.load(model_file, map_location="cpu")
476
+
477
+ import safetensors
478
+ model = cls.from_config(config)
479
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
480
+ if not os.path.isfile(model_file):
481
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
482
+ if not os.path.isfile(model_file):
483
+ raise RuntimeError(f"{model_file} does not exist")
484
+ else:
485
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
486
+ else:
487
+ state_dict = torch.load(model_file, map_location="cpu")
488
+
489
+ for k, v in model.state_dict().items():
490
+ if '_temp.' in k or 'camera_embedding' in k or 'class_embedding' in k:
491
+ state_dict.update({k: v})
492
+ for k in list(state_dict.keys()):
493
+ if 'camera_embedding_' in k:
494
+ v = state_dict.pop(k)
495
+ model.load_state_dict(state_dict)
496
+
497
+ return model
2D_Stage/tuneavideo/models/unet_blocks.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ # from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ use_attn_temp=False,
29
+ ):
30
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
31
+ if down_block_type == "DownBlock3D":
32
+ return DownBlock3D(
33
+ num_layers=num_layers,
34
+ in_channels=in_channels,
35
+ out_channels=out_channels,
36
+ temb_channels=temb_channels,
37
+ add_downsample=add_downsample,
38
+ resnet_eps=resnet_eps,
39
+ resnet_act_fn=resnet_act_fn,
40
+ resnet_groups=resnet_groups,
41
+ downsample_padding=downsample_padding,
42
+ resnet_time_scale_shift=resnet_time_scale_shift,
43
+ )
44
+ elif down_block_type == "CrossAttnDownBlock3D":
45
+ if cross_attention_dim is None:
46
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
47
+ return CrossAttnDownBlock3D(
48
+ num_layers=num_layers,
49
+ in_channels=in_channels,
50
+ out_channels=out_channels,
51
+ temb_channels=temb_channels,
52
+ add_downsample=add_downsample,
53
+ resnet_eps=resnet_eps,
54
+ resnet_act_fn=resnet_act_fn,
55
+ resnet_groups=resnet_groups,
56
+ downsample_padding=downsample_padding,
57
+ cross_attention_dim=cross_attention_dim,
58
+ attn_num_head_channels=attn_num_head_channels,
59
+ dual_cross_attention=dual_cross_attention,
60
+ use_linear_projection=use_linear_projection,
61
+ only_cross_attention=only_cross_attention,
62
+ upcast_attention=upcast_attention,
63
+ resnet_time_scale_shift=resnet_time_scale_shift,
64
+ use_attn_temp=use_attn_temp,
65
+ )
66
+ raise ValueError(f"{down_block_type} does not exist.")
67
+
68
+
69
+ def get_up_block(
70
+ up_block_type,
71
+ num_layers,
72
+ in_channels,
73
+ out_channels,
74
+ prev_output_channel,
75
+ temb_channels,
76
+ add_upsample,
77
+ resnet_eps,
78
+ resnet_act_fn,
79
+ attn_num_head_channels,
80
+ resnet_groups=None,
81
+ cross_attention_dim=None,
82
+ dual_cross_attention=False,
83
+ use_linear_projection=False,
84
+ only_cross_attention=False,
85
+ upcast_attention=False,
86
+ resnet_time_scale_shift="default",
87
+ use_attn_temp=False,
88
+ ):
89
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
90
+ if up_block_type == "UpBlock3D":
91
+ return UpBlock3D(
92
+ num_layers=num_layers,
93
+ in_channels=in_channels,
94
+ out_channels=out_channels,
95
+ prev_output_channel=prev_output_channel,
96
+ temb_channels=temb_channels,
97
+ add_upsample=add_upsample,
98
+ resnet_eps=resnet_eps,
99
+ resnet_act_fn=resnet_act_fn,
100
+ resnet_groups=resnet_groups,
101
+ resnet_time_scale_shift=resnet_time_scale_shift,
102
+ )
103
+ elif up_block_type == "CrossAttnUpBlock3D":
104
+ if cross_attention_dim is None:
105
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
106
+ return CrossAttnUpBlock3D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ prev_output_channel=prev_output_channel,
111
+ temb_channels=temb_channels,
112
+ add_upsample=add_upsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ cross_attention_dim=cross_attention_dim,
117
+ attn_num_head_channels=attn_num_head_channels,
118
+ dual_cross_attention=dual_cross_attention,
119
+ use_linear_projection=use_linear_projection,
120
+ only_cross_attention=only_cross_attention,
121
+ upcast_attention=upcast_attention,
122
+ resnet_time_scale_shift=resnet_time_scale_shift,
123
+ use_attn_temp=use_attn_temp,
124
+ )
125
+ raise ValueError(f"{up_block_type} does not exist.")
126
+
127
+
128
+ class UNetMidBlock3DCrossAttn(nn.Module):
129
+ def __init__(
130
+ self,
131
+ in_channels: int,
132
+ temb_channels: int,
133
+ dropout: float = 0.0,
134
+ num_layers: int = 1,
135
+ resnet_eps: float = 1e-6,
136
+ resnet_time_scale_shift: str = "default",
137
+ resnet_act_fn: str = "swish",
138
+ resnet_groups: int = 32,
139
+ resnet_pre_norm: bool = True,
140
+ attn_num_head_channels=1,
141
+ output_scale_factor=1.0,
142
+ cross_attention_dim=1280,
143
+ dual_cross_attention=False,
144
+ use_linear_projection=False,
145
+ upcast_attention=False,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.has_cross_attention = True
150
+ self.attn_num_head_channels = attn_num_head_channels
151
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
152
+
153
+ # there is always at least one resnet
154
+ resnets = [
155
+ ResnetBlock3D(
156
+ in_channels=in_channels,
157
+ out_channels=in_channels,
158
+ temb_channels=temb_channels,
159
+ eps=resnet_eps,
160
+ groups=resnet_groups,
161
+ dropout=dropout,
162
+ time_embedding_norm=resnet_time_scale_shift,
163
+ non_linearity=resnet_act_fn,
164
+ output_scale_factor=output_scale_factor,
165
+ pre_norm=resnet_pre_norm,
166
+ )
167
+ ]
168
+ attentions = []
169
+
170
+ for _ in range(num_layers):
171
+ if dual_cross_attention:
172
+ raise NotImplementedError
173
+ attentions.append(
174
+ Transformer3DModel(
175
+ attn_num_head_channels,
176
+ in_channels // attn_num_head_channels,
177
+ in_channels=in_channels,
178
+ num_layers=1,
179
+ cross_attention_dim=cross_attention_dim,
180
+ norm_num_groups=resnet_groups,
181
+ use_linear_projection=use_linear_projection,
182
+ upcast_attention=upcast_attention,
183
+ )
184
+ )
185
+ resnets.append(
186
+ ResnetBlock3D(
187
+ in_channels=in_channels,
188
+ out_channels=in_channels,
189
+ temb_channels=temb_channels,
190
+ eps=resnet_eps,
191
+ groups=resnet_groups,
192
+ dropout=dropout,
193
+ time_embedding_norm=resnet_time_scale_shift,
194
+ non_linearity=resnet_act_fn,
195
+ output_scale_factor=output_scale_factor,
196
+ pre_norm=resnet_pre_norm,
197
+ )
198
+ )
199
+
200
+ self.attentions = nn.ModuleList(attentions)
201
+ self.resnets = nn.ModuleList(resnets)
202
+
203
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
204
+ hidden_states = self.resnets[0](hidden_states, temb)
205
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
206
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
207
+ hidden_states = resnet(hidden_states, temb)
208
+
209
+ return hidden_states
210
+
211
+
212
+ class CrossAttnDownBlock3D(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_channels: int,
216
+ out_channels: int,
217
+ temb_channels: int,
218
+ dropout: float = 0.0,
219
+ num_layers: int = 1,
220
+ resnet_eps: float = 1e-6,
221
+ resnet_time_scale_shift: str = "default",
222
+ resnet_act_fn: str = "swish",
223
+ resnet_groups: int = 32,
224
+ resnet_pre_norm: bool = True,
225
+ attn_num_head_channels=1,
226
+ cross_attention_dim=1280,
227
+ output_scale_factor=1.0,
228
+ downsample_padding=1,
229
+ add_downsample=True,
230
+ dual_cross_attention=False,
231
+ use_linear_projection=False,
232
+ only_cross_attention=False,
233
+ upcast_attention=False,
234
+ use_attn_temp=False,
235
+ ):
236
+ super().__init__()
237
+ resnets = []
238
+ attentions = []
239
+
240
+ self.has_cross_attention = True
241
+ self.attn_num_head_channels = attn_num_head_channels
242
+
243
+ for i in range(num_layers):
244
+ in_channels = in_channels if i == 0 else out_channels
245
+ resnets.append(
246
+ ResnetBlock3D(
247
+ in_channels=in_channels,
248
+ out_channels=out_channels,
249
+ temb_channels=temb_channels,
250
+ eps=resnet_eps,
251
+ groups=resnet_groups,
252
+ dropout=dropout,
253
+ time_embedding_norm=resnet_time_scale_shift,
254
+ non_linearity=resnet_act_fn,
255
+ output_scale_factor=output_scale_factor,
256
+ pre_norm=resnet_pre_norm,
257
+ )
258
+ )
259
+ if dual_cross_attention:
260
+ raise NotImplementedError
261
+ attentions.append(
262
+ Transformer3DModel(
263
+ attn_num_head_channels,
264
+ out_channels // attn_num_head_channels,
265
+ in_channels=out_channels,
266
+ num_layers=1,
267
+ cross_attention_dim=cross_attention_dim,
268
+ norm_num_groups=resnet_groups,
269
+ use_linear_projection=use_linear_projection,
270
+ only_cross_attention=only_cross_attention,
271
+ upcast_attention=upcast_attention,
272
+ use_attn_temp=use_attn_temp,
273
+ )
274
+ )
275
+ self.attentions = nn.ModuleList(attentions)
276
+ self.resnets = nn.ModuleList(resnets)
277
+
278
+ if add_downsample:
279
+ self.downsamplers = nn.ModuleList(
280
+ [
281
+ Downsample3D(
282
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
283
+ )
284
+ ]
285
+ )
286
+ else:
287
+ self.downsamplers = None
288
+
289
+ self.gradient_checkpointing = False
290
+
291
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
292
+ output_states = ()
293
+
294
+ for resnet, attn in zip(self.resnets, self.attentions):
295
+ if self.training and self.gradient_checkpointing:
296
+
297
+ def create_custom_forward(module, return_dict=None):
298
+ def custom_forward(*inputs):
299
+ if return_dict is not None:
300
+ return module(*inputs, return_dict=return_dict)
301
+ else:
302
+ return module(*inputs)
303
+
304
+ return custom_forward
305
+
306
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
307
+ hidden_states = torch.utils.checkpoint.checkpoint(
308
+ create_custom_forward(attn, return_dict=False),
309
+ hidden_states,
310
+ encoder_hidden_states,
311
+ )[0]
312
+ else:
313
+ hidden_states = resnet(hidden_states, temb)
314
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
315
+
316
+ output_states += (hidden_states,)
317
+
318
+ if self.downsamplers is not None:
319
+ for downsampler in self.downsamplers:
320
+ hidden_states = downsampler(hidden_states)
321
+
322
+ output_states += (hidden_states,)
323
+
324
+ return hidden_states, output_states
325
+
326
+
327
+ class DownBlock3D(nn.Module):
328
+ def __init__(
329
+ self,
330
+ in_channels: int,
331
+ out_channels: int,
332
+ temb_channels: int,
333
+ dropout: float = 0.0,
334
+ num_layers: int = 1,
335
+ resnet_eps: float = 1e-6,
336
+ resnet_time_scale_shift: str = "default",
337
+ resnet_act_fn: str = "swish",
338
+ resnet_groups: int = 32,
339
+ resnet_pre_norm: bool = True,
340
+ output_scale_factor=1.0,
341
+ add_downsample=True,
342
+ downsample_padding=1,
343
+ ):
344
+ super().__init__()
345
+ resnets = []
346
+
347
+ for i in range(num_layers):
348
+ in_channels = in_channels if i == 0 else out_channels
349
+ resnets.append(
350
+ ResnetBlock3D(
351
+ in_channels=in_channels,
352
+ out_channels=out_channels,
353
+ temb_channels=temb_channels,
354
+ eps=resnet_eps,
355
+ groups=resnet_groups,
356
+ dropout=dropout,
357
+ time_embedding_norm=resnet_time_scale_shift,
358
+ non_linearity=resnet_act_fn,
359
+ output_scale_factor=output_scale_factor,
360
+ pre_norm=resnet_pre_norm,
361
+ )
362
+ )
363
+
364
+ self.resnets = nn.ModuleList(resnets)
365
+
366
+ if add_downsample:
367
+ self.downsamplers = nn.ModuleList(
368
+ [
369
+ Downsample3D(
370
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
371
+ )
372
+ ]
373
+ )
374
+ else:
375
+ self.downsamplers = None
376
+
377
+ self.gradient_checkpointing = False
378
+
379
+ def forward(self, hidden_states, temb=None):
380
+ output_states = ()
381
+
382
+ for resnet in self.resnets:
383
+ if self.training and self.gradient_checkpointing:
384
+
385
+ def create_custom_forward(module):
386
+ def custom_forward(*inputs):
387
+ return module(*inputs)
388
+
389
+ return custom_forward
390
+
391
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
392
+ else:
393
+ hidden_states = resnet(hidden_states, temb)
394
+
395
+ output_states += (hidden_states,)
396
+
397
+ if self.downsamplers is not None:
398
+ for downsampler in self.downsamplers:
399
+ hidden_states = downsampler(hidden_states)
400
+
401
+ output_states += (hidden_states,)
402
+
403
+ return hidden_states, output_states
404
+
405
+
406
+ class CrossAttnUpBlock3D(nn.Module):
407
+ def __init__(
408
+ self,
409
+ in_channels: int,
410
+ out_channels: int,
411
+ prev_output_channel: int,
412
+ temb_channels: int,
413
+ dropout: float = 0.0,
414
+ num_layers: int = 1,
415
+ resnet_eps: float = 1e-6,
416
+ resnet_time_scale_shift: str = "default",
417
+ resnet_act_fn: str = "swish",
418
+ resnet_groups: int = 32,
419
+ resnet_pre_norm: bool = True,
420
+ attn_num_head_channels=1,
421
+ cross_attention_dim=1280,
422
+ output_scale_factor=1.0,
423
+ add_upsample=True,
424
+ dual_cross_attention=False,
425
+ use_linear_projection=False,
426
+ only_cross_attention=False,
427
+ upcast_attention=False,
428
+ use_attn_temp=False,
429
+ ):
430
+ super().__init__()
431
+ resnets = []
432
+ attentions = []
433
+
434
+ self.has_cross_attention = True
435
+ self.attn_num_head_channels = attn_num_head_channels
436
+
437
+ for i in range(num_layers):
438
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
439
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
440
+
441
+ resnets.append(
442
+ ResnetBlock3D(
443
+ in_channels=resnet_in_channels + res_skip_channels,
444
+ out_channels=out_channels,
445
+ temb_channels=temb_channels,
446
+ eps=resnet_eps,
447
+ groups=resnet_groups,
448
+ dropout=dropout,
449
+ time_embedding_norm=resnet_time_scale_shift,
450
+ non_linearity=resnet_act_fn,
451
+ output_scale_factor=output_scale_factor,
452
+ pre_norm=resnet_pre_norm,
453
+ )
454
+ )
455
+ if dual_cross_attention:
456
+ raise NotImplementedError
457
+ attentions.append(
458
+ Transformer3DModel(
459
+ attn_num_head_channels,
460
+ out_channels // attn_num_head_channels,
461
+ in_channels=out_channels,
462
+ num_layers=1,
463
+ cross_attention_dim=cross_attention_dim,
464
+ norm_num_groups=resnet_groups,
465
+ use_linear_projection=use_linear_projection,
466
+ only_cross_attention=only_cross_attention,
467
+ upcast_attention=upcast_attention,
468
+ use_attn_temp=use_attn_temp,
469
+ )
470
+ )
471
+
472
+ self.attentions = nn.ModuleList(attentions)
473
+ self.resnets = nn.ModuleList(resnets)
474
+
475
+ if add_upsample:
476
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
477
+ else:
478
+ self.upsamplers = None
479
+
480
+ self.gradient_checkpointing = False
481
+
482
+ def forward(
483
+ self,
484
+ hidden_states,
485
+ res_hidden_states_tuple,
486
+ temb=None,
487
+ encoder_hidden_states=None,
488
+ upsample_size=None,
489
+ attention_mask=None,
490
+ ):
491
+ for resnet, attn in zip(self.resnets, self.attentions):
492
+ # pop res hidden states
493
+ res_hidden_states = res_hidden_states_tuple[-1]
494
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
495
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
496
+
497
+ if self.training and self.gradient_checkpointing:
498
+
499
+ def create_custom_forward(module, return_dict=None):
500
+ def custom_forward(*inputs):
501
+ if return_dict is not None:
502
+ return module(*inputs, return_dict=return_dict)
503
+ else:
504
+ return module(*inputs)
505
+
506
+ return custom_forward
507
+
508
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
509
+ hidden_states = torch.utils.checkpoint.checkpoint(
510
+ create_custom_forward(attn, return_dict=False),
511
+ hidden_states,
512
+ encoder_hidden_states,
513
+ )[0]
514
+ else:
515
+ hidden_states = resnet(hidden_states, temb)
516
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
517
+
518
+ if self.upsamplers is not None:
519
+ for upsampler in self.upsamplers:
520
+ hidden_states = upsampler(hidden_states, upsample_size)
521
+
522
+ return hidden_states
523
+
524
+
525
+ class UpBlock3D(nn.Module):
526
+ def __init__(
527
+ self,
528
+ in_channels: int,
529
+ prev_output_channel: int,
530
+ out_channels: int,
531
+ temb_channels: int,
532
+ dropout: float = 0.0,
533
+ num_layers: int = 1,
534
+ resnet_eps: float = 1e-6,
535
+ resnet_time_scale_shift: str = "default",
536
+ resnet_act_fn: str = "swish",
537
+ resnet_groups: int = 32,
538
+ resnet_pre_norm: bool = True,
539
+ output_scale_factor=1.0,
540
+ add_upsample=True,
541
+ ):
542
+ super().__init__()
543
+ resnets = []
544
+
545
+ for i in range(num_layers):
546
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
547
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
548
+
549
+ resnets.append(
550
+ ResnetBlock3D(
551
+ in_channels=resnet_in_channels + res_skip_channels,
552
+ out_channels=out_channels,
553
+ temb_channels=temb_channels,
554
+ eps=resnet_eps,
555
+ groups=resnet_groups,
556
+ dropout=dropout,
557
+ time_embedding_norm=resnet_time_scale_shift,
558
+ non_linearity=resnet_act_fn,
559
+ output_scale_factor=output_scale_factor,
560
+ pre_norm=resnet_pre_norm,
561
+ )
562
+ )
563
+
564
+ self.resnets = nn.ModuleList(resnets)
565
+
566
+ if add_upsample:
567
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
568
+ else:
569
+ self.upsamplers = None
570
+
571
+ self.gradient_checkpointing = False
572
+
573
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
574
+ for resnet in self.resnets:
575
+ # pop res hidden states
576
+ res_hidden_states = res_hidden_states_tuple[-1]
577
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
578
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
579
+
580
+ if self.training and self.gradient_checkpointing:
581
+
582
+ def create_custom_forward(module):
583
+ def custom_forward(*inputs):
584
+ return module(*inputs)
585
+
586
+ return custom_forward
587
+
588
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
589
+ else:
590
+ hidden_states = resnet(hidden_states, temb)
591
+
592
+ if self.upsamplers is not None:
593
+ for upsampler in self.upsamplers:
594
+ hidden_states = upsampler(hidden_states, upsample_size)
595
+
596
+ return hidden_states
2D_Stage/tuneavideo/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ # from diffusers.models.attention import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from tuneavideo.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1,
60
+ joint_attention: bool = False,
61
+ joint_attention_twice: bool = False,
62
+ multiview_attention: bool = True,
63
+ cross_domain_attention: bool=False
64
+ ):
65
+ # If attn head dim is not defined, we default it to the number of heads
66
+ if attention_head_dim is None:
67
+ logger.warn(
68
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
69
+ )
70
+ attention_head_dim = num_attention_heads
71
+
72
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
73
+ if down_block_type == "DownBlock2D":
74
+ return DownBlock2D(
75
+ num_layers=num_layers,
76
+ in_channels=in_channels,
77
+ out_channels=out_channels,
78
+ temb_channels=temb_channels,
79
+ add_downsample=add_downsample,
80
+ resnet_eps=resnet_eps,
81
+ resnet_act_fn=resnet_act_fn,
82
+ resnet_groups=resnet_groups,
83
+ downsample_padding=downsample_padding,
84
+ resnet_time_scale_shift=resnet_time_scale_shift,
85
+ )
86
+ elif down_block_type == "ResnetDownsampleBlock2D":
87
+ return ResnetDownsampleBlock2D(
88
+ num_layers=num_layers,
89
+ in_channels=in_channels,
90
+ out_channels=out_channels,
91
+ temb_channels=temb_channels,
92
+ add_downsample=add_downsample,
93
+ resnet_eps=resnet_eps,
94
+ resnet_act_fn=resnet_act_fn,
95
+ resnet_groups=resnet_groups,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ skip_time_act=resnet_skip_time_act,
98
+ output_scale_factor=resnet_out_scale_factor,
99
+ )
100
+ elif down_block_type == "AttnDownBlock2D":
101
+ if add_downsample is False:
102
+ downsample_type = None
103
+ else:
104
+ downsample_type = downsample_type or "conv" # default to 'conv'
105
+ return AttnDownBlock2D(
106
+ num_layers=num_layers,
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ temb_channels=temb_channels,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ downsample_padding=downsample_padding,
114
+ attention_head_dim=attention_head_dim,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ downsample_type=downsample_type,
117
+ )
118
+ elif down_block_type == "CrossAttnDownBlock2D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
121
+ return CrossAttnDownBlock2D(
122
+ num_layers=num_layers,
123
+ transformer_layers_per_block=transformer_layers_per_block,
124
+ in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ temb_channels=temb_channels,
127
+ add_downsample=add_downsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ downsample_padding=downsample_padding,
132
+ cross_attention_dim=cross_attention_dim,
133
+ num_attention_heads=num_attention_heads,
134
+ dual_cross_attention=dual_cross_attention,
135
+ use_linear_projection=use_linear_projection,
136
+ only_cross_attention=only_cross_attention,
137
+ upcast_attention=upcast_attention,
138
+ resnet_time_scale_shift=resnet_time_scale_shift,
139
+ )
140
+ # custom MV2D attention block
141
+ elif down_block_type == "CrossAttnDownBlockMV2D":
142
+ if cross_attention_dim is None:
143
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
144
+ return CrossAttnDownBlockMV2D(
145
+ num_layers=num_layers,
146
+ transformer_layers_per_block=transformer_layers_per_block,
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ temb_channels=temb_channels,
150
+ add_downsample=add_downsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ downsample_padding=downsample_padding,
155
+ cross_attention_dim=cross_attention_dim,
156
+ num_attention_heads=num_attention_heads,
157
+ dual_cross_attention=dual_cross_attention,
158
+ use_linear_projection=use_linear_projection,
159
+ only_cross_attention=only_cross_attention,
160
+ upcast_attention=upcast_attention,
161
+ resnet_time_scale_shift=resnet_time_scale_shift,
162
+ num_views=num_views,
163
+ joint_attention=joint_attention,
164
+ joint_attention_twice=joint_attention_twice,
165
+ multiview_attention=multiview_attention,
166
+ cross_domain_attention=cross_domain_attention
167
+ )
168
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
169
+ if cross_attention_dim is None:
170
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
171
+ return SimpleCrossAttnDownBlock2D(
172
+ num_layers=num_layers,
173
+ in_channels=in_channels,
174
+ out_channels=out_channels,
175
+ temb_channels=temb_channels,
176
+ add_downsample=add_downsample,
177
+ resnet_eps=resnet_eps,
178
+ resnet_act_fn=resnet_act_fn,
179
+ resnet_groups=resnet_groups,
180
+ cross_attention_dim=cross_attention_dim,
181
+ attention_head_dim=attention_head_dim,
182
+ resnet_time_scale_shift=resnet_time_scale_shift,
183
+ skip_time_act=resnet_skip_time_act,
184
+ output_scale_factor=resnet_out_scale_factor,
185
+ only_cross_attention=only_cross_attention,
186
+ cross_attention_norm=cross_attention_norm,
187
+ )
188
+ elif down_block_type == "SkipDownBlock2D":
189
+ return SkipDownBlock2D(
190
+ num_layers=num_layers,
191
+ in_channels=in_channels,
192
+ out_channels=out_channels,
193
+ temb_channels=temb_channels,
194
+ add_downsample=add_downsample,
195
+ resnet_eps=resnet_eps,
196
+ resnet_act_fn=resnet_act_fn,
197
+ downsample_padding=downsample_padding,
198
+ resnet_time_scale_shift=resnet_time_scale_shift,
199
+ )
200
+ elif down_block_type == "AttnSkipDownBlock2D":
201
+ return AttnSkipDownBlock2D(
202
+ num_layers=num_layers,
203
+ in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ temb_channels=temb_channels,
206
+ add_downsample=add_downsample,
207
+ resnet_eps=resnet_eps,
208
+ resnet_act_fn=resnet_act_fn,
209
+ attention_head_dim=attention_head_dim,
210
+ resnet_time_scale_shift=resnet_time_scale_shift,
211
+ )
212
+ elif down_block_type == "DownEncoderBlock2D":
213
+ return DownEncoderBlock2D(
214
+ num_layers=num_layers,
215
+ in_channels=in_channels,
216
+ out_channels=out_channels,
217
+ add_downsample=add_downsample,
218
+ resnet_eps=resnet_eps,
219
+ resnet_act_fn=resnet_act_fn,
220
+ resnet_groups=resnet_groups,
221
+ downsample_padding=downsample_padding,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ )
224
+ elif down_block_type == "AttnDownEncoderBlock2D":
225
+ return AttnDownEncoderBlock2D(
226
+ num_layers=num_layers,
227
+ in_channels=in_channels,
228
+ out_channels=out_channels,
229
+ add_downsample=add_downsample,
230
+ resnet_eps=resnet_eps,
231
+ resnet_act_fn=resnet_act_fn,
232
+ resnet_groups=resnet_groups,
233
+ downsample_padding=downsample_padding,
234
+ attention_head_dim=attention_head_dim,
235
+ resnet_time_scale_shift=resnet_time_scale_shift,
236
+ )
237
+ elif down_block_type == "KDownBlock2D":
238
+ return KDownBlock2D(
239
+ num_layers=num_layers,
240
+ in_channels=in_channels,
241
+ out_channels=out_channels,
242
+ temb_channels=temb_channels,
243
+ add_downsample=add_downsample,
244
+ resnet_eps=resnet_eps,
245
+ resnet_act_fn=resnet_act_fn,
246
+ )
247
+ elif down_block_type == "KCrossAttnDownBlock2D":
248
+ return KCrossAttnDownBlock2D(
249
+ num_layers=num_layers,
250
+ in_channels=in_channels,
251
+ out_channels=out_channels,
252
+ temb_channels=temb_channels,
253
+ add_downsample=add_downsample,
254
+ resnet_eps=resnet_eps,
255
+ resnet_act_fn=resnet_act_fn,
256
+ cross_attention_dim=cross_attention_dim,
257
+ attention_head_dim=attention_head_dim,
258
+ add_self_attention=True if not add_downsample else False,
259
+ )
260
+ raise ValueError(f"{down_block_type} does not exist.")
261
+
262
+
263
+ def get_up_block(
264
+ up_block_type,
265
+ num_layers,
266
+ in_channels,
267
+ out_channels,
268
+ prev_output_channel,
269
+ temb_channels,
270
+ add_upsample,
271
+ resnet_eps,
272
+ resnet_act_fn,
273
+ transformer_layers_per_block=1,
274
+ num_attention_heads=None,
275
+ resnet_groups=None,
276
+ cross_attention_dim=None,
277
+ dual_cross_attention=False,
278
+ use_linear_projection=False,
279
+ only_cross_attention=False,
280
+ upcast_attention=False,
281
+ resnet_time_scale_shift="default",
282
+ resnet_skip_time_act=False,
283
+ resnet_out_scale_factor=1.0,
284
+ cross_attention_norm=None,
285
+ attention_head_dim=None,
286
+ upsample_type=None,
287
+ num_views=1,
288
+ joint_attention: bool = False,
289
+ joint_attention_twice: bool = False,
290
+ multiview_attention: bool = True,
291
+ cross_domain_attention: bool=False
292
+ ):
293
+ # If attn head dim is not defined, we default it to the number of heads
294
+ if attention_head_dim is None:
295
+ logger.warn(
296
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
297
+ )
298
+ attention_head_dim = num_attention_heads
299
+
300
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
301
+ if up_block_type == "UpBlock2D":
302
+ return UpBlock2D(
303
+ num_layers=num_layers,
304
+ in_channels=in_channels,
305
+ out_channels=out_channels,
306
+ prev_output_channel=prev_output_channel,
307
+ temb_channels=temb_channels,
308
+ add_upsample=add_upsample,
309
+ resnet_eps=resnet_eps,
310
+ resnet_act_fn=resnet_act_fn,
311
+ resnet_groups=resnet_groups,
312
+ resnet_time_scale_shift=resnet_time_scale_shift,
313
+ )
314
+ elif up_block_type == "ResnetUpsampleBlock2D":
315
+ return ResnetUpsampleBlock2D(
316
+ num_layers=num_layers,
317
+ in_channels=in_channels,
318
+ out_channels=out_channels,
319
+ prev_output_channel=prev_output_channel,
320
+ temb_channels=temb_channels,
321
+ add_upsample=add_upsample,
322
+ resnet_eps=resnet_eps,
323
+ resnet_act_fn=resnet_act_fn,
324
+ resnet_groups=resnet_groups,
325
+ resnet_time_scale_shift=resnet_time_scale_shift,
326
+ skip_time_act=resnet_skip_time_act,
327
+ output_scale_factor=resnet_out_scale_factor,
328
+ )
329
+ elif up_block_type == "CrossAttnUpBlock2D":
330
+ if cross_attention_dim is None:
331
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
332
+ return CrossAttnUpBlock2D(
333
+ num_layers=num_layers,
334
+ transformer_layers_per_block=transformer_layers_per_block,
335
+ in_channels=in_channels,
336
+ out_channels=out_channels,
337
+ prev_output_channel=prev_output_channel,
338
+ temb_channels=temb_channels,
339
+ add_upsample=add_upsample,
340
+ resnet_eps=resnet_eps,
341
+ resnet_act_fn=resnet_act_fn,
342
+ resnet_groups=resnet_groups,
343
+ cross_attention_dim=cross_attention_dim,
344
+ num_attention_heads=num_attention_heads,
345
+ dual_cross_attention=dual_cross_attention,
346
+ use_linear_projection=use_linear_projection,
347
+ only_cross_attention=only_cross_attention,
348
+ upcast_attention=upcast_attention,
349
+ resnet_time_scale_shift=resnet_time_scale_shift,
350
+ )
351
+ # custom MV2D attention block
352
+ elif up_block_type == "CrossAttnUpBlockMV2D":
353
+ if cross_attention_dim is None:
354
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
355
+ return CrossAttnUpBlockMV2D(
356
+ num_layers=num_layers,
357
+ transformer_layers_per_block=transformer_layers_per_block,
358
+ in_channels=in_channels,
359
+ out_channels=out_channels,
360
+ prev_output_channel=prev_output_channel,
361
+ temb_channels=temb_channels,
362
+ add_upsample=add_upsample,
363
+ resnet_eps=resnet_eps,
364
+ resnet_act_fn=resnet_act_fn,
365
+ resnet_groups=resnet_groups,
366
+ cross_attention_dim=cross_attention_dim,
367
+ num_attention_heads=num_attention_heads,
368
+ dual_cross_attention=dual_cross_attention,
369
+ use_linear_projection=use_linear_projection,
370
+ only_cross_attention=only_cross_attention,
371
+ upcast_attention=upcast_attention,
372
+ resnet_time_scale_shift=resnet_time_scale_shift,
373
+ num_views=num_views,
374
+ joint_attention=joint_attention,
375
+ joint_attention_twice=joint_attention_twice,
376
+ multiview_attention=multiview_attention,
377
+ cross_domain_attention=cross_domain_attention
378
+ )
379
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
380
+ if cross_attention_dim is None:
381
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
382
+ return SimpleCrossAttnUpBlock2D(
383
+ num_layers=num_layers,
384
+ in_channels=in_channels,
385
+ out_channels=out_channels,
386
+ prev_output_channel=prev_output_channel,
387
+ temb_channels=temb_channels,
388
+ add_upsample=add_upsample,
389
+ resnet_eps=resnet_eps,
390
+ resnet_act_fn=resnet_act_fn,
391
+ resnet_groups=resnet_groups,
392
+ cross_attention_dim=cross_attention_dim,
393
+ attention_head_dim=attention_head_dim,
394
+ resnet_time_scale_shift=resnet_time_scale_shift,
395
+ skip_time_act=resnet_skip_time_act,
396
+ output_scale_factor=resnet_out_scale_factor,
397
+ only_cross_attention=only_cross_attention,
398
+ cross_attention_norm=cross_attention_norm,
399
+ )
400
+ elif up_block_type == "AttnUpBlock2D":
401
+ if add_upsample is False:
402
+ upsample_type = None
403
+ else:
404
+ upsample_type = upsample_type or "conv" # default to 'conv'
405
+
406
+ return AttnUpBlock2D(
407
+ num_layers=num_layers,
408
+ in_channels=in_channels,
409
+ out_channels=out_channels,
410
+ prev_output_channel=prev_output_channel,
411
+ temb_channels=temb_channels,
412
+ resnet_eps=resnet_eps,
413
+ resnet_act_fn=resnet_act_fn,
414
+ resnet_groups=resnet_groups,
415
+ attention_head_dim=attention_head_dim,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ upsample_type=upsample_type,
418
+ )
419
+ elif up_block_type == "SkipUpBlock2D":
420
+ return SkipUpBlock2D(
421
+ num_layers=num_layers,
422
+ in_channels=in_channels,
423
+ out_channels=out_channels,
424
+ prev_output_channel=prev_output_channel,
425
+ temb_channels=temb_channels,
426
+ add_upsample=add_upsample,
427
+ resnet_eps=resnet_eps,
428
+ resnet_act_fn=resnet_act_fn,
429
+ resnet_time_scale_shift=resnet_time_scale_shift,
430
+ )
431
+ elif up_block_type == "AttnSkipUpBlock2D":
432
+ return AttnSkipUpBlock2D(
433
+ num_layers=num_layers,
434
+ in_channels=in_channels,
435
+ out_channels=out_channels,
436
+ prev_output_channel=prev_output_channel,
437
+ temb_channels=temb_channels,
438
+ add_upsample=add_upsample,
439
+ resnet_eps=resnet_eps,
440
+ resnet_act_fn=resnet_act_fn,
441
+ attention_head_dim=attention_head_dim,
442
+ resnet_time_scale_shift=resnet_time_scale_shift,
443
+ )
444
+ elif up_block_type == "UpDecoderBlock2D":
445
+ return UpDecoderBlock2D(
446
+ num_layers=num_layers,
447
+ in_channels=in_channels,
448
+ out_channels=out_channels,
449
+ add_upsample=add_upsample,
450
+ resnet_eps=resnet_eps,
451
+ resnet_act_fn=resnet_act_fn,
452
+ resnet_groups=resnet_groups,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ temb_channels=temb_channels,
455
+ )
456
+ elif up_block_type == "AttnUpDecoderBlock2D":
457
+ return AttnUpDecoderBlock2D(
458
+ num_layers=num_layers,
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ add_upsample=add_upsample,
462
+ resnet_eps=resnet_eps,
463
+ resnet_act_fn=resnet_act_fn,
464
+ resnet_groups=resnet_groups,
465
+ attention_head_dim=attention_head_dim,
466
+ resnet_time_scale_shift=resnet_time_scale_shift,
467
+ temb_channels=temb_channels,
468
+ )
469
+ elif up_block_type == "KUpBlock2D":
470
+ return KUpBlock2D(
471
+ num_layers=num_layers,
472
+ in_channels=in_channels,
473
+ out_channels=out_channels,
474
+ temb_channels=temb_channels,
475
+ add_upsample=add_upsample,
476
+ resnet_eps=resnet_eps,
477
+ resnet_act_fn=resnet_act_fn,
478
+ )
479
+ elif up_block_type == "KCrossAttnUpBlock2D":
480
+ return KCrossAttnUpBlock2D(
481
+ num_layers=num_layers,
482
+ in_channels=in_channels,
483
+ out_channels=out_channels,
484
+ temb_channels=temb_channels,
485
+ add_upsample=add_upsample,
486
+ resnet_eps=resnet_eps,
487
+ resnet_act_fn=resnet_act_fn,
488
+ cross_attention_dim=cross_attention_dim,
489
+ attention_head_dim=attention_head_dim,
490
+ )
491
+
492
+ raise ValueError(f"{up_block_type} does not exist.")
493
+
494
+
495
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
496
+ def __init__(
497
+ self,
498
+ in_channels: int,
499
+ temb_channels: int,
500
+ dropout: float = 0.0,
501
+ num_layers: int = 1,
502
+ transformer_layers_per_block: int = 1,
503
+ resnet_eps: float = 1e-6,
504
+ resnet_time_scale_shift: str = "default",
505
+ resnet_act_fn: str = "swish",
506
+ resnet_groups: int = 32,
507
+ resnet_pre_norm: bool = True,
508
+ num_attention_heads=1,
509
+ output_scale_factor=1.0,
510
+ cross_attention_dim=1280,
511
+ dual_cross_attention=False,
512
+ use_linear_projection=False,
513
+ upcast_attention=False,
514
+ num_views: int = 1,
515
+ joint_attention: bool = False,
516
+ joint_attention_twice: bool = False,
517
+ multiview_attention: bool = True,
518
+ cross_domain_attention: bool=False
519
+ ):
520
+ super().__init__()
521
+
522
+ self.has_cross_attention = True
523
+ self.num_attention_heads = num_attention_heads
524
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
525
+
526
+ # there is always at least one resnet
527
+ resnets = [
528
+ ResnetBlock2D(
529
+ in_channels=in_channels,
530
+ out_channels=in_channels,
531
+ temb_channels=temb_channels,
532
+ eps=resnet_eps,
533
+ groups=resnet_groups,
534
+ dropout=dropout,
535
+ time_embedding_norm=resnet_time_scale_shift,
536
+ non_linearity=resnet_act_fn,
537
+ output_scale_factor=output_scale_factor,
538
+ pre_norm=resnet_pre_norm,
539
+ )
540
+ ]
541
+ attentions = []
542
+
543
+ for _ in range(num_layers):
544
+ if not dual_cross_attention:
545
+ attentions.append(
546
+ TransformerMV2DModel(
547
+ num_attention_heads,
548
+ in_channels // num_attention_heads,
549
+ in_channels=in_channels,
550
+ num_layers=transformer_layers_per_block,
551
+ cross_attention_dim=cross_attention_dim,
552
+ norm_num_groups=resnet_groups,
553
+ use_linear_projection=use_linear_projection,
554
+ upcast_attention=upcast_attention,
555
+ num_views=num_views,
556
+ joint_attention=joint_attention,
557
+ joint_attention_twice=joint_attention_twice,
558
+ multiview_attention=multiview_attention,
559
+ cross_domain_attention=cross_domain_attention
560
+ )
561
+ )
562
+ else:
563
+ raise NotImplementedError
564
+ resnets.append(
565
+ ResnetBlock2D(
566
+ in_channels=in_channels,
567
+ out_channels=in_channels,
568
+ temb_channels=temb_channels,
569
+ eps=resnet_eps,
570
+ groups=resnet_groups,
571
+ dropout=dropout,
572
+ time_embedding_norm=resnet_time_scale_shift,
573
+ non_linearity=resnet_act_fn,
574
+ output_scale_factor=output_scale_factor,
575
+ pre_norm=resnet_pre_norm,
576
+ )
577
+ )
578
+
579
+ self.attentions = nn.ModuleList(attentions)
580
+ self.resnets = nn.ModuleList(resnets)
581
+
582
+ def forward(
583
+ self,
584
+ hidden_states: torch.FloatTensor,
585
+ temb: Optional[torch.FloatTensor] = None,
586
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
587
+ attention_mask: Optional[torch.FloatTensor] = None,
588
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
589
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
590
+ ) -> torch.FloatTensor:
591
+ hidden_states = self.resnets[0](hidden_states, temb)
592
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
593
+ hidden_states = attn(
594
+ hidden_states,
595
+ encoder_hidden_states=encoder_hidden_states,
596
+ cross_attention_kwargs=cross_attention_kwargs,
597
+ attention_mask=attention_mask,
598
+ encoder_attention_mask=encoder_attention_mask,
599
+ return_dict=False,
600
+ )[0]
601
+ hidden_states = resnet(hidden_states, temb)
602
+
603
+ return hidden_states
604
+
605
+
606
+ class CrossAttnUpBlockMV2D(nn.Module):
607
+ def __init__(
608
+ self,
609
+ in_channels: int,
610
+ out_channels: int,
611
+ prev_output_channel: int,
612
+ temb_channels: int,
613
+ dropout: float = 0.0,
614
+ num_layers: int = 1,
615
+ transformer_layers_per_block: int = 1,
616
+ resnet_eps: float = 1e-6,
617
+ resnet_time_scale_shift: str = "default",
618
+ resnet_act_fn: str = "swish",
619
+ resnet_groups: int = 32,
620
+ resnet_pre_norm: bool = True,
621
+ num_attention_heads=1,
622
+ cross_attention_dim=1280,
623
+ output_scale_factor=1.0,
624
+ add_upsample=True,
625
+ dual_cross_attention=False,
626
+ use_linear_projection=False,
627
+ only_cross_attention=False,
628
+ upcast_attention=False,
629
+ num_views: int = 1,
630
+ joint_attention: bool = False,
631
+ joint_attention_twice: bool = False,
632
+ multiview_attention: bool = True,
633
+ cross_domain_attention: bool=False
634
+ ):
635
+ super().__init__()
636
+ resnets = []
637
+ attentions = []
638
+
639
+ self.has_cross_attention = True
640
+ self.num_attention_heads = num_attention_heads
641
+
642
+ for i in range(num_layers):
643
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
644
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
645
+
646
+ resnets.append(
647
+ ResnetBlock2D(
648
+ in_channels=resnet_in_channels + res_skip_channels,
649
+ out_channels=out_channels,
650
+ temb_channels=temb_channels,
651
+ eps=resnet_eps,
652
+ groups=resnet_groups,
653
+ dropout=dropout,
654
+ time_embedding_norm=resnet_time_scale_shift,
655
+ non_linearity=resnet_act_fn,
656
+ output_scale_factor=output_scale_factor,
657
+ pre_norm=resnet_pre_norm,
658
+ )
659
+ )
660
+ if not dual_cross_attention:
661
+ attentions.append(
662
+ TransformerMV2DModel(
663
+ num_attention_heads,
664
+ out_channels // num_attention_heads,
665
+ in_channels=out_channels,
666
+ num_layers=transformer_layers_per_block,
667
+ cross_attention_dim=cross_attention_dim,
668
+ norm_num_groups=resnet_groups,
669
+ use_linear_projection=use_linear_projection,
670
+ only_cross_attention=only_cross_attention,
671
+ upcast_attention=upcast_attention,
672
+ num_views=num_views,
673
+ joint_attention=joint_attention,
674
+ joint_attention_twice=joint_attention_twice,
675
+ multiview_attention=multiview_attention,
676
+ cross_domain_attention=cross_domain_attention
677
+ )
678
+ )
679
+ else:
680
+ raise NotImplementedError
681
+ self.attentions = nn.ModuleList(attentions)
682
+ self.resnets = nn.ModuleList(resnets)
683
+
684
+ if add_upsample:
685
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
686
+ else:
687
+ self.upsamplers = None
688
+ if num_views == 4:
689
+ self.gradient_checkpointing = False
690
+ else:
691
+ self.gradient_checkpointing = False
692
+
693
+ def forward(
694
+ self,
695
+ hidden_states: torch.FloatTensor,
696
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
697
+ temb: Optional[torch.FloatTensor] = None,
698
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
699
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
700
+ upsample_size: Optional[int] = None,
701
+ attention_mask: Optional[torch.FloatTensor] = None,
702
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
703
+ ):
704
+ for resnet, attn in zip(self.resnets, self.attentions):
705
+ # pop res hidden states
706
+ res_hidden_states = res_hidden_states_tuple[-1]
707
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
708
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
709
+
710
+ if self.training and self.gradient_checkpointing:
711
+
712
+ def create_custom_forward(module, return_dict=None):
713
+ def custom_forward(*inputs):
714
+ if return_dict is not None:
715
+ return module(*inputs, return_dict=return_dict)
716
+ else:
717
+ return module(*inputs)
718
+
719
+ return custom_forward
720
+
721
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
722
+ hidden_states = torch.utils.checkpoint.checkpoint(
723
+ create_custom_forward(resnet),
724
+ hidden_states,
725
+ temb,
726
+ **ckpt_kwargs,
727
+ )
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(attn, return_dict=False),
730
+ hidden_states,
731
+ encoder_hidden_states,
732
+ None, # timestep
733
+ None, # class_labels
734
+ cross_attention_kwargs,
735
+ attention_mask,
736
+ encoder_attention_mask,
737
+ **ckpt_kwargs,
738
+ )[0]
739
+ # hidden_states = attn(
740
+ # hidden_states,
741
+ # encoder_hidden_states=encoder_hidden_states,
742
+ # cross_attention_kwargs=cross_attention_kwargs,
743
+ # attention_mask=attention_mask,
744
+ # encoder_attention_mask=encoder_attention_mask,
745
+ # return_dict=False,
746
+ # )[0]
747
+ else:
748
+ hidden_states = resnet(hidden_states, temb)
749
+ hidden_states = attn(
750
+ hidden_states,
751
+ encoder_hidden_states=encoder_hidden_states,
752
+ cross_attention_kwargs=cross_attention_kwargs,
753
+ attention_mask=attention_mask,
754
+ encoder_attention_mask=encoder_attention_mask,
755
+ return_dict=False,
756
+ )[0]
757
+
758
+ if self.upsamplers is not None:
759
+ for upsampler in self.upsamplers:
760
+ hidden_states = upsampler(hidden_states, upsample_size)
761
+
762
+ return hidden_states
763
+
764
+
765
+ class CrossAttnDownBlockMV2D(nn.Module):
766
+ def __init__(
767
+ self,
768
+ in_channels: int,
769
+ out_channels: int,
770
+ temb_channels: int,
771
+ dropout: float = 0.0,
772
+ num_layers: int = 1,
773
+ transformer_layers_per_block: int = 1,
774
+ resnet_eps: float = 1e-6,
775
+ resnet_time_scale_shift: str = "default",
776
+ resnet_act_fn: str = "swish",
777
+ resnet_groups: int = 32,
778
+ resnet_pre_norm: bool = True,
779
+ num_attention_heads=1,
780
+ cross_attention_dim=1280,
781
+ output_scale_factor=1.0,
782
+ downsample_padding=1,
783
+ add_downsample=True,
784
+ dual_cross_attention=False,
785
+ use_linear_projection=False,
786
+ only_cross_attention=False,
787
+ upcast_attention=False,
788
+ num_views: int = 1,
789
+ joint_attention: bool = False,
790
+ joint_attention_twice: bool = False,
791
+ multiview_attention: bool = True,
792
+ cross_domain_attention: bool=False
793
+ ):
794
+ super().__init__()
795
+ resnets = []
796
+ attentions = []
797
+
798
+ self.has_cross_attention = True
799
+ self.num_attention_heads = num_attention_heads
800
+
801
+ for i in range(num_layers):
802
+ in_channels = in_channels if i == 0 else out_channels
803
+ resnets.append(
804
+ ResnetBlock2D(
805
+ in_channels=in_channels,
806
+ out_channels=out_channels,
807
+ temb_channels=temb_channels,
808
+ eps=resnet_eps,
809
+ groups=resnet_groups,
810
+ dropout=dropout,
811
+ time_embedding_norm=resnet_time_scale_shift,
812
+ non_linearity=resnet_act_fn,
813
+ output_scale_factor=output_scale_factor,
814
+ pre_norm=resnet_pre_norm,
815
+ )
816
+ )
817
+ if not dual_cross_attention:
818
+ attentions.append(
819
+ TransformerMV2DModel(
820
+ num_attention_heads,
821
+ out_channels // num_attention_heads,
822
+ in_channels=out_channels,
823
+ num_layers=transformer_layers_per_block,
824
+ cross_attention_dim=cross_attention_dim,
825
+ norm_num_groups=resnet_groups,
826
+ use_linear_projection=use_linear_projection,
827
+ only_cross_attention=only_cross_attention,
828
+ upcast_attention=upcast_attention,
829
+ num_views=num_views,
830
+ joint_attention=joint_attention,
831
+ joint_attention_twice=joint_attention_twice,
832
+ multiview_attention=multiview_attention,
833
+ cross_domain_attention=cross_domain_attention
834
+ )
835
+ )
836
+ else:
837
+ raise NotImplementedError
838
+ self.attentions = nn.ModuleList(attentions)
839
+ self.resnets = nn.ModuleList(resnets)
840
+
841
+ if add_downsample:
842
+ self.downsamplers = nn.ModuleList(
843
+ [
844
+ Downsample2D(
845
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
846
+ )
847
+ ]
848
+ )
849
+ else:
850
+ self.downsamplers = None
851
+ if num_views == 4:
852
+ self.gradient_checkpointing = False
853
+ else:
854
+ self.gradient_checkpointing = False
855
+
856
+ def forward(
857
+ self,
858
+ hidden_states: torch.FloatTensor,
859
+ temb: Optional[torch.FloatTensor] = None,
860
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
861
+ attention_mask: Optional[torch.FloatTensor] = None,
862
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
863
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
864
+ additional_residuals=None,
865
+ ):
866
+ output_states = ()
867
+
868
+ blocks = list(zip(self.resnets, self.attentions))
869
+
870
+ for i, (resnet, attn) in enumerate(blocks):
871
+ if self.training and self.gradient_checkpointing:
872
+
873
+ def create_custom_forward(module, return_dict=None):
874
+ def custom_forward(*inputs):
875
+ if return_dict is not None:
876
+ return module(*inputs, return_dict=return_dict)
877
+ else:
878
+ return module(*inputs)
879
+
880
+ return custom_forward
881
+
882
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
883
+ hidden_states = torch.utils.checkpoint.checkpoint(
884
+ create_custom_forward(resnet),
885
+ hidden_states,
886
+ temb,
887
+ **ckpt_kwargs,
888
+ )
889
+ hidden_states = torch.utils.checkpoint.checkpoint(
890
+ create_custom_forward(attn, return_dict=False),
891
+ hidden_states,
892
+ encoder_hidden_states,
893
+ None, # timestep
894
+ None, # class_labels
895
+ cross_attention_kwargs,
896
+ attention_mask,
897
+ encoder_attention_mask,
898
+ **ckpt_kwargs,
899
+ )[0]
900
+ else:
901
+ # import ipdb
902
+ # ipdb.set_trace()
903
+ hidden_states = resnet(hidden_states, temb)
904
+ hidden_states = attn(
905
+ hidden_states,
906
+ encoder_hidden_states=encoder_hidden_states,
907
+ cross_attention_kwargs=cross_attention_kwargs,
908
+ attention_mask=attention_mask,
909
+ encoder_attention_mask=encoder_attention_mask,
910
+ return_dict=False,
911
+ )[0]
912
+
913
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
914
+ if i == len(blocks) - 1 and additional_residuals is not None:
915
+ hidden_states = hidden_states + additional_residuals
916
+
917
+ output_states = output_states + (hidden_states,)
918
+
919
+ if self.downsamplers is not None:
920
+ for downsampler in self.downsamplers:
921
+ hidden_states = downsampler(hidden_states)
922
+
923
+ output_states = output_states + (hidden_states,)
924
+
925
+ return hidden_states, output_states
926
+
2D_Stage/tuneavideo/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from einops import rearrange
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
41
+ from diffusers.models.unet_2d_blocks import (
42
+ CrossAttnDownBlock2D,
43
+ CrossAttnUpBlock2D,
44
+ DownBlock2D,
45
+ UNetMidBlock2DCrossAttn,
46
+ UNetMidBlock2DSimpleCrossAttn,
47
+ UpBlock2D,
48
+ )
49
+ from diffusers.utils import (
50
+ CONFIG_NAME,
51
+ DIFFUSERS_CACHE,
52
+ FLAX_WEIGHTS_NAME,
53
+ HF_HUB_OFFLINE,
54
+ SAFETENSORS_WEIGHTS_NAME,
55
+ WEIGHTS_NAME,
56
+ _add_variant,
57
+ _get_model_file,
58
+ deprecate,
59
+ is_accelerate_available,
60
+ is_torch_version,
61
+ logging,
62
+ )
63
+ from diffusers import __version__
64
+ from tuneavideo.models.unet_mv2d_blocks import (
65
+ CrossAttnDownBlockMV2D,
66
+ CrossAttnUpBlockMV2D,
67
+ UNetMidBlockMV2DCrossAttn,
68
+ get_down_block,
69
+ get_up_block,
70
+ )
71
+ from diffusers.models.attention_processor import Attention, AttnProcessor
72
+ from diffusers.utils.import_utils import is_xformers_available
73
+ from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
74
+ from tuneavideo.models.refunet import ReferenceOnlyAttnProc
75
+
76
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
77
+
78
+
79
+ @dataclass
80
+ class UNetMV2DConditionOutput(BaseOutput):
81
+ """
82
+ The output of [`UNet2DConditionModel`].
83
+
84
+ Args:
85
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
86
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
87
+ """
88
+
89
+ sample: torch.FloatTensor = None
90
+
91
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
92
+ r"""
93
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
94
+ shaped output.
95
+
96
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
97
+ for all models (such as downloading or saving).
98
+
99
+ Parameters:
100
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
101
+ Height and width of input/output sample.
102
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
103
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
104
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
105
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
106
+ Whether to flip the sin to cos in the time embedding.
107
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
108
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
109
+ The tuple of downsample blocks to use.
110
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
111
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
112
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
113
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
114
+ The tuple of upsample blocks to use.
115
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
116
+ Whether to include self-attention in the basic transformer blocks, see
117
+ [`~models.attention.BasicTransformerBlock`].
118
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
119
+ The tuple of output channels for each block.
120
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
121
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
122
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
123
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
124
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
125
+ If `None`, normalization and activation layers is skipped in post-processing.
126
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
127
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
128
+ The dimension of the cross attention features.
129
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
130
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
131
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
132
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
133
+ encoder_hid_dim (`int`, *optional*, defaults to None):
134
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
135
+ dimension to `cross_attention_dim`.
136
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
137
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
138
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
139
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
140
+ num_attention_heads (`int`, *optional*):
141
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
142
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
143
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
144
+ class_embed_type (`str`, *optional*, defaults to `None`):
145
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
146
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
147
+ addition_embed_type (`str`, *optional*, defaults to `None`):
148
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
149
+ "text". "text" will use the `TextTimeEmbedding` layer.
150
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
151
+ Dimension for the timestep embeddings.
152
+ num_class_embeds (`int`, *optional*, defaults to `None`):
153
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
154
+ class conditioning with `class_embed_type` equal to `None`.
155
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
156
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
157
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
158
+ An optional override for the dimension of the projected time embedding.
159
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
160
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
161
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
162
+ timestep_post_act (`str`, *optional*, defaults to `None`):
163
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
164
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
165
+ The dimension of `cond_proj` layer in the timestep embedding.
166
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
167
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
168
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
169
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
170
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
171
+ embeddings with the class embeddings.
172
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
173
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
174
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
175
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
176
+ otherwise.
177
+ """
178
+
179
+ _supports_gradient_checkpointing = True
180
+
181
+ @register_to_config
182
+ def __init__(
183
+ self,
184
+ sample_size: Optional[int] = None,
185
+ in_channels: int = 4,
186
+ out_channels: int = 4,
187
+ center_input_sample: bool = False,
188
+ flip_sin_to_cos: bool = True,
189
+ freq_shift: int = 0,
190
+ down_block_types: Tuple[str] = (
191
+ "CrossAttnDownBlockMV2D",
192
+ "CrossAttnDownBlockMV2D",
193
+ "CrossAttnDownBlockMV2D",
194
+ "DownBlock2D",
195
+ ),
196
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
197
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
198
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
199
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
200
+ layers_per_block: Union[int, Tuple[int]] = 2,
201
+ downsample_padding: int = 1,
202
+ mid_block_scale_factor: float = 1,
203
+ act_fn: str = "silu",
204
+ norm_num_groups: Optional[int] = 32,
205
+ norm_eps: float = 1e-5,
206
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
207
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
208
+ encoder_hid_dim: Optional[int] = None,
209
+ encoder_hid_dim_type: Optional[str] = None,
210
+ attention_head_dim: Union[int, Tuple[int]] = 8,
211
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
212
+ dual_cross_attention: bool = False,
213
+ use_linear_projection: bool = False,
214
+ class_embed_type: Optional[str] = None,
215
+ addition_embed_type: Optional[str] = None,
216
+ addition_time_embed_dim: Optional[int] = None,
217
+ num_class_embeds: Optional[int] = None,
218
+ upcast_attention: bool = False,
219
+ resnet_time_scale_shift: str = "default",
220
+ resnet_skip_time_act: bool = False,
221
+ resnet_out_scale_factor: int = 1.0,
222
+ time_embedding_type: str = "positional",
223
+ time_embedding_dim: Optional[int] = None,
224
+ time_embedding_act_fn: Optional[str] = None,
225
+ timestep_post_act: Optional[str] = None,
226
+ time_cond_proj_dim: Optional[int] = None,
227
+ conv_in_kernel: int = 3,
228
+ conv_out_kernel: int = 3,
229
+ projection_class_embeddings_input_dim: Optional[int] = None,
230
+ class_embeddings_concat: bool = False,
231
+ mid_block_only_cross_attention: Optional[bool] = None,
232
+ cross_attention_norm: Optional[str] = None,
233
+ addition_embed_type_num_heads=64,
234
+ num_views: int = 1,
235
+ joint_attention: bool = False,
236
+ joint_attention_twice: bool = False,
237
+ multiview_attention: bool = True,
238
+ cross_domain_attention: bool = False,
239
+ camera_input_dim: int = 12,
240
+ camera_hidden_dim: int = 320,
241
+ camera_output_dim: int = 1280,
242
+
243
+ ):
244
+ super().__init__()
245
+
246
+ self.sample_size = sample_size
247
+
248
+ if num_attention_heads is not None:
249
+ raise ValueError(
250
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
251
+ )
252
+
253
+ # If `num_attention_heads` is not defined (which is the case for most models)
254
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
255
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
256
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
257
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
258
+ # which is why we correct for the naming here.
259
+ num_attention_heads = num_attention_heads or attention_head_dim
260
+
261
+ # Check inputs
262
+ if len(down_block_types) != len(up_block_types):
263
+ raise ValueError(
264
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
265
+ )
266
+
267
+ if len(block_out_channels) != len(down_block_types):
268
+ raise ValueError(
269
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
270
+ )
271
+
272
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
273
+ raise ValueError(
274
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
278
+ raise ValueError(
279
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
280
+ )
281
+
282
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
283
+ raise ValueError(
284
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
285
+ )
286
+
287
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
288
+ raise ValueError(
289
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
290
+ )
291
+
292
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
293
+ raise ValueError(
294
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
295
+ )
296
+
297
+ # input
298
+ conv_in_padding = (conv_in_kernel - 1) // 2
299
+ self.conv_in = nn.Conv2d(
300
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
301
+ )
302
+
303
+ # time
304
+ if time_embedding_type == "fourier":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
306
+ if time_embed_dim % 2 != 0:
307
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
308
+ self.time_proj = GaussianFourierProjection(
309
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
310
+ )
311
+ timestep_input_dim = time_embed_dim
312
+ elif time_embedding_type == "positional":
313
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
314
+
315
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
316
+ timestep_input_dim = block_out_channels[0]
317
+ else:
318
+ raise ValueError(
319
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
320
+ )
321
+
322
+ self.time_embedding = TimestepEmbedding(
323
+ timestep_input_dim,
324
+ time_embed_dim,
325
+ act_fn=act_fn,
326
+ post_act_fn=timestep_post_act,
327
+ cond_proj_dim=time_cond_proj_dim,
328
+ )
329
+
330
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
331
+ encoder_hid_dim_type = "text_proj"
332
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
333
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
334
+
335
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
336
+ raise ValueError(
337
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
338
+ )
339
+
340
+ if encoder_hid_dim_type == "text_proj":
341
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
342
+ elif encoder_hid_dim_type == "text_image_proj":
343
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
344
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
345
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
346
+ self.encoder_hid_proj = TextImageProjection(
347
+ text_embed_dim=encoder_hid_dim,
348
+ image_embed_dim=cross_attention_dim,
349
+ cross_attention_dim=cross_attention_dim,
350
+ )
351
+ elif encoder_hid_dim_type == "image_proj":
352
+ # Kandinsky 2.2
353
+ self.encoder_hid_proj = ImageProjection(
354
+ image_embed_dim=encoder_hid_dim,
355
+ cross_attention_dim=cross_attention_dim,
356
+ )
357
+ elif encoder_hid_dim_type is not None:
358
+ raise ValueError(
359
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
360
+ )
361
+ else:
362
+ self.encoder_hid_proj = None
363
+
364
+ # class embedding
365
+ if class_embed_type is None and num_class_embeds is not None:
366
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
367
+ elif class_embed_type == "timestep":
368
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
369
+ elif class_embed_type == "identity":
370
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
371
+ elif class_embed_type == "projection":
372
+ if projection_class_embeddings_input_dim is None:
373
+ raise ValueError(
374
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
375
+ )
376
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
377
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
378
+ # 2. it projects from an arbitrary input dimension.
379
+ #
380
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
381
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
382
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
383
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
384
+ elif class_embed_type == "simple_projection":
385
+ if projection_class_embeddings_input_dim is None:
386
+ raise ValueError(
387
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
388
+ )
389
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
390
+ else:
391
+ self.class_embedding = None
392
+
393
+ if addition_embed_type == "text":
394
+ if encoder_hid_dim is not None:
395
+ text_time_embedding_from_dim = encoder_hid_dim
396
+ else:
397
+ text_time_embedding_from_dim = cross_attention_dim
398
+
399
+ self.add_embedding = TextTimeEmbedding(
400
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
401
+ )
402
+ elif addition_embed_type == "text_image":
403
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
404
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
405
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
406
+ self.add_embedding = TextImageTimeEmbedding(
407
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
408
+ )
409
+ elif addition_embed_type == "text_time":
410
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
411
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
412
+ elif addition_embed_type == "image":
413
+ # Kandinsky 2.2
414
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
415
+ elif addition_embed_type == "image_hint":
416
+ # Kandinsky 2.2 ControlNet
417
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
418
+ elif addition_embed_type is not None:
419
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
420
+
421
+ if time_embedding_act_fn is None:
422
+ self.time_embed_act = None
423
+ else:
424
+ self.time_embed_act = get_activation(time_embedding_act_fn)
425
+
426
+ self.camera_embedding = nn.Sequential(
427
+ nn.Linear(camera_input_dim, time_embed_dim),
428
+ nn.SiLU(),
429
+ nn.Linear(time_embed_dim, time_embed_dim),
430
+ )
431
+
432
+ self.down_blocks = nn.ModuleList([])
433
+ self.up_blocks = nn.ModuleList([])
434
+
435
+ if isinstance(only_cross_attention, bool):
436
+ if mid_block_only_cross_attention is None:
437
+ mid_block_only_cross_attention = only_cross_attention
438
+
439
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
440
+
441
+ if mid_block_only_cross_attention is None:
442
+ mid_block_only_cross_attention = False
443
+
444
+ if isinstance(num_attention_heads, int):
445
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
446
+
447
+ if isinstance(attention_head_dim, int):
448
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
449
+
450
+ if isinstance(cross_attention_dim, int):
451
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
452
+
453
+ if isinstance(layers_per_block, int):
454
+ layers_per_block = [layers_per_block] * len(down_block_types)
455
+
456
+ if isinstance(transformer_layers_per_block, int):
457
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
458
+
459
+ if class_embeddings_concat:
460
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
461
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
462
+ # regular time embeddings
463
+ blocks_time_embed_dim = time_embed_dim * 2
464
+ else:
465
+ blocks_time_embed_dim = time_embed_dim
466
+
467
+ # down
468
+ output_channel = block_out_channels[0]
469
+ for i, down_block_type in enumerate(down_block_types):
470
+ input_channel = output_channel
471
+ output_channel = block_out_channels[i]
472
+ is_final_block = i == len(block_out_channels) - 1
473
+
474
+ down_block = get_down_block(
475
+ down_block_type,
476
+ num_layers=layers_per_block[i],
477
+ transformer_layers_per_block=transformer_layers_per_block[i],
478
+ in_channels=input_channel,
479
+ out_channels=output_channel,
480
+ temb_channels=blocks_time_embed_dim,
481
+ add_downsample=not is_final_block,
482
+ resnet_eps=norm_eps,
483
+ resnet_act_fn=act_fn,
484
+ resnet_groups=norm_num_groups,
485
+ cross_attention_dim=cross_attention_dim[i],
486
+ num_attention_heads=num_attention_heads[i],
487
+ downsample_padding=downsample_padding,
488
+ dual_cross_attention=dual_cross_attention,
489
+ use_linear_projection=use_linear_projection,
490
+ only_cross_attention=only_cross_attention[i],
491
+ upcast_attention=upcast_attention,
492
+ resnet_time_scale_shift=resnet_time_scale_shift,
493
+ resnet_skip_time_act=resnet_skip_time_act,
494
+ resnet_out_scale_factor=resnet_out_scale_factor,
495
+ cross_attention_norm=cross_attention_norm,
496
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
497
+ num_views=num_views,
498
+ joint_attention=joint_attention,
499
+ joint_attention_twice=joint_attention_twice,
500
+ multiview_attention=multiview_attention,
501
+ cross_domain_attention=cross_domain_attention
502
+ )
503
+ self.down_blocks.append(down_block)
504
+
505
+ # mid
506
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
507
+ self.mid_block = UNetMidBlock2DCrossAttn(
508
+ transformer_layers_per_block=transformer_layers_per_block[-1],
509
+ in_channels=block_out_channels[-1],
510
+ temb_channels=blocks_time_embed_dim,
511
+ resnet_eps=norm_eps,
512
+ resnet_act_fn=act_fn,
513
+ output_scale_factor=mid_block_scale_factor,
514
+ resnet_time_scale_shift=resnet_time_scale_shift,
515
+ cross_attention_dim=cross_attention_dim[-1],
516
+ num_attention_heads=num_attention_heads[-1],
517
+ resnet_groups=norm_num_groups,
518
+ dual_cross_attention=dual_cross_attention,
519
+ use_linear_projection=use_linear_projection,
520
+ upcast_attention=upcast_attention,
521
+ )
522
+ # custom MV2D attention block
523
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
524
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
525
+ transformer_layers_per_block=transformer_layers_per_block[-1],
526
+ in_channels=block_out_channels[-1],
527
+ temb_channels=blocks_time_embed_dim,
528
+ resnet_eps=norm_eps,
529
+ resnet_act_fn=act_fn,
530
+ output_scale_factor=mid_block_scale_factor,
531
+ resnet_time_scale_shift=resnet_time_scale_shift,
532
+ cross_attention_dim=cross_attention_dim[-1],
533
+ num_attention_heads=num_attention_heads[-1],
534
+ resnet_groups=norm_num_groups,
535
+ dual_cross_attention=dual_cross_attention,
536
+ use_linear_projection=use_linear_projection,
537
+ upcast_attention=upcast_attention,
538
+ num_views=num_views,
539
+ joint_attention=joint_attention,
540
+ joint_attention_twice=joint_attention_twice,
541
+ multiview_attention=multiview_attention,
542
+ cross_domain_attention=cross_domain_attention
543
+ )
544
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
545
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
546
+ in_channels=block_out_channels[-1],
547
+ temb_channels=blocks_time_embed_dim,
548
+ resnet_eps=norm_eps,
549
+ resnet_act_fn=act_fn,
550
+ output_scale_factor=mid_block_scale_factor,
551
+ cross_attention_dim=cross_attention_dim[-1],
552
+ attention_head_dim=attention_head_dim[-1],
553
+ resnet_groups=norm_num_groups,
554
+ resnet_time_scale_shift=resnet_time_scale_shift,
555
+ skip_time_act=resnet_skip_time_act,
556
+ only_cross_attention=mid_block_only_cross_attention,
557
+ cross_attention_norm=cross_attention_norm,
558
+ )
559
+ elif mid_block_type is None:
560
+ self.mid_block = None
561
+ else:
562
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
563
+
564
+ # count how many layers upsample the images
565
+ self.num_upsamplers = 0
566
+
567
+ # up
568
+ reversed_block_out_channels = list(reversed(block_out_channels))
569
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
570
+ reversed_layers_per_block = list(reversed(layers_per_block))
571
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
572
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
573
+ only_cross_attention = list(reversed(only_cross_attention))
574
+
575
+ output_channel = reversed_block_out_channels[0]
576
+ for i, up_block_type in enumerate(up_block_types):
577
+ is_final_block = i == len(block_out_channels) - 1
578
+
579
+ prev_output_channel = output_channel
580
+ output_channel = reversed_block_out_channels[i]
581
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
582
+
583
+ # add upsample block for all BUT final layer
584
+ if not is_final_block:
585
+ add_upsample = True
586
+ self.num_upsamplers += 1
587
+ else:
588
+ add_upsample = False
589
+
590
+ up_block = get_up_block(
591
+ up_block_type,
592
+ num_layers=reversed_layers_per_block[i] + 1,
593
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
594
+ in_channels=input_channel,
595
+ out_channels=output_channel,
596
+ prev_output_channel=prev_output_channel,
597
+ temb_channels=blocks_time_embed_dim,
598
+ add_upsample=add_upsample,
599
+ resnet_eps=norm_eps,
600
+ resnet_act_fn=act_fn,
601
+ resnet_groups=norm_num_groups,
602
+ cross_attention_dim=reversed_cross_attention_dim[i],
603
+ num_attention_heads=reversed_num_attention_heads[i],
604
+ dual_cross_attention=dual_cross_attention,
605
+ use_linear_projection=use_linear_projection,
606
+ only_cross_attention=only_cross_attention[i],
607
+ upcast_attention=upcast_attention,
608
+ resnet_time_scale_shift=resnet_time_scale_shift,
609
+ resnet_skip_time_act=resnet_skip_time_act,
610
+ resnet_out_scale_factor=resnet_out_scale_factor,
611
+ cross_attention_norm=cross_attention_norm,
612
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
613
+ num_views=num_views,
614
+ joint_attention=joint_attention,
615
+ joint_attention_twice=joint_attention_twice,
616
+ multiview_attention=multiview_attention,
617
+ cross_domain_attention=cross_domain_attention
618
+ )
619
+ self.up_blocks.append(up_block)
620
+ prev_output_channel = output_channel
621
+
622
+ # out
623
+ if norm_num_groups is not None:
624
+ self.conv_norm_out = nn.GroupNorm(
625
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
626
+ )
627
+
628
+ self.conv_act = get_activation(act_fn)
629
+
630
+ else:
631
+ self.conv_norm_out = None
632
+ self.conv_act = None
633
+
634
+ conv_out_padding = (conv_out_kernel - 1) // 2
635
+ self.conv_out = nn.Conv2d(
636
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
637
+ )
638
+
639
+ @property
640
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
641
+ r"""
642
+ Returns:
643
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
644
+ indexed by its weight name.
645
+ """
646
+ # set recursively
647
+ processors = {}
648
+
649
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
650
+ if hasattr(module, "set_processor"):
651
+ processors[f"{name}.processor"] = module.processor
652
+
653
+ for sub_name, child in module.named_children():
654
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
655
+
656
+ return processors
657
+
658
+ for name, module in self.named_children():
659
+ fn_recursive_add_processors(name, module, processors)
660
+
661
+ return processors
662
+
663
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
664
+ r"""
665
+ Sets the attention processor to use to compute attention.
666
+
667
+ Parameters:
668
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
669
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
670
+ for **all** `Attention` layers.
671
+
672
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
673
+ processor. This is strongly recommended when setting trainable attention processors.
674
+
675
+ """
676
+ count = len(self.attn_processors.keys())
677
+
678
+ if isinstance(processor, dict) and len(processor) != count:
679
+ raise ValueError(
680
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
681
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
682
+ )
683
+
684
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
685
+ if hasattr(module, "set_processor"):
686
+ if not isinstance(processor, dict):
687
+ module.set_processor(processor)
688
+ else:
689
+ module.set_processor(processor.pop(f"{name}.processor"))
690
+
691
+ for sub_name, child in module.named_children():
692
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
693
+
694
+ for name, module in self.named_children():
695
+ fn_recursive_attn_processor(name, module, processor)
696
+
697
+ def set_default_attn_processor(self):
698
+ """
699
+ Disables custom attention processors and sets the default attention implementation.
700
+ """
701
+ self.set_attn_processor(AttnProcessor())
702
+
703
+ def set_attention_slice(self, slice_size):
704
+ r"""
705
+ Enable sliced attention computation.
706
+
707
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
708
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
709
+
710
+ Args:
711
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
712
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
713
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
714
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
715
+ must be a multiple of `slice_size`.
716
+ """
717
+ sliceable_head_dims = []
718
+
719
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
720
+ if hasattr(module, "set_attention_slice"):
721
+ sliceable_head_dims.append(module.sliceable_head_dim)
722
+
723
+ for child in module.children():
724
+ fn_recursive_retrieve_sliceable_dims(child)
725
+
726
+ # retrieve number of attention layers
727
+ for module in self.children():
728
+ fn_recursive_retrieve_sliceable_dims(module)
729
+
730
+ num_sliceable_layers = len(sliceable_head_dims)
731
+
732
+ if slice_size == "auto":
733
+ # half the attention head size is usually a good trade-off between
734
+ # speed and memory
735
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
736
+ elif slice_size == "max":
737
+ # make smallest slice possible
738
+ slice_size = num_sliceable_layers * [1]
739
+
740
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
741
+
742
+ if len(slice_size) != len(sliceable_head_dims):
743
+ raise ValueError(
744
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
745
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
746
+ )
747
+
748
+ for i in range(len(slice_size)):
749
+ size = slice_size[i]
750
+ dim = sliceable_head_dims[i]
751
+ if size is not None and size > dim:
752
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
753
+
754
+ # Recursively walk through all the children.
755
+ # Any children which exposes the set_attention_slice method
756
+ # gets the message
757
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
758
+ if hasattr(module, "set_attention_slice"):
759
+ module.set_attention_slice(slice_size.pop())
760
+
761
+ for child in module.children():
762
+ fn_recursive_set_attention_slice(child, slice_size)
763
+
764
+ reversed_slice_size = list(reversed(slice_size))
765
+ for module in self.children():
766
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
767
+
768
+ # def _set_gradient_checkpointing(self, module, value=False):
769
+ # if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
770
+ # module.gradient_checkpointing = value
771
+
772
+ def forward(
773
+ self,
774
+ sample: torch.FloatTensor,
775
+ timestep: Union[torch.Tensor, float, int],
776
+ encoder_hidden_states: torch.Tensor,
777
+ camera_matrixs: Optional[torch.Tensor] = None,
778
+ class_labels: Optional[torch.Tensor] = None,
779
+ timestep_cond: Optional[torch.Tensor] = None,
780
+ attention_mask: Optional[torch.Tensor] = None,
781
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
782
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
783
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
784
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
785
+ encoder_attention_mask: Optional[torch.Tensor] = None,
786
+ return_dict: bool = True,
787
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
788
+ r"""
789
+ The [`UNet2DConditionModel`] forward method.
790
+
791
+ Args:
792
+ sample (`torch.FloatTensor`):
793
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
794
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
795
+ encoder_hidden_states (`torch.FloatTensor`):
796
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
797
+ encoder_attention_mask (`torch.Tensor`):
798
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
799
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
800
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
801
+ return_dict (`bool`, *optional*, defaults to `True`):
802
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
803
+ tuple.
804
+ cross_attention_kwargs (`dict`, *optional*):
805
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
806
+ added_cond_kwargs: (`dict`, *optional*):
807
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
808
+ are passed along to the UNet blocks.
809
+
810
+ Returns:
811
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
812
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
813
+ a `tuple` is returned where the first element is the sample tensor.
814
+ """
815
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
816
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
817
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
818
+ # on the fly if necessary.
819
+ default_overall_up_factor = 2**self.num_upsamplers
820
+
821
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
822
+ forward_upsample_size = False
823
+ upsample_size = None
824
+
825
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
826
+ logger.info("Forward upsample size to force interpolation output size.")
827
+ forward_upsample_size = True
828
+
829
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
830
+ # expects mask of shape:
831
+ # [batch, key_tokens]
832
+ # adds singleton query_tokens dimension:
833
+ # [batch, 1, key_tokens]
834
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
835
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
836
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
837
+ if attention_mask is not None:
838
+ # assume that mask is expressed as:
839
+ # (1 = keep, 0 = discard)
840
+ # convert mask into a bias that can be added to attention scores:
841
+ # (keep = +0, discard = -10000.0)
842
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
843
+ attention_mask = attention_mask.unsqueeze(1)
844
+
845
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
846
+ if encoder_attention_mask is not None:
847
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
848
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
849
+
850
+ # 0. center input if necessary
851
+ if self.config.center_input_sample:
852
+ sample = 2 * sample - 1.0
853
+
854
+ # 1. time
855
+ timesteps = timestep
856
+ if not torch.is_tensor(timesteps):
857
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
858
+ # This would be a good case for the `match` statement (Python 3.10+)
859
+ is_mps = sample.device.type == "mps"
860
+ if isinstance(timestep, float):
861
+ dtype = torch.float32 if is_mps else torch.float64
862
+ else:
863
+ dtype = torch.int32 if is_mps else torch.int64
864
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
865
+ elif len(timesteps.shape) == 0:
866
+ timesteps = timesteps[None].to(sample.device)
867
+
868
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
869
+ timesteps = timesteps.expand(sample.shape[0])
870
+
871
+ t_emb = self.time_proj(timesteps)
872
+
873
+ # `Timesteps` does not contain any weights and will always return f32 tensors
874
+ # but time_embedding might actually be running in fp16. so we need to cast here.
875
+ # there might be better ways to encapsulate this.
876
+ t_emb = t_emb.to(dtype=sample.dtype)
877
+ emb = self.time_embedding(t_emb, timestep_cond)
878
+
879
+ # import pdb; pdb.set_trace()
880
+ if camera_matrixs is not None:
881
+ emb = torch.unsqueeze(emb, 1)
882
+ # came emb
883
+ cam_emb = self.camera_embedding(camera_matrixs)
884
+ # cam_emb = self.camera_embedding_2(cam_emb)
885
+ # import ipdb
886
+ # ipdb.set_trace()
887
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
888
+ emb = emb + cam_emb
889
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
890
+
891
+ aug_emb = None
892
+
893
+ if self.class_embedding is not None and class_labels is not None:
894
+ if class_labels is None:
895
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
896
+
897
+ if self.config.class_embed_type == "timestep":
898
+ class_labels = self.time_proj(class_labels)
899
+
900
+ # `Timesteps` does not contain any weights and will always return f32 tensors
901
+ # there might be better ways to encapsulate this.
902
+ class_labels = class_labels.to(dtype=sample.dtype)
903
+
904
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
905
+
906
+ if self.config.class_embeddings_concat:
907
+ emb = torch.cat([emb, class_emb], dim=-1)
908
+ else:
909
+ emb = emb + class_emb
910
+
911
+ if self.config.addition_embed_type == "text":
912
+ aug_emb = self.add_embedding(encoder_hidden_states)
913
+ elif self.config.addition_embed_type == "text_image":
914
+ # Kandinsky 2.1 - style
915
+ if "image_embeds" not in added_cond_kwargs:
916
+ raise ValueError(
917
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
918
+ )
919
+
920
+ image_embs = added_cond_kwargs.get("image_embeds")
921
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
922
+ aug_emb = self.add_embedding(text_embs, image_embs)
923
+ elif self.config.addition_embed_type == "text_time":
924
+ # SDXL - style
925
+ if "text_embeds" not in added_cond_kwargs:
926
+ raise ValueError(
927
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
928
+ )
929
+ text_embeds = added_cond_kwargs.get("text_embeds")
930
+ if "time_ids" not in added_cond_kwargs:
931
+ raise ValueError(
932
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
933
+ )
934
+ time_ids = added_cond_kwargs.get("time_ids")
935
+ time_embeds = self.add_time_proj(time_ids.flatten())
936
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
937
+
938
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
939
+ add_embeds = add_embeds.to(emb.dtype)
940
+ aug_emb = self.add_embedding(add_embeds)
941
+ elif self.config.addition_embed_type == "image":
942
+ # Kandinsky 2.2 - style
943
+ if "image_embeds" not in added_cond_kwargs:
944
+ raise ValueError(
945
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
946
+ )
947
+ image_embs = added_cond_kwargs.get("image_embeds")
948
+ aug_emb = self.add_embedding(image_embs)
949
+ elif self.config.addition_embed_type == "image_hint":
950
+ # Kandinsky 2.2 - style
951
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
952
+ raise ValueError(
953
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
954
+ )
955
+ image_embs = added_cond_kwargs.get("image_embeds")
956
+ hint = added_cond_kwargs.get("hint")
957
+ aug_emb, hint = self.add_embedding(image_embs, hint)
958
+ sample = torch.cat([sample, hint], dim=1)
959
+
960
+ emb = emb + aug_emb if aug_emb is not None else emb
961
+
962
+ if self.time_embed_act is not None:
963
+ emb = self.time_embed_act(emb)
964
+
965
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
966
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
967
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
968
+ # Kadinsky 2.1 - style
969
+ if "image_embeds" not in added_cond_kwargs:
970
+ raise ValueError(
971
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
972
+ )
973
+
974
+ image_embeds = added_cond_kwargs.get("image_embeds")
975
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
976
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
977
+ # Kandinsky 2.2 - style
978
+ if "image_embeds" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
981
+ )
982
+ image_embeds = added_cond_kwargs.get("image_embeds")
983
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
984
+ # 2. pre-process
985
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
986
+ sample = self.conv_in(sample)
987
+ # 3. down
988
+
989
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
990
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
991
+
992
+ down_block_res_samples = (sample,)
993
+ for downsample_block in self.down_blocks:
994
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
995
+ # For t2i-adapter CrossAttnDownBlock2D
996
+ additional_residuals = {}
997
+ if is_adapter and len(down_block_additional_residuals) > 0:
998
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
999
+
1000
+ sample, res_samples = downsample_block(
1001
+ hidden_states=sample,
1002
+ temb=emb,
1003
+ encoder_hidden_states=encoder_hidden_states,
1004
+ attention_mask=attention_mask,
1005
+ cross_attention_kwargs=cross_attention_kwargs,
1006
+ encoder_attention_mask=encoder_attention_mask,
1007
+ **additional_residuals,
1008
+ )
1009
+ else:
1010
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1011
+
1012
+ if is_adapter and len(down_block_additional_residuals) > 0:
1013
+ sample += down_block_additional_residuals.pop(0)
1014
+
1015
+ down_block_res_samples += res_samples
1016
+
1017
+ if is_controlnet:
1018
+ new_down_block_res_samples = ()
1019
+
1020
+ for down_block_res_sample, down_block_additional_residual in zip(
1021
+ down_block_res_samples, down_block_additional_residuals
1022
+ ):
1023
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1024
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1025
+
1026
+ down_block_res_samples = new_down_block_res_samples
1027
+ # print("after down: ", sample.mean(), emb.mean())
1028
+ # 4. mid
1029
+ if self.mid_block is not None:
1030
+ sample = self.mid_block(
1031
+ sample,
1032
+ emb,
1033
+ encoder_hidden_states=encoder_hidden_states,
1034
+ attention_mask=attention_mask,
1035
+ cross_attention_kwargs=cross_attention_kwargs,
1036
+ encoder_attention_mask=encoder_attention_mask,
1037
+ )
1038
+
1039
+ if is_controlnet:
1040
+ sample = sample + mid_block_additional_residual
1041
+
1042
+ # 5. up
1043
+ for i, upsample_block in enumerate(self.up_blocks):
1044
+ is_final_block = i == len(self.up_blocks) - 1
1045
+
1046
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1047
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1048
+
1049
+ # if we have not reached the final block and need to forward the
1050
+ # upsample size, we do it here
1051
+ if not is_final_block and forward_upsample_size:
1052
+ upsample_size = down_block_res_samples[-1].shape[2:]
1053
+
1054
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1055
+ sample = upsample_block(
1056
+ hidden_states=sample,
1057
+ temb=emb,
1058
+ res_hidden_states_tuple=res_samples,
1059
+ encoder_hidden_states=encoder_hidden_states,
1060
+ cross_attention_kwargs=cross_attention_kwargs,
1061
+ upsample_size=upsample_size,
1062
+ attention_mask=attention_mask,
1063
+ encoder_attention_mask=encoder_attention_mask,
1064
+ )
1065
+ else:
1066
+ sample = upsample_block(
1067
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1068
+ )
1069
+
1070
+ # 6. post-process
1071
+ if self.conv_norm_out:
1072
+ sample = self.conv_norm_out(sample)
1073
+ sample = self.conv_act(sample)
1074
+ sample = self.conv_out(sample)
1075
+
1076
+ if not return_dict:
1077
+ return (sample,)
1078
+
1079
+ return UNetMV2DConditionOutput(sample=sample)
1080
+
1081
+ @classmethod
1082
+ def from_pretrained_2d(
1083
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1084
+ camera_embedding_type: str, num_views: int, sample_size: int,
1085
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1086
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1087
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1088
+ cross_domain_attention: bool = False,
1089
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
1090
+ **kwargs
1091
+ ):
1092
+ r"""
1093
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1094
+
1095
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1096
+ train the model, set it back in training mode with `model.train()`.
1097
+
1098
+ Parameters:
1099
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1100
+ Can be either:
1101
+
1102
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1103
+ the Hub.
1104
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1105
+ with [`~ModelMixin.save_pretrained`].
1106
+
1107
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1108
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1109
+ is not used.
1110
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1111
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1112
+ dtype is automatically derived from the model's weights.
1113
+ force_download (`bool`, *optional*, defaults to `False`):
1114
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1115
+ cached versions if they exist.
1116
+ resume_download (`bool`, *optional*, defaults to `False`):
1117
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1118
+ incompletely downloaded files are deleted.
1119
+ proxies (`Dict[str, str]`, *optional*):
1120
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1121
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1122
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1123
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1124
+ local_files_only(`bool`, *optional*, defaults to `False`):
1125
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1126
+ won't be downloaded from the Hub.
1127
+ use_auth_token (`str` or *bool*, *optional*):
1128
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1129
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1130
+ revision (`str`, *optional*, defaults to `"main"`):
1131
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1132
+ allowed by Git.
1133
+ from_flax (`bool`, *optional*, defaults to `False`):
1134
+ Load the model weights from a Flax checkpoint save file.
1135
+ subfolder (`str`, *optional*, defaults to `""`):
1136
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1137
+ mirror (`str`, *optional*):
1138
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1139
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1140
+ information.
1141
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1142
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1143
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1144
+ same device.
1145
+
1146
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1147
+ more information about each option see [designing a device
1148
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1149
+ max_memory (`Dict`, *optional*):
1150
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1151
+ each GPU and the available CPU RAM if unset.
1152
+ offload_folder (`str` or `os.PathLike`, *optional*):
1153
+ The path to offload weights if `device_map` contains the value `"disk"`.
1154
+ offload_state_dict (`bool`, *optional*):
1155
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1156
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1157
+ when there is some disk offload.
1158
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1159
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1160
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1161
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1162
+ argument to `True` will raise an error.
1163
+ variant (`str`, *optional*):
1164
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1165
+ loading `from_flax`.
1166
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1167
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1168
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1169
+ weights. If set to `False`, `safetensors` weights are not loaded.
1170
+
1171
+ <Tip>
1172
+
1173
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1174
+ `huggingface-cli login`. You can also activate the special
1175
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1176
+ firewalled environment.
1177
+
1178
+ </Tip>
1179
+
1180
+ Example:
1181
+
1182
+ ```py
1183
+ from diffusers import UNet2DConditionModel
1184
+
1185
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1186
+ ```
1187
+
1188
+ If you get the error message below, you need to finetune the weights for your downstream task:
1189
+
1190
+ ```bash
1191
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1192
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1193
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1194
+ ```
1195
+ """
1196
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1197
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1198
+ force_download = kwargs.pop("force_download", False)
1199
+ from_flax = kwargs.pop("from_flax", False)
1200
+ resume_download = kwargs.pop("resume_download", False)
1201
+ proxies = kwargs.pop("proxies", None)
1202
+ output_loading_info = kwargs.pop("output_loading_info", False)
1203
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1204
+ use_auth_token = kwargs.pop("use_auth_token", None)
1205
+ revision = kwargs.pop("revision", None)
1206
+ torch_dtype = kwargs.pop("torch_dtype", None)
1207
+ subfolder = kwargs.pop("subfolder", None)
1208
+ device_map = kwargs.pop("device_map", None)
1209
+ max_memory = kwargs.pop("max_memory", None)
1210
+ offload_folder = kwargs.pop("offload_folder", None)
1211
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1212
+ variant = kwargs.pop("variant", None)
1213
+ use_safetensors = kwargs.pop("use_safetensors", None)
1214
+
1215
+ # if use_safetensors and not is_safetensors_available():
1216
+ # raise ValueError(
1217
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1218
+ # )
1219
+
1220
+ allow_pickle = False
1221
+ if use_safetensors is None:
1222
+ # use_safetensors = is_safetensors_available()
1223
+ use_safetensors = False
1224
+ allow_pickle = True
1225
+
1226
+ if device_map is not None and not is_accelerate_available():
1227
+ raise NotImplementedError(
1228
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1229
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1230
+ )
1231
+
1232
+ # Check if we can handle device_map and dispatching the weights
1233
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1234
+ raise NotImplementedError(
1235
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1236
+ " `device_map=None`."
1237
+ )
1238
+
1239
+ # Load config if we don't provide a configuration
1240
+ config_path = pretrained_model_name_or_path
1241
+
1242
+ user_agent = {
1243
+ "diffusers": __version__,
1244
+ "file_type": "model",
1245
+ "framework": "pytorch",
1246
+ }
1247
+
1248
+ # load config
1249
+ config, unused_kwargs, commit_hash = cls.load_config(
1250
+ config_path,
1251
+ cache_dir=cache_dir,
1252
+ return_unused_kwargs=True,
1253
+ return_commit_hash=True,
1254
+ force_download=force_download,
1255
+ resume_download=resume_download,
1256
+ proxies=proxies,
1257
+ local_files_only=local_files_only,
1258
+ use_auth_token=use_auth_token,
1259
+ revision=revision,
1260
+ subfolder=subfolder,
1261
+ device_map=device_map,
1262
+ max_memory=max_memory,
1263
+ offload_folder=offload_folder,
1264
+ offload_state_dict=offload_state_dict,
1265
+ user_agent=user_agent,
1266
+ **kwargs,
1267
+ )
1268
+
1269
+ # modify config
1270
+ config["_class_name"] = cls.__name__
1271
+ config['in_channels'] = in_channels
1272
+ config['out_channels'] = out_channels
1273
+ config['sample_size'] = sample_size # training resolution
1274
+ config['num_views'] = num_views
1275
+ config['joint_attention'] = joint_attention
1276
+ config['joint_attention_twice'] = joint_attention_twice
1277
+ config['multiview_attention'] = multiview_attention
1278
+ config['cross_domain_attention'] = cross_domain_attention
1279
+ config["down_block_types"] = [
1280
+ "CrossAttnDownBlockMV2D",
1281
+ "CrossAttnDownBlockMV2D",
1282
+ "CrossAttnDownBlockMV2D",
1283
+ "DownBlock2D"
1284
+ ]
1285
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1286
+ config["up_block_types"] = [
1287
+ "UpBlock2D",
1288
+ "CrossAttnUpBlockMV2D",
1289
+ "CrossAttnUpBlockMV2D",
1290
+ "CrossAttnUpBlockMV2D"
1291
+ ]
1292
+ config['class_embed_type'] = 'projection'
1293
+ if camera_embedding_type == 'e_de_da_sincos':
1294
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1295
+ else:
1296
+ raise NotImplementedError
1297
+
1298
+ # load model
1299
+ model_file = None
1300
+ if from_flax:
1301
+ raise NotImplementedError
1302
+ else:
1303
+ if use_safetensors:
1304
+ try:
1305
+ model_file = _get_model_file(
1306
+ pretrained_model_name_or_path,
1307
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1308
+ cache_dir=cache_dir,
1309
+ force_download=force_download,
1310
+ resume_download=resume_download,
1311
+ proxies=proxies,
1312
+ local_files_only=local_files_only,
1313
+ use_auth_token=use_auth_token,
1314
+ revision=revision,
1315
+ subfolder=subfolder,
1316
+ user_agent=user_agent,
1317
+ commit_hash=commit_hash,
1318
+ )
1319
+ except IOError as e:
1320
+ if not allow_pickle:
1321
+ raise e
1322
+ pass
1323
+ if model_file is None:
1324
+ model_file = _get_model_file(
1325
+ pretrained_model_name_or_path,
1326
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1327
+ cache_dir=cache_dir,
1328
+ force_download=force_download,
1329
+ resume_download=resume_download,
1330
+ proxies=proxies,
1331
+ local_files_only=local_files_only,
1332
+ use_auth_token=use_auth_token,
1333
+ revision=revision,
1334
+ subfolder=subfolder,
1335
+ user_agent=user_agent,
1336
+ commit_hash=commit_hash,
1337
+ )
1338
+
1339
+ model = cls.from_config(config, **unused_kwargs)
1340
+ if local_crossattn:
1341
+ unet_lora_attn_procs = dict()
1342
+ for name, _ in model.attn_processors.items():
1343
+ if not name.endswith("attn1.processor"):
1344
+ default_attn_proc = AttnProcessor()
1345
+ elif is_xformers_available():
1346
+ default_attn_proc = XFormersMVAttnProcessor()
1347
+ else:
1348
+ default_attn_proc = MVAttnProcessor()
1349
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
1350
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
1351
+ )
1352
+ model.set_attn_processor(unet_lora_attn_procs)
1353
+ state_dict = load_state_dict(model_file, variant=variant)
1354
+ model._convert_deprecated_attention_blocks(state_dict)
1355
+
1356
+ conv_in_weight = state_dict['conv_in.weight']
1357
+ conv_out_weight = state_dict['conv_out.weight']
1358
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1359
+ model,
1360
+ state_dict,
1361
+ model_file,
1362
+ pretrained_model_name_or_path,
1363
+ ignore_mismatched_sizes=True,
1364
+ )
1365
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1366
+ # initialize from the original SD structure
1367
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1368
+
1369
+ # whether to place all zero to new layers?
1370
+ if zero_init_conv_in:
1371
+ model.conv_in.weight.data[:,4:] = 0.
1372
+
1373
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1374
+ # initialize from the original SD structure
1375
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1376
+ if out_channels == 8: # copy for the last 4 channels
1377
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1378
+
1379
+ if zero_init_camera_projection:
1380
+ for p in model.class_embedding.parameters():
1381
+ torch.nn.init.zeros_(p)
1382
+
1383
+ loading_info = {
1384
+ "missing_keys": missing_keys,
1385
+ "unexpected_keys": unexpected_keys,
1386
+ "mismatched_keys": mismatched_keys,
1387
+ "error_msgs": error_msgs,
1388
+ }
1389
+
1390
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1391
+ raise ValueError(
1392
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1393
+ )
1394
+ elif torch_dtype is not None:
1395
+ model = model.to(torch_dtype)
1396
+
1397
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1398
+
1399
+ # Set model in evaluation mode to deactivate DropOut modules by default
1400
+ model.eval()
1401
+ if output_loading_info:
1402
+ return model, loading_info
1403
+
1404
+ return model
1405
+
1406
+ @classmethod
1407
+ def _load_pretrained_model_2d(
1408
+ cls,
1409
+ model,
1410
+ state_dict,
1411
+ resolved_archive_file,
1412
+ pretrained_model_name_or_path,
1413
+ ignore_mismatched_sizes=False,
1414
+ ):
1415
+ # Retrieve missing & unexpected_keys
1416
+ model_state_dict = model.state_dict()
1417
+ loaded_keys = list(state_dict.keys())
1418
+
1419
+ expected_keys = list(model_state_dict.keys())
1420
+
1421
+ original_loaded_keys = loaded_keys
1422
+
1423
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1424
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1425
+
1426
+ # Make sure we are able to load base models as well as derived models (with heads)
1427
+ model_to_load = model
1428
+
1429
+ def _find_mismatched_keys(
1430
+ state_dict,
1431
+ model_state_dict,
1432
+ loaded_keys,
1433
+ ignore_mismatched_sizes,
1434
+ ):
1435
+ mismatched_keys = []
1436
+ if ignore_mismatched_sizes:
1437
+ for checkpoint_key in loaded_keys:
1438
+ model_key = checkpoint_key
1439
+
1440
+ if (
1441
+ model_key in model_state_dict
1442
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1443
+ ):
1444
+ mismatched_keys.append(
1445
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1446
+ )
1447
+ del state_dict[checkpoint_key]
1448
+ return mismatched_keys
1449
+
1450
+ if state_dict is not None:
1451
+ # Whole checkpoint
1452
+ mismatched_keys = _find_mismatched_keys(
1453
+ state_dict,
1454
+ model_state_dict,
1455
+ original_loaded_keys,
1456
+ ignore_mismatched_sizes,
1457
+ )
1458
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1459
+
1460
+ if len(error_msgs) > 0:
1461
+ error_msg = "\n\t".join(error_msgs)
1462
+ if "size mismatch" in error_msg:
1463
+ error_msg += (
1464
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1465
+ )
1466
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1467
+
1468
+ if len(unexpected_keys) > 0:
1469
+ logger.warning(
1470
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1471
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1472
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1473
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1474
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1475
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1476
+ " identical (initializing a BertForSequenceClassification model from a"
1477
+ " BertForSequenceClassification model)."
1478
+ )
1479
+ else:
1480
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1481
+ if len(missing_keys) > 0:
1482
+ logger.warning(
1483
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1484
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1485
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1486
+ )
1487
+ elif len(mismatched_keys) == 0:
1488
+ logger.info(
1489
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1490
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1491
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1492
+ " without further training."
1493
+ )
1494
+ if len(mismatched_keys) > 0:
1495
+ mismatched_warning = "\n".join(
1496
+ [
1497
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1498
+ for key, shape1, shape2 in mismatched_keys
1499
+ ]
1500
+ )
1501
+ logger.warning(
1502
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1503
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1504
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1505
+ " able to use it for predictions and inference."
1506
+ )
1507
+
1508
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1509
+
2D_Stage/tuneavideo/models/unet_mv2d_ref.py ADDED
@@ -0,0 +1,1570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from einops import rearrange
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import UNet2DConditionLoadersMixin
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.models.activations import get_activation
28
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
29
+ from diffusers.models.embeddings import (
30
+ GaussianFourierProjection,
31
+ ImageHintTimeEmbedding,
32
+ ImageProjection,
33
+ ImageTimeEmbedding,
34
+ TextImageProjection,
35
+ TextImageTimeEmbedding,
36
+ TextTimeEmbedding,
37
+ TimestepEmbedding,
38
+ Timesteps,
39
+ )
40
+ from diffusers.models.lora import LoRALinearLayer
41
+
42
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
43
+ from diffusers.models.unet_2d_blocks import (
44
+ CrossAttnDownBlock2D,
45
+ CrossAttnUpBlock2D,
46
+ DownBlock2D,
47
+ UNetMidBlock2DCrossAttn,
48
+ UNetMidBlock2DSimpleCrossAttn,
49
+ UpBlock2D,
50
+ )
51
+ from diffusers.utils import (
52
+ CONFIG_NAME,
53
+ DIFFUSERS_CACHE,
54
+ FLAX_WEIGHTS_NAME,
55
+ HF_HUB_OFFLINE,
56
+ SAFETENSORS_WEIGHTS_NAME,
57
+ WEIGHTS_NAME,
58
+ _add_variant,
59
+ _get_model_file,
60
+ deprecate,
61
+ is_accelerate_available,
62
+ is_torch_version,
63
+ logging,
64
+ )
65
+ from diffusers import __version__
66
+ from tuneavideo.models.unet_mv2d_blocks import (
67
+ CrossAttnDownBlockMV2D,
68
+ CrossAttnUpBlockMV2D,
69
+ UNetMidBlockMV2DCrossAttn,
70
+ get_down_block,
71
+ get_up_block,
72
+ )
73
+ from diffusers.models.attention_processor import Attention, AttnProcessor
74
+ from diffusers.utils.import_utils import is_xformers_available
75
+ from tuneavideo.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor
76
+ from tuneavideo.models.refunet import ReferenceOnlyAttnProc
77
+
78
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
79
+
80
+
81
+ @dataclass
82
+ class UNetMV2DRefOutput(BaseOutput):
83
+ """
84
+ The output of [`UNet2DConditionModel`].
85
+
86
+ Args:
87
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
88
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
89
+ """
90
+
91
+ sample: torch.FloatTensor = None
92
+
93
+ class Identity(torch.nn.Module):
94
+ r"""A placeholder identity operator that is argument-insensitive.
95
+
96
+ Args:
97
+ args: any argument (unused)
98
+ kwargs: any keyword argument (unused)
99
+
100
+ Shape:
101
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
102
+ - Output: :math:`(*)`, same shape as the input.
103
+
104
+ Examples::
105
+
106
+ >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
107
+ >>> input = torch.randn(128, 20)
108
+ >>> output = m(input)
109
+ >>> print(output.size())
110
+ torch.Size([128, 20])
111
+
112
+ """
113
+ def __init__(self, scale=None, *args, **kwargs) -> None:
114
+ super(Identity, self).__init__()
115
+
116
+ def forward(self, input, *args, **kwargs):
117
+ return input
118
+
119
+
120
+
121
+ class _LoRACompatibleLinear(nn.Module):
122
+ """
123
+ A Linear layer that can be used with LoRA.
124
+ """
125
+
126
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
127
+ super().__init__(*args, **kwargs)
128
+ self.lora_layer = lora_layer
129
+
130
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
131
+ self.lora_layer = lora_layer
132
+
133
+ def _fuse_lora(self):
134
+ pass
135
+
136
+ def _unfuse_lora(self):
137
+ pass
138
+
139
+ def forward(self, hidden_states, scale=None, lora_scale: int = 1):
140
+ return hidden_states
141
+
142
+ class UNetMV2DRefModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
143
+ r"""
144
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
145
+ shaped output.
146
+
147
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
148
+ for all models (such as downloading or saving).
149
+
150
+ Parameters:
151
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
152
+ Height and width of input/output sample.
153
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
154
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
155
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
156
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
157
+ Whether to flip the sin to cos in the time embedding.
158
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
159
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
160
+ The tuple of downsample blocks to use.
161
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
162
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
163
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
164
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
165
+ The tuple of upsample blocks to use.
166
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
167
+ Whether to include self-attention in the basic transformer blocks, see
168
+ [`~models.attention.BasicTransformerBlock`].
169
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
170
+ The tuple of output channels for each block.
171
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
172
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
173
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
174
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
175
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
176
+ If `None`, normalization and activation layers is skipped in post-processing.
177
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
178
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
179
+ The dimension of the cross attention features.
180
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
181
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
182
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
183
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
184
+ encoder_hid_dim (`int`, *optional*, defaults to None):
185
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
186
+ dimension to `cross_attention_dim`.
187
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
188
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
189
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
190
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
191
+ num_attention_heads (`int`, *optional*):
192
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
193
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
194
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
195
+ class_embed_type (`str`, *optional*, defaults to `None`):
196
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
197
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
198
+ addition_embed_type (`str`, *optional*, defaults to `None`):
199
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
200
+ "text". "text" will use the `TextTimeEmbedding` layer.
201
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
202
+ Dimension for the timestep embeddings.
203
+ num_class_embeds (`int`, *optional*, defaults to `None`):
204
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
205
+ class conditioning with `class_embed_type` equal to `None`.
206
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
207
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
208
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
209
+ An optional override for the dimension of the projected time embedding.
210
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
211
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
212
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
213
+ timestep_post_act (`str`, *optional*, defaults to `None`):
214
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
215
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
216
+ The dimension of `cond_proj` layer in the timestep embedding.
217
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
218
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
219
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
220
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
221
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
222
+ embeddings with the class embeddings.
223
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
224
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
225
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
226
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
227
+ otherwise.
228
+ """
229
+
230
+ _supports_gradient_checkpointing = True
231
+
232
+ @register_to_config
233
+ def __init__(
234
+ self,
235
+ sample_size: Optional[int] = None,
236
+ in_channels: int = 4,
237
+ out_channels: int = 4,
238
+ center_input_sample: bool = False,
239
+ flip_sin_to_cos: bool = True,
240
+ freq_shift: int = 0,
241
+ down_block_types: Tuple[str] = (
242
+ "CrossAttnDownBlockMV2D",
243
+ "CrossAttnDownBlockMV2D",
244
+ "CrossAttnDownBlockMV2D",
245
+ "DownBlock2D",
246
+ ),
247
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
248
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
249
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
250
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
251
+ layers_per_block: Union[int, Tuple[int]] = 2,
252
+ downsample_padding: int = 1,
253
+ mid_block_scale_factor: float = 1,
254
+ act_fn: str = "silu",
255
+ norm_num_groups: Optional[int] = 32,
256
+ norm_eps: float = 1e-5,
257
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
258
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
259
+ encoder_hid_dim: Optional[int] = None,
260
+ encoder_hid_dim_type: Optional[str] = None,
261
+ attention_head_dim: Union[int, Tuple[int]] = 8,
262
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
263
+ dual_cross_attention: bool = False,
264
+ use_linear_projection: bool = False,
265
+ class_embed_type: Optional[str] = None,
266
+ addition_embed_type: Optional[str] = None,
267
+ addition_time_embed_dim: Optional[int] = None,
268
+ num_class_embeds: Optional[int] = None,
269
+ upcast_attention: bool = False,
270
+ resnet_time_scale_shift: str = "default",
271
+ resnet_skip_time_act: bool = False,
272
+ resnet_out_scale_factor: int = 1.0,
273
+ time_embedding_type: str = "positional",
274
+ time_embedding_dim: Optional[int] = None,
275
+ time_embedding_act_fn: Optional[str] = None,
276
+ timestep_post_act: Optional[str] = None,
277
+ time_cond_proj_dim: Optional[int] = None,
278
+ conv_in_kernel: int = 3,
279
+ conv_out_kernel: int = 3,
280
+ projection_class_embeddings_input_dim: Optional[int] = None,
281
+ class_embeddings_concat: bool = False,
282
+ mid_block_only_cross_attention: Optional[bool] = None,
283
+ cross_attention_norm: Optional[str] = None,
284
+ addition_embed_type_num_heads=64,
285
+ num_views: int = 1,
286
+ joint_attention: bool = False,
287
+ joint_attention_twice: bool = False,
288
+ multiview_attention: bool = True,
289
+ cross_domain_attention: bool = False,
290
+ camera_input_dim: int = 12,
291
+ camera_hidden_dim: int = 320,
292
+ camera_output_dim: int = 1280,
293
+
294
+ ):
295
+ super().__init__()
296
+
297
+ self.sample_size = sample_size
298
+
299
+ if num_attention_heads is not None:
300
+ raise ValueError(
301
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
302
+ )
303
+
304
+ # If `num_attention_heads` is not defined (which is the case for most models)
305
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
306
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
307
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
308
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
309
+ # which is why we correct for the naming here.
310
+ num_attention_heads = num_attention_heads or attention_head_dim
311
+
312
+ # Check inputs
313
+ if len(down_block_types) != len(up_block_types):
314
+ raise ValueError(
315
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
316
+ )
317
+
318
+ if len(block_out_channels) != len(down_block_types):
319
+ raise ValueError(
320
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
321
+ )
322
+
323
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
324
+ raise ValueError(
325
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
326
+ )
327
+
328
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
329
+ raise ValueError(
330
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
331
+ )
332
+
333
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
334
+ raise ValueError(
335
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
336
+ )
337
+
338
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
339
+ raise ValueError(
340
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
341
+ )
342
+
343
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
344
+ raise ValueError(
345
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
346
+ )
347
+
348
+ # input
349
+ conv_in_padding = (conv_in_kernel - 1) // 2
350
+ self.conv_in = nn.Conv2d(
351
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
352
+ )
353
+
354
+ # time
355
+ if time_embedding_type == "fourier":
356
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
357
+ if time_embed_dim % 2 != 0:
358
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
359
+ self.time_proj = GaussianFourierProjection(
360
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
361
+ )
362
+ timestep_input_dim = time_embed_dim
363
+ elif time_embedding_type == "positional":
364
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
365
+
366
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
367
+ timestep_input_dim = block_out_channels[0]
368
+ else:
369
+ raise ValueError(
370
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
371
+ )
372
+
373
+ self.time_embedding = TimestepEmbedding(
374
+ timestep_input_dim,
375
+ time_embed_dim,
376
+ act_fn=act_fn,
377
+ post_act_fn=timestep_post_act,
378
+ cond_proj_dim=time_cond_proj_dim,
379
+ )
380
+
381
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
382
+ encoder_hid_dim_type = "text_proj"
383
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
384
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
385
+
386
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
387
+ raise ValueError(
388
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
389
+ )
390
+
391
+ if encoder_hid_dim_type == "text_proj":
392
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
393
+ elif encoder_hid_dim_type == "text_image_proj":
394
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
395
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
396
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
397
+ self.encoder_hid_proj = TextImageProjection(
398
+ text_embed_dim=encoder_hid_dim,
399
+ image_embed_dim=cross_attention_dim,
400
+ cross_attention_dim=cross_attention_dim,
401
+ )
402
+ elif encoder_hid_dim_type == "image_proj":
403
+ # Kandinsky 2.2
404
+ self.encoder_hid_proj = ImageProjection(
405
+ image_embed_dim=encoder_hid_dim,
406
+ cross_attention_dim=cross_attention_dim,
407
+ )
408
+ elif encoder_hid_dim_type is not None:
409
+ raise ValueError(
410
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
411
+ )
412
+ else:
413
+ self.encoder_hid_proj = None
414
+
415
+ # class embedding
416
+ if class_embed_type is None and num_class_embeds is not None:
417
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
418
+ elif class_embed_type == "timestep":
419
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
420
+ elif class_embed_type == "identity":
421
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
422
+ elif class_embed_type == "projection":
423
+ if projection_class_embeddings_input_dim is None:
424
+ raise ValueError(
425
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
426
+ )
427
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
428
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
429
+ # 2. it projects from an arbitrary input dimension.
430
+ #
431
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
432
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
433
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
434
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
435
+ elif class_embed_type == "simple_projection":
436
+ if projection_class_embeddings_input_dim is None:
437
+ raise ValueError(
438
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
439
+ )
440
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
441
+ else:
442
+ self.class_embedding = None
443
+
444
+ if addition_embed_type == "text":
445
+ if encoder_hid_dim is not None:
446
+ text_time_embedding_from_dim = encoder_hid_dim
447
+ else:
448
+ text_time_embedding_from_dim = cross_attention_dim
449
+
450
+ self.add_embedding = TextTimeEmbedding(
451
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
452
+ )
453
+ elif addition_embed_type == "text_image":
454
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
455
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
456
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
457
+ self.add_embedding = TextImageTimeEmbedding(
458
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
459
+ )
460
+ elif addition_embed_type == "text_time":
461
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
462
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
463
+ elif addition_embed_type == "image":
464
+ # Kandinsky 2.2
465
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
466
+ elif addition_embed_type == "image_hint":
467
+ # Kandinsky 2.2 ControlNet
468
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
469
+ elif addition_embed_type is not None:
470
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
471
+
472
+ if time_embedding_act_fn is None:
473
+ self.time_embed_act = None
474
+ else:
475
+ self.time_embed_act = get_activation(time_embedding_act_fn)
476
+
477
+ self.camera_embedding = nn.Sequential(
478
+ nn.Linear(camera_input_dim, time_embed_dim),
479
+ nn.SiLU(),
480
+ nn.Linear(time_embed_dim, time_embed_dim),
481
+ )
482
+
483
+ self.down_blocks = nn.ModuleList([])
484
+ self.up_blocks = nn.ModuleList([])
485
+
486
+ if isinstance(only_cross_attention, bool):
487
+ if mid_block_only_cross_attention is None:
488
+ mid_block_only_cross_attention = only_cross_attention
489
+
490
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
491
+
492
+ if mid_block_only_cross_attention is None:
493
+ mid_block_only_cross_attention = False
494
+
495
+ if isinstance(num_attention_heads, int):
496
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
497
+
498
+ if isinstance(attention_head_dim, int):
499
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
500
+
501
+ if isinstance(cross_attention_dim, int):
502
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
503
+
504
+ if isinstance(layers_per_block, int):
505
+ layers_per_block = [layers_per_block] * len(down_block_types)
506
+
507
+ if isinstance(transformer_layers_per_block, int):
508
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
509
+
510
+ if class_embeddings_concat:
511
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
512
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
513
+ # regular time embeddings
514
+ blocks_time_embed_dim = time_embed_dim * 2
515
+ else:
516
+ blocks_time_embed_dim = time_embed_dim
517
+
518
+ # down
519
+ output_channel = block_out_channels[0]
520
+ for i, down_block_type in enumerate(down_block_types):
521
+ input_channel = output_channel
522
+ output_channel = block_out_channels[i]
523
+ is_final_block = i == len(block_out_channels) - 1
524
+
525
+ down_block = get_down_block(
526
+ down_block_type,
527
+ num_layers=layers_per_block[i],
528
+ transformer_layers_per_block=transformer_layers_per_block[i],
529
+ in_channels=input_channel,
530
+ out_channels=output_channel,
531
+ temb_channels=blocks_time_embed_dim,
532
+ add_downsample=not is_final_block,
533
+ resnet_eps=norm_eps,
534
+ resnet_act_fn=act_fn,
535
+ resnet_groups=norm_num_groups,
536
+ cross_attention_dim=cross_attention_dim[i],
537
+ num_attention_heads=num_attention_heads[i],
538
+ downsample_padding=downsample_padding,
539
+ dual_cross_attention=dual_cross_attention,
540
+ use_linear_projection=use_linear_projection,
541
+ only_cross_attention=only_cross_attention[i],
542
+ upcast_attention=upcast_attention,
543
+ resnet_time_scale_shift=resnet_time_scale_shift,
544
+ resnet_skip_time_act=resnet_skip_time_act,
545
+ resnet_out_scale_factor=resnet_out_scale_factor,
546
+ cross_attention_norm=cross_attention_norm,
547
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
548
+ num_views=num_views,
549
+ joint_attention=joint_attention,
550
+ joint_attention_twice=joint_attention_twice,
551
+ multiview_attention=multiview_attention,
552
+ cross_domain_attention=cross_domain_attention
553
+ )
554
+ self.down_blocks.append(down_block)
555
+
556
+ # mid
557
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
558
+ self.mid_block = UNetMidBlock2DCrossAttn(
559
+ transformer_layers_per_block=transformer_layers_per_block[-1],
560
+ in_channels=block_out_channels[-1],
561
+ temb_channels=blocks_time_embed_dim,
562
+ resnet_eps=norm_eps,
563
+ resnet_act_fn=act_fn,
564
+ output_scale_factor=mid_block_scale_factor,
565
+ resnet_time_scale_shift=resnet_time_scale_shift,
566
+ cross_attention_dim=cross_attention_dim[-1],
567
+ num_attention_heads=num_attention_heads[-1],
568
+ resnet_groups=norm_num_groups,
569
+ dual_cross_attention=dual_cross_attention,
570
+ use_linear_projection=use_linear_projection,
571
+ upcast_attention=upcast_attention,
572
+ )
573
+ # custom MV2D attention block
574
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
575
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
576
+ transformer_layers_per_block=transformer_layers_per_block[-1],
577
+ in_channels=block_out_channels[-1],
578
+ temb_channels=blocks_time_embed_dim,
579
+ resnet_eps=norm_eps,
580
+ resnet_act_fn=act_fn,
581
+ output_scale_factor=mid_block_scale_factor,
582
+ resnet_time_scale_shift=resnet_time_scale_shift,
583
+ cross_attention_dim=cross_attention_dim[-1],
584
+ num_attention_heads=num_attention_heads[-1],
585
+ resnet_groups=norm_num_groups,
586
+ dual_cross_attention=dual_cross_attention,
587
+ use_linear_projection=use_linear_projection,
588
+ upcast_attention=upcast_attention,
589
+ num_views=num_views,
590
+ joint_attention=joint_attention,
591
+ joint_attention_twice=joint_attention_twice,
592
+ multiview_attention=multiview_attention,
593
+ cross_domain_attention=cross_domain_attention
594
+ )
595
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
596
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
597
+ in_channels=block_out_channels[-1],
598
+ temb_channels=blocks_time_embed_dim,
599
+ resnet_eps=norm_eps,
600
+ resnet_act_fn=act_fn,
601
+ output_scale_factor=mid_block_scale_factor,
602
+ cross_attention_dim=cross_attention_dim[-1],
603
+ attention_head_dim=attention_head_dim[-1],
604
+ resnet_groups=norm_num_groups,
605
+ resnet_time_scale_shift=resnet_time_scale_shift,
606
+ skip_time_act=resnet_skip_time_act,
607
+ only_cross_attention=mid_block_only_cross_attention,
608
+ cross_attention_norm=cross_attention_norm,
609
+ )
610
+ elif mid_block_type is None:
611
+ self.mid_block = None
612
+ else:
613
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
614
+
615
+ # count how many layers upsample the images
616
+ self.num_upsamplers = 0
617
+
618
+ # up
619
+ reversed_block_out_channels = list(reversed(block_out_channels))
620
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
621
+ reversed_layers_per_block = list(reversed(layers_per_block))
622
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
623
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
624
+ only_cross_attention = list(reversed(only_cross_attention))
625
+
626
+ output_channel = reversed_block_out_channels[0]
627
+ for i, up_block_type in enumerate(up_block_types):
628
+ is_final_block = i == len(block_out_channels) - 1
629
+
630
+ prev_output_channel = output_channel
631
+ output_channel = reversed_block_out_channels[i]
632
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
633
+
634
+ # add upsample block for all BUT final layer
635
+ if not is_final_block:
636
+ add_upsample = True
637
+ self.num_upsamplers += 1
638
+ else:
639
+ add_upsample = False
640
+
641
+ up_block = get_up_block(
642
+ up_block_type,
643
+ num_layers=reversed_layers_per_block[i] + 1,
644
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
645
+ in_channels=input_channel,
646
+ out_channels=output_channel,
647
+ prev_output_channel=prev_output_channel,
648
+ temb_channels=blocks_time_embed_dim,
649
+ add_upsample=add_upsample,
650
+ resnet_eps=norm_eps,
651
+ resnet_act_fn=act_fn,
652
+ resnet_groups=norm_num_groups,
653
+ cross_attention_dim=reversed_cross_attention_dim[i],
654
+ num_attention_heads=reversed_num_attention_heads[i],
655
+ dual_cross_attention=dual_cross_attention,
656
+ use_linear_projection=use_linear_projection,
657
+ only_cross_attention=only_cross_attention[i],
658
+ upcast_attention=upcast_attention,
659
+ resnet_time_scale_shift=resnet_time_scale_shift,
660
+ resnet_skip_time_act=resnet_skip_time_act,
661
+ resnet_out_scale_factor=resnet_out_scale_factor,
662
+ cross_attention_norm=cross_attention_norm,
663
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
664
+ num_views=num_views,
665
+ joint_attention=joint_attention,
666
+ joint_attention_twice=joint_attention_twice,
667
+ multiview_attention=multiview_attention,
668
+ cross_domain_attention=cross_domain_attention
669
+ )
670
+ self.up_blocks.append(up_block)
671
+ prev_output_channel = output_channel
672
+
673
+ # out
674
+ # if norm_num_groups is not None:
675
+ # self.conv_norm_out = nn.GroupNorm(
676
+ # num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
677
+ # )
678
+
679
+ # self.conv_act = get_activation(act_fn)
680
+
681
+ # else:
682
+ # self.conv_norm_out = None
683
+ # self.conv_act = None
684
+
685
+ # conv_out_padding = (conv_out_kernel - 1) // 2
686
+ # self.conv_out = nn.Conv2d(
687
+ # block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
688
+ # )
689
+
690
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_q = _LoRACompatibleLinear()
691
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_k = _LoRACompatibleLinear()
692
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_v = _LoRACompatibleLinear()
693
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn1.to_out = nn.ModuleList([Identity(), Identity()])
694
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm2 = Identity()
695
+ self.up_blocks[3].attentions[2].transformer_blocks[0].attn2 = None
696
+ self.up_blocks[3].attentions[2].transformer_blocks[0].norm3 = Identity()
697
+ self.up_blocks[3].attentions[2].transformer_blocks[0].ff = Identity()
698
+ self.up_blocks[3].attentions[2].proj_out = Identity()
699
+
700
+ @property
701
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
702
+ r"""
703
+ Returns:
704
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
705
+ indexed by its weight name.
706
+ """
707
+ # set recursively
708
+ processors = {}
709
+
710
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
711
+ if hasattr(module, "set_processor"):
712
+ processors[f"{name}.processor"] = module.processor
713
+
714
+ for sub_name, child in module.named_children():
715
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
716
+
717
+ return processors
718
+
719
+ for name, module in self.named_children():
720
+ fn_recursive_add_processors(name, module, processors)
721
+
722
+ return processors
723
+
724
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
725
+ r"""
726
+ Sets the attention processor to use to compute attention.
727
+
728
+ Parameters:
729
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
730
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
731
+ for **all** `Attention` layers.
732
+
733
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
734
+ processor. This is strongly recommended when setting trainable attention processors.
735
+
736
+ """
737
+ count = len(self.attn_processors.keys())
738
+
739
+ if isinstance(processor, dict) and len(processor) != count:
740
+ raise ValueError(
741
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
742
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
743
+ )
744
+
745
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
746
+ if hasattr(module, "set_processor"):
747
+ if not isinstance(processor, dict):
748
+ module.set_processor(processor)
749
+ else:
750
+ module.set_processor(processor.pop(f"{name}.processor"))
751
+
752
+ for sub_name, child in module.named_children():
753
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
754
+
755
+ for name, module in self.named_children():
756
+ fn_recursive_attn_processor(name, module, processor)
757
+
758
+ def set_default_attn_processor(self):
759
+ """
760
+ Disables custom attention processors and sets the default attention implementation.
761
+ """
762
+ self.set_attn_processor(AttnProcessor())
763
+
764
+ def set_attention_slice(self, slice_size):
765
+ r"""
766
+ Enable sliced attention computation.
767
+
768
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
769
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
770
+
771
+ Args:
772
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
773
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
774
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
775
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
776
+ must be a multiple of `slice_size`.
777
+ """
778
+ sliceable_head_dims = []
779
+
780
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
781
+ if hasattr(module, "set_attention_slice"):
782
+ sliceable_head_dims.append(module.sliceable_head_dim)
783
+
784
+ for child in module.children():
785
+ fn_recursive_retrieve_sliceable_dims(child)
786
+
787
+ # retrieve number of attention layers
788
+ for module in self.children():
789
+ fn_recursive_retrieve_sliceable_dims(module)
790
+
791
+ num_sliceable_layers = len(sliceable_head_dims)
792
+
793
+ if slice_size == "auto":
794
+ # half the attention head size is usually a good trade-off between
795
+ # speed and memory
796
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
797
+ elif slice_size == "max":
798
+ # make smallest slice possible
799
+ slice_size = num_sliceable_layers * [1]
800
+
801
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
802
+
803
+ if len(slice_size) != len(sliceable_head_dims):
804
+ raise ValueError(
805
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
806
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
807
+ )
808
+
809
+ for i in range(len(slice_size)):
810
+ size = slice_size[i]
811
+ dim = sliceable_head_dims[i]
812
+ if size is not None and size > dim:
813
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
814
+
815
+ # Recursively walk through all the children.
816
+ # Any children which exposes the set_attention_slice method
817
+ # gets the message
818
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
819
+ if hasattr(module, "set_attention_slice"):
820
+ module.set_attention_slice(slice_size.pop())
821
+
822
+ for child in module.children():
823
+ fn_recursive_set_attention_slice(child, slice_size)
824
+
825
+ reversed_slice_size = list(reversed(slice_size))
826
+ for module in self.children():
827
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
828
+
829
+ def _set_gradient_checkpointing(self, module, value=False):
830
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
831
+ module.gradient_checkpointing = value
832
+
833
+ def forward(
834
+ self,
835
+ sample: torch.FloatTensor,
836
+ timestep: Union[torch.Tensor, float, int],
837
+ encoder_hidden_states: torch.Tensor,
838
+ camera_matrixs: Optional[torch.Tensor] = None,
839
+ class_labels: Optional[torch.Tensor] = None,
840
+ timestep_cond: Optional[torch.Tensor] = None,
841
+ attention_mask: Optional[torch.Tensor] = None,
842
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
843
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
844
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
845
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
846
+ encoder_attention_mask: Optional[torch.Tensor] = None,
847
+ return_dict: bool = True,
848
+ ) -> Union[UNetMV2DRefOutput, Tuple]:
849
+ r"""
850
+ The [`UNet2DConditionModel`] forward method.
851
+
852
+ Args:
853
+ sample (`torch.FloatTensor`):
854
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
855
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
856
+ encoder_hidden_states (`torch.FloatTensor`):
857
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
858
+ encoder_attention_mask (`torch.Tensor`):
859
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
860
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
861
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
862
+ return_dict (`bool`, *optional*, defaults to `True`):
863
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
864
+ tuple.
865
+ cross_attention_kwargs (`dict`, *optional*):
866
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
867
+ added_cond_kwargs: (`dict`, *optional*):
868
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
869
+ are passed along to the UNet blocks.
870
+
871
+ Returns:
872
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
873
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
874
+ a `tuple` is returned where the first element is the sample tensor.
875
+ """
876
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
877
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
878
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
879
+ # on the fly if necessary.
880
+ default_overall_up_factor = 2**self.num_upsamplers
881
+
882
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
883
+ forward_upsample_size = False
884
+ upsample_size = None
885
+
886
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
887
+ logger.info("Forward upsample size to force interpolation output size.")
888
+ forward_upsample_size = True
889
+
890
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
891
+ # expects mask of shape:
892
+ # [batch, key_tokens]
893
+ # adds singleton query_tokens dimension:
894
+ # [batch, 1, key_tokens]
895
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
896
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
897
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
898
+ if attention_mask is not None:
899
+ # assume that mask is expressed as:
900
+ # (1 = keep, 0 = discard)
901
+ # convert mask into a bias that can be added to attention scores:
902
+ # (keep = +0, discard = -10000.0)
903
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
904
+ attention_mask = attention_mask.unsqueeze(1)
905
+
906
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
907
+ if encoder_attention_mask is not None:
908
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
909
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
910
+
911
+ # 0. center input if necessary
912
+ if self.config.center_input_sample:
913
+ sample = 2 * sample - 1.0
914
+
915
+ # 1. time
916
+ timesteps = timestep
917
+ if not torch.is_tensor(timesteps):
918
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
919
+ # This would be a good case for the `match` statement (Python 3.10+)
920
+ is_mps = sample.device.type == "mps"
921
+ if isinstance(timestep, float):
922
+ dtype = torch.float32 if is_mps else torch.float64
923
+ else:
924
+ dtype = torch.int32 if is_mps else torch.int64
925
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
926
+ elif len(timesteps.shape) == 0:
927
+ timesteps = timesteps[None].to(sample.device)
928
+
929
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
930
+ timesteps = timesteps.expand(sample.shape[0])
931
+
932
+ t_emb = self.time_proj(timesteps)
933
+
934
+ # `Timesteps` does not contain any weights and will always return f32 tensors
935
+ # but time_embedding might actually be running in fp16. so we need to cast here.
936
+ # there might be better ways to encapsulate this.
937
+ t_emb = t_emb.to(dtype=sample.dtype)
938
+ emb = self.time_embedding(t_emb, timestep_cond)
939
+
940
+ # import pdb; pdb.set_trace()
941
+ if camera_matrixs is not None:
942
+ emb = torch.unsqueeze(emb, 1)
943
+ # came emb
944
+ cam_emb = self.camera_embedding(camera_matrixs)
945
+ # cam_emb = self.camera_embedding_2(cam_emb)
946
+ emb = emb.repeat(1,cam_emb.shape[1],1) #torch.Size([32, 4, 1280])
947
+ emb = emb + cam_emb
948
+ emb = rearrange(emb, "b f c -> (b f) c", f=emb.shape[1])
949
+
950
+ aug_emb = None
951
+
952
+ if self.class_embedding is not None and class_labels is not None:
953
+ if class_labels is None:
954
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
955
+
956
+ if self.config.class_embed_type == "timestep":
957
+ class_labels = self.time_proj(class_labels)
958
+
959
+ # `Timesteps` does not contain any weights and will always return f32 tensors
960
+ # there might be better ways to encapsulate this.
961
+ class_labels = class_labels.to(dtype=sample.dtype)
962
+
963
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
964
+
965
+ if self.config.class_embeddings_concat:
966
+ emb = torch.cat([emb, class_emb], dim=-1)
967
+ else:
968
+ emb = emb + class_emb
969
+
970
+ if self.config.addition_embed_type == "text":
971
+ aug_emb = self.add_embedding(encoder_hidden_states)
972
+ elif self.config.addition_embed_type == "text_image":
973
+ # Kandinsky 2.1 - style
974
+ if "image_embeds" not in added_cond_kwargs:
975
+ raise ValueError(
976
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
977
+ )
978
+
979
+ image_embs = added_cond_kwargs.get("image_embeds")
980
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
981
+ aug_emb = self.add_embedding(text_embs, image_embs)
982
+ elif self.config.addition_embed_type == "text_time":
983
+ # SDXL - style
984
+ if "text_embeds" not in added_cond_kwargs:
985
+ raise ValueError(
986
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
987
+ )
988
+ text_embeds = added_cond_kwargs.get("text_embeds")
989
+ if "time_ids" not in added_cond_kwargs:
990
+ raise ValueError(
991
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
992
+ )
993
+ time_ids = added_cond_kwargs.get("time_ids")
994
+ time_embeds = self.add_time_proj(time_ids.flatten())
995
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
996
+
997
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
998
+ add_embeds = add_embeds.to(emb.dtype)
999
+ aug_emb = self.add_embedding(add_embeds)
1000
+ elif self.config.addition_embed_type == "image":
1001
+ # Kandinsky 2.2 - style
1002
+ if "image_embeds" not in added_cond_kwargs:
1003
+ raise ValueError(
1004
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1005
+ )
1006
+ image_embs = added_cond_kwargs.get("image_embeds")
1007
+ aug_emb = self.add_embedding(image_embs)
1008
+ elif self.config.addition_embed_type == "image_hint":
1009
+ # Kandinsky 2.2 - style
1010
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1011
+ raise ValueError(
1012
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1013
+ )
1014
+ image_embs = added_cond_kwargs.get("image_embeds")
1015
+ hint = added_cond_kwargs.get("hint")
1016
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1017
+ sample = torch.cat([sample, hint], dim=1)
1018
+
1019
+ emb = emb + aug_emb if aug_emb is not None else emb
1020
+
1021
+ if self.time_embed_act is not None:
1022
+ emb = self.time_embed_act(emb)
1023
+
1024
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1025
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1026
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1027
+ # Kadinsky 2.1 - style
1028
+ if "image_embeds" not in added_cond_kwargs:
1029
+ raise ValueError(
1030
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1031
+ )
1032
+
1033
+ image_embeds = added_cond_kwargs.get("image_embeds")
1034
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1035
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1036
+ # Kandinsky 2.2 - style
1037
+ if "image_embeds" not in added_cond_kwargs:
1038
+ raise ValueError(
1039
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1040
+ )
1041
+ image_embeds = added_cond_kwargs.get("image_embeds")
1042
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1043
+ # 2. pre-process
1044
+ sample = rearrange(sample, "b c f h w -> (b f) c h w", f=sample.shape[2])
1045
+ sample = self.conv_in(sample)
1046
+ # 3. down
1047
+
1048
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1049
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
1050
+
1051
+ down_block_res_samples = (sample,)
1052
+ for downsample_block in self.down_blocks:
1053
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1054
+ # For t2i-adapter CrossAttnDownBlock2D
1055
+ additional_residuals = {}
1056
+ if is_adapter and len(down_block_additional_residuals) > 0:
1057
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
1058
+
1059
+ sample, res_samples = downsample_block(
1060
+ hidden_states=sample,
1061
+ temb=emb,
1062
+ encoder_hidden_states=encoder_hidden_states,
1063
+ attention_mask=attention_mask,
1064
+ cross_attention_kwargs=cross_attention_kwargs,
1065
+ encoder_attention_mask=encoder_attention_mask,
1066
+ **additional_residuals,
1067
+ )
1068
+ else:
1069
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1070
+
1071
+ if is_adapter and len(down_block_additional_residuals) > 0:
1072
+ sample += down_block_additional_residuals.pop(0)
1073
+
1074
+ down_block_res_samples += res_samples
1075
+
1076
+ if is_controlnet:
1077
+ new_down_block_res_samples = ()
1078
+
1079
+ for down_block_res_sample, down_block_additional_residual in zip(
1080
+ down_block_res_samples, down_block_additional_residuals
1081
+ ):
1082
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1083
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1084
+
1085
+ down_block_res_samples = new_down_block_res_samples
1086
+ # print("after down: ", sample.mean(), emb.mean())
1087
+
1088
+ # 4. mid
1089
+ if self.mid_block is not None:
1090
+ sample = self.mid_block(
1091
+ sample,
1092
+ emb,
1093
+ encoder_hidden_states=encoder_hidden_states,
1094
+ attention_mask=attention_mask,
1095
+ cross_attention_kwargs=cross_attention_kwargs,
1096
+ encoder_attention_mask=encoder_attention_mask,
1097
+ )
1098
+
1099
+ if is_controlnet:
1100
+ sample = sample + mid_block_additional_residual
1101
+
1102
+ # print("after mid: ", sample.mean())
1103
+ # 5. up
1104
+ for i, upsample_block in enumerate(self.up_blocks):
1105
+ is_final_block = i == len(self.up_blocks) - 1
1106
+
1107
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1108
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1109
+
1110
+ # if we have not reached the final block and need to forward the
1111
+ # upsample size, we do it here
1112
+ if not is_final_block and forward_upsample_size:
1113
+ upsample_size = down_block_res_samples[-1].shape[2:]
1114
+
1115
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1116
+ sample = upsample_block(
1117
+ hidden_states=sample,
1118
+ temb=emb,
1119
+ res_hidden_states_tuple=res_samples,
1120
+ encoder_hidden_states=encoder_hidden_states,
1121
+ cross_attention_kwargs=cross_attention_kwargs,
1122
+ upsample_size=upsample_size,
1123
+ attention_mask=attention_mask,
1124
+ encoder_attention_mask=encoder_attention_mask,
1125
+ )
1126
+ else:
1127
+ sample = upsample_block(
1128
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1129
+ )
1130
+
1131
+ # 6. post-process
1132
+ # if self.conv_norm_out:
1133
+ # sample = self.conv_norm_out(sample)
1134
+ # sample = self.conv_act(sample)
1135
+ # sample = self.conv_out(sample)
1136
+
1137
+ if not return_dict:
1138
+ return (sample,)
1139
+
1140
+ return UNetMV2DRefOutput(sample=sample)
1141
+
1142
+ @classmethod
1143
+ def from_pretrained_2d(
1144
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1145
+ camera_embedding_type: str, num_views: int, sample_size: int,
1146
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1147
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1148
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1149
+ cross_domain_attention: bool = False,
1150
+ in_channels: int = 8, out_channels: int = 4, local_crossattn=False,
1151
+ **kwargs
1152
+ ):
1153
+ r"""
1154
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1155
+
1156
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1157
+ train the model, set it back in training mode with `model.train()`.
1158
+
1159
+ Parameters:
1160
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1161
+ Can be either:
1162
+
1163
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1164
+ the Hub.
1165
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1166
+ with [`~ModelMixin.save_pretrained`].
1167
+
1168
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1169
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1170
+ is not used.
1171
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1172
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1173
+ dtype is automatically derived from the model's weights.
1174
+ force_download (`bool`, *optional*, defaults to `False`):
1175
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1176
+ cached versions if they exist.
1177
+ resume_download (`bool`, *optional*, defaults to `False`):
1178
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1179
+ incompletely downloaded files are deleted.
1180
+ proxies (`Dict[str, str]`, *optional*):
1181
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1182
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1183
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1184
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1185
+ local_files_only(`bool`, *optional*, defaults to `False`):
1186
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1187
+ won't be downloaded from the Hub.
1188
+ use_auth_token (`str` or *bool*, *optional*):
1189
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1190
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1191
+ revision (`str`, *optional*, defaults to `"main"`):
1192
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1193
+ allowed by Git.
1194
+ from_flax (`bool`, *optional*, defaults to `False`):
1195
+ Load the model weights from a Flax checkpoint save file.
1196
+ subfolder (`str`, *optional*, defaults to `""`):
1197
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1198
+ mirror (`str`, *optional*):
1199
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1200
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1201
+ information.
1202
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1203
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1204
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1205
+ same device.
1206
+
1207
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1208
+ more information about each option see [designing a device
1209
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1210
+ max_memory (`Dict`, *optional*):
1211
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1212
+ each GPU and the available CPU RAM if unset.
1213
+ offload_folder (`str` or `os.PathLike`, *optional*):
1214
+ The path to offload weights if `device_map` contains the value `"disk"`.
1215
+ offload_state_dict (`bool`, *optional*):
1216
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1217
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1218
+ when there is some disk offload.
1219
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1220
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1221
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1222
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1223
+ argument to `True` will raise an error.
1224
+ variant (`str`, *optional*):
1225
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1226
+ loading `from_flax`.
1227
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1228
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1229
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1230
+ weights. If set to `False`, `safetensors` weights are not loaded.
1231
+
1232
+ <Tip>
1233
+
1234
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1235
+ `huggingface-cli login`. You can also activate the special
1236
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1237
+ firewalled environment.
1238
+
1239
+ </Tip>
1240
+
1241
+ Example:
1242
+
1243
+ ```py
1244
+ from diffusers import UNet2DConditionModel
1245
+
1246
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1247
+ ```
1248
+
1249
+ If you get the error message below, you need to finetune the weights for your downstream task:
1250
+
1251
+ ```bash
1252
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1253
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1254
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1255
+ ```
1256
+ """
1257
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1258
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1259
+ force_download = kwargs.pop("force_download", False)
1260
+ from_flax = kwargs.pop("from_flax", False)
1261
+ resume_download = kwargs.pop("resume_download", False)
1262
+ proxies = kwargs.pop("proxies", None)
1263
+ output_loading_info = kwargs.pop("output_loading_info", False)
1264
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1265
+ use_auth_token = kwargs.pop("use_auth_token", None)
1266
+ revision = kwargs.pop("revision", None)
1267
+ torch_dtype = kwargs.pop("torch_dtype", None)
1268
+ subfolder = kwargs.pop("subfolder", None)
1269
+ device_map = kwargs.pop("device_map", None)
1270
+ max_memory = kwargs.pop("max_memory", None)
1271
+ offload_folder = kwargs.pop("offload_folder", None)
1272
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1273
+ variant = kwargs.pop("variant", None)
1274
+ use_safetensors = kwargs.pop("use_safetensors", None)
1275
+
1276
+ # if use_safetensors and not is_safetensors_available():
1277
+ # raise ValueError(
1278
+ # "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1279
+ # )
1280
+
1281
+ allow_pickle = False
1282
+ if use_safetensors is None:
1283
+ # use_safetensors = is_safetensors_available()
1284
+ use_safetensors = False
1285
+ allow_pickle = True
1286
+
1287
+ if device_map is not None and not is_accelerate_available():
1288
+ raise NotImplementedError(
1289
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1290
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1291
+ )
1292
+
1293
+ # Check if we can handle device_map and dispatching the weights
1294
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1295
+ raise NotImplementedError(
1296
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1297
+ " `device_map=None`."
1298
+ )
1299
+
1300
+ # Load config if we don't provide a configuration
1301
+ config_path = pretrained_model_name_or_path
1302
+
1303
+ user_agent = {
1304
+ "diffusers": __version__,
1305
+ "file_type": "model",
1306
+ "framework": "pytorch",
1307
+ }
1308
+
1309
+ # load config
1310
+ config, unused_kwargs, commit_hash = cls.load_config(
1311
+ config_path,
1312
+ cache_dir=cache_dir,
1313
+ return_unused_kwargs=True,
1314
+ return_commit_hash=True,
1315
+ force_download=force_download,
1316
+ resume_download=resume_download,
1317
+ proxies=proxies,
1318
+ local_files_only=local_files_only,
1319
+ use_auth_token=use_auth_token,
1320
+ revision=revision,
1321
+ subfolder=subfolder,
1322
+ device_map=device_map,
1323
+ max_memory=max_memory,
1324
+ offload_folder=offload_folder,
1325
+ offload_state_dict=offload_state_dict,
1326
+ user_agent=user_agent,
1327
+ **kwargs,
1328
+ )
1329
+
1330
+ # modify config
1331
+ config["_class_name"] = cls.__name__
1332
+ config['in_channels'] = in_channels
1333
+ config['out_channels'] = out_channels
1334
+ config['sample_size'] = sample_size # training resolution
1335
+ config['num_views'] = num_views
1336
+ config['joint_attention'] = joint_attention
1337
+ config['joint_attention_twice'] = joint_attention_twice
1338
+ config['multiview_attention'] = multiview_attention
1339
+ config['cross_domain_attention'] = cross_domain_attention
1340
+ config["down_block_types"] = [
1341
+ "CrossAttnDownBlockMV2D",
1342
+ "CrossAttnDownBlockMV2D",
1343
+ "CrossAttnDownBlockMV2D",
1344
+ "DownBlock2D"
1345
+ ]
1346
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1347
+ config["up_block_types"] = [
1348
+ "UpBlock2D",
1349
+ "CrossAttnUpBlockMV2D",
1350
+ "CrossAttnUpBlockMV2D",
1351
+ "CrossAttnUpBlockMV2D"
1352
+ ]
1353
+ config['class_embed_type'] = 'projection'
1354
+ if camera_embedding_type == 'e_de_da_sincos':
1355
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1356
+ else:
1357
+ raise NotImplementedError
1358
+
1359
+ # load model
1360
+ model_file = None
1361
+ if from_flax:
1362
+ raise NotImplementedError
1363
+ else:
1364
+ if use_safetensors:
1365
+ try:
1366
+ model_file = _get_model_file(
1367
+ pretrained_model_name_or_path,
1368
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1369
+ cache_dir=cache_dir,
1370
+ force_download=force_download,
1371
+ resume_download=resume_download,
1372
+ proxies=proxies,
1373
+ local_files_only=local_files_only,
1374
+ use_auth_token=use_auth_token,
1375
+ revision=revision,
1376
+ subfolder=subfolder,
1377
+ user_agent=user_agent,
1378
+ commit_hash=commit_hash,
1379
+ )
1380
+ except IOError as e:
1381
+ if not allow_pickle:
1382
+ raise e
1383
+ pass
1384
+ if model_file is None:
1385
+ model_file = _get_model_file(
1386
+ pretrained_model_name_or_path,
1387
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1388
+ cache_dir=cache_dir,
1389
+ force_download=force_download,
1390
+ resume_download=resume_download,
1391
+ proxies=proxies,
1392
+ local_files_only=local_files_only,
1393
+ use_auth_token=use_auth_token,
1394
+ revision=revision,
1395
+ subfolder=subfolder,
1396
+ user_agent=user_agent,
1397
+ commit_hash=commit_hash,
1398
+ )
1399
+
1400
+ model = cls.from_config(config, **unused_kwargs)
1401
+ if local_crossattn:
1402
+ unet_lora_attn_procs = dict()
1403
+ for name, _ in model.attn_processors.items():
1404
+ if not name.endswith("attn1.processor"):
1405
+ default_attn_proc = AttnProcessor()
1406
+ elif is_xformers_available():
1407
+ default_attn_proc = XFormersMVAttnProcessor()
1408
+ else:
1409
+ default_attn_proc = MVAttnProcessor()
1410
+ unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
1411
+ default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
1412
+ )
1413
+ model.set_attn_processor(unet_lora_attn_procs)
1414
+ state_dict = load_state_dict(model_file, variant=variant)
1415
+ model._convert_deprecated_attention_blocks(state_dict)
1416
+
1417
+ conv_in_weight = state_dict['conv_in.weight']
1418
+ conv_out_weight = state_dict['conv_out.weight']
1419
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1420
+ model,
1421
+ state_dict,
1422
+ model_file,
1423
+ pretrained_model_name_or_path,
1424
+ ignore_mismatched_sizes=True,
1425
+ )
1426
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1427
+ # initialize from the original SD structure
1428
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1429
+
1430
+ # whether to place all zero to new layers?
1431
+ if zero_init_conv_in:
1432
+ model.conv_in.weight.data[:,4:] = 0.
1433
+
1434
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1435
+ # initialize from the original SD structure
1436
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1437
+ if out_channels == 8: # copy for the last 4 channels
1438
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1439
+
1440
+ if zero_init_camera_projection:
1441
+ for p in model.class_embedding.parameters():
1442
+ torch.nn.init.zeros_(p)
1443
+
1444
+ loading_info = {
1445
+ "missing_keys": missing_keys,
1446
+ "unexpected_keys": unexpected_keys,
1447
+ "mismatched_keys": mismatched_keys,
1448
+ "error_msgs": error_msgs,
1449
+ }
1450
+
1451
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1452
+ raise ValueError(
1453
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1454
+ )
1455
+ elif torch_dtype is not None:
1456
+ model = model.to(torch_dtype)
1457
+
1458
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1459
+
1460
+ # Set model in evaluation mode to deactivate DropOut modules by default
1461
+ model.eval()
1462
+ if output_loading_info:
1463
+ return model, loading_info
1464
+
1465
+ return model
1466
+
1467
+ @classmethod
1468
+ def _load_pretrained_model_2d(
1469
+ cls,
1470
+ model,
1471
+ state_dict,
1472
+ resolved_archive_file,
1473
+ pretrained_model_name_or_path,
1474
+ ignore_mismatched_sizes=False,
1475
+ ):
1476
+ # Retrieve missing & unexpected_keys
1477
+ model_state_dict = model.state_dict()
1478
+ loaded_keys = list(state_dict.keys())
1479
+
1480
+ expected_keys = list(model_state_dict.keys())
1481
+
1482
+ original_loaded_keys = loaded_keys
1483
+
1484
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1485
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1486
+
1487
+ # Make sure we are able to load base models as well as derived models (with heads)
1488
+ model_to_load = model
1489
+
1490
+ def _find_mismatched_keys(
1491
+ state_dict,
1492
+ model_state_dict,
1493
+ loaded_keys,
1494
+ ignore_mismatched_sizes,
1495
+ ):
1496
+ mismatched_keys = []
1497
+ if ignore_mismatched_sizes:
1498
+ for checkpoint_key in loaded_keys:
1499
+ model_key = checkpoint_key
1500
+
1501
+ if (
1502
+ model_key in model_state_dict
1503
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1504
+ ):
1505
+ mismatched_keys.append(
1506
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1507
+ )
1508
+ del state_dict[checkpoint_key]
1509
+ return mismatched_keys
1510
+
1511
+ if state_dict is not None:
1512
+ # Whole checkpoint
1513
+ mismatched_keys = _find_mismatched_keys(
1514
+ state_dict,
1515
+ model_state_dict,
1516
+ original_loaded_keys,
1517
+ ignore_mismatched_sizes,
1518
+ )
1519
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1520
+
1521
+ if len(error_msgs) > 0:
1522
+ error_msg = "\n\t".join(error_msgs)
1523
+ if "size mismatch" in error_msg:
1524
+ error_msg += (
1525
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1526
+ )
1527
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1528
+
1529
+ if len(unexpected_keys) > 0:
1530
+ logger.warning(
1531
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1532
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1533
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1534
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1535
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1536
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1537
+ " identical (initializing a BertForSequenceClassification model from a"
1538
+ " BertForSequenceClassification model)."
1539
+ )
1540
+ else:
1541
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1542
+ if len(missing_keys) > 0:
1543
+ logger.warning(
1544
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1545
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1546
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1547
+ )
1548
+ elif len(mismatched_keys) == 0:
1549
+ logger.info(
1550
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1551
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1552
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1553
+ " without further training."
1554
+ )
1555
+ if len(mismatched_keys) > 0:
1556
+ mismatched_warning = "\n".join(
1557
+ [
1558
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1559
+ for key, shape1, shape2 in mismatched_keys
1560
+ ]
1561
+ )
1562
+ logger.warning(
1563
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1564
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1565
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1566
+ " able to use it for predictions and inference."
1567
+ )
1568
+
1569
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1570
+
2D_Stage/tuneavideo/pipelines/__pycache__/pipeline_tuneavideo.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
2D_Stage/tuneavideo/pipelines/pipeline_tuneavideo.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2
+
3
+ import tqdm
4
+
5
+ import inspect
6
+ from typing import Callable, List, Optional, Union
7
+ from dataclasses import dataclass
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from diffusers.utils import is_accelerate_available
13
+ from packaging import version
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ import torchvision.transforms.functional as TF
16
+
17
+ from diffusers.configuration_utils import FrozenDict
18
+ from diffusers.models import AutoencoderKL
19
+ from diffusers import DiffusionPipeline
20
+ from diffusers.schedulers import (
21
+ DDIMScheduler,
22
+ DPMSolverMultistepScheduler,
23
+ EulerAncestralDiscreteScheduler,
24
+ EulerDiscreteScheduler,
25
+ LMSDiscreteScheduler,
26
+ PNDMScheduler,
27
+ )
28
+ from diffusers.utils import deprecate, logging, BaseOutput
29
+
30
+ from einops import rearrange
31
+
32
+ from ..models.unet import UNet3DConditionModel
33
+ from torchvision.transforms import InterpolationMode
34
+
35
+ import ipdb
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ @dataclass
41
+ class TuneAVideoPipelineOutput(BaseOutput):
42
+ videos: Union[torch.Tensor, np.ndarray]
43
+
44
+
45
+ class TuneAVideoPipeline(DiffusionPipeline):
46
+ _optional_components = []
47
+
48
+ def __init__(
49
+ self,
50
+ vae: AutoencoderKL,
51
+ text_encoder: CLIPTextModel,
52
+ tokenizer: CLIPTokenizer,
53
+ unet: UNet3DConditionModel,
54
+
55
+ scheduler: Union[
56
+ DDIMScheduler,
57
+ PNDMScheduler,
58
+ LMSDiscreteScheduler,
59
+ EulerDiscreteScheduler,
60
+ EulerAncestralDiscreteScheduler,
61
+ DPMSolverMultistepScheduler,
62
+ ],
63
+ ref_unet = None,
64
+ feature_extractor=None,
65
+ image_encoder=None
66
+ ):
67
+ super().__init__()
68
+ self.ref_unet = ref_unet
69
+ self.feature_extractor = feature_extractor
70
+ self.image_encoder = image_encoder
71
+
72
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
73
+ deprecation_message = (
74
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
75
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
76
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
77
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
78
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
79
+ " file"
80
+ )
81
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
82
+ new_config = dict(scheduler.config)
83
+ new_config["steps_offset"] = 1
84
+ scheduler._internal_dict = FrozenDict(new_config)
85
+
86
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
87
+ deprecation_message = (
88
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
89
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
90
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
91
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
92
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
93
+ )
94
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
95
+ new_config = dict(scheduler.config)
96
+ new_config["clip_sample"] = False
97
+ scheduler._internal_dict = FrozenDict(new_config)
98
+
99
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
100
+ version.parse(unet.config._diffusers_version).base_version
101
+ ) < version.parse("0.9.0.dev0")
102
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
103
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
104
+ deprecation_message = (
105
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
106
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
107
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
108
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
109
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
110
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
111
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
112
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
113
+ " the `unet/config.json` file"
114
+ )
115
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
116
+ new_config = dict(unet.config)
117
+ new_config["sample_size"] = 64
118
+ unet._internal_dict = FrozenDict(new_config)
119
+
120
+ self.register_modules(
121
+ vae=vae,
122
+ text_encoder=text_encoder,
123
+ tokenizer=tokenizer,
124
+ unet=unet,
125
+ scheduler=scheduler,
126
+ )
127
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
128
+
129
+ def enable_vae_slicing(self):
130
+ self.vae.enable_slicing()
131
+
132
+ def disable_vae_slicing(self):
133
+ self.vae.disable_slicing()
134
+
135
+ def enable_sequential_cpu_offload(self, gpu_id=0):
136
+ if is_accelerate_available():
137
+ from accelerate import cpu_offload
138
+ else:
139
+ raise ImportError("Please install accelerate via `pip install accelerate`")
140
+
141
+ device = torch.device(f"cuda:{gpu_id}")
142
+
143
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
144
+ if cpu_offloaded_model is not None:
145
+ cpu_offload(cpu_offloaded_model, device)
146
+
147
+
148
+ @property
149
+ def _execution_device(self):
150
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
151
+ return self.device
152
+ for module in self.unet.modules():
153
+ if (
154
+ hasattr(module, "_hf_hook")
155
+ and hasattr(module._hf_hook, "execution_device")
156
+ and module._hf_hook.execution_device is not None
157
+ ):
158
+ return torch.device(module._hf_hook.execution_device)
159
+ return self.device
160
+
161
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance, img_proj=None):
162
+ dtype = next(self.image_encoder.parameters()).dtype
163
+
164
+ # image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
165
+ # image_pt = image_pt.to(device=device, dtype=dtype)
166
+ # image_embeddings = self.image_encoder(image_pt).image_embeds
167
+ # image_embeddings = image_embeddings.unsqueeze(1)
168
+
169
+ # # image encoding
170
+ clip_image_mean = torch.as_tensor(self.feature_extractor.image_mean)[:,None,None].to(device, dtype=torch.float32)
171
+ clip_image_std = torch.as_tensor(self.feature_extractor.image_std)[:,None,None].to(device, dtype=torch.float32)
172
+ imgs_in_proc = TF.resize(image_pil, (self.feature_extractor.crop_size['height'], self.feature_extractor.crop_size['width']), interpolation=InterpolationMode.BICUBIC)
173
+ # do the normalization in float32 to preserve precision
174
+ imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(dtype)
175
+ if img_proj is None:
176
+ # (B*Nv, 1, 768)
177
+ image_embeddings = self.image_encoder(imgs_in_proc).image_embeds.unsqueeze(1)
178
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
179
+ # Note: repeat differently from official pipelines
180
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
181
+ bs_embed, seq_len, _ = image_embeddings.shape
182
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
183
+ if do_classifier_free_guidance:
184
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
185
+
186
+ # For classifier free guidance, we need to do two forward passes.
187
+ # Here we concatenate the unconditional and text embeddings into a single batch
188
+ # to avoid doing two forward passes
189
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
190
+ else:
191
+ if do_classifier_free_guidance:
192
+ negative_image_proc = torch.zeros_like(imgs_in_proc)
193
+
194
+ # For classifier free guidance, we need to do two forward passes.
195
+ # Here we concatenate the unconditional and text embeddings into a single batch
196
+ # to avoid doing two forward passes
197
+ imgs_in_proc = torch.cat([negative_image_proc, imgs_in_proc])
198
+
199
+ image_embeds = image_encoder(imgs_in_proc, output_hidden_states=True).hidden_states[-2]
200
+ image_embeddings = img_proj(image_embeds)
201
+
202
+ # image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
203
+
204
+ # image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
205
+ # image_pil = image_pil * 2.0 - 1.0
206
+ image_latents = self.vae.encode(image_pil* 2.0 - 1.0).latent_dist.mode() * self.vae.config.scaling_factor
207
+
208
+ # Note: repeat differently from official pipelines
209
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
210
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
211
+
212
+ # if do_classifier_free_guidance:
213
+ # image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
214
+
215
+ return image_embeddings, image_latents
216
+
217
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
218
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
219
+
220
+ text_inputs = self.tokenizer(
221
+ prompt,
222
+ padding="max_length",
223
+ max_length=self.tokenizer.model_max_length,
224
+ truncation=True,
225
+ return_tensors="pt",
226
+ )
227
+ text_input_ids = text_inputs.input_ids
228
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
229
+
230
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
231
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
232
+ logger.warning(
233
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
234
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
235
+ )
236
+
237
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
238
+ attention_mask = text_inputs.attention_mask.to(device)
239
+ else:
240
+ attention_mask = None
241
+
242
+ text_embeddings = self.text_encoder(
243
+ text_input_ids.to(device),
244
+ attention_mask=attention_mask,
245
+ )
246
+ text_embeddings = text_embeddings[0]
247
+
248
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
249
+ bs_embed, seq_len, _ = text_embeddings.shape
250
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
251
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
252
+
253
+ # get unconditional embeddings for classifier free guidance
254
+ if do_classifier_free_guidance:
255
+ uncond_tokens: List[str]
256
+ if negative_prompt is None:
257
+ uncond_tokens = [""] * batch_size
258
+ elif type(prompt) is not type(negative_prompt):
259
+ raise TypeError(
260
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
261
+ f" {type(prompt)}."
262
+ )
263
+ elif isinstance(negative_prompt, str):
264
+ uncond_tokens = [negative_prompt]
265
+ elif batch_size != len(negative_prompt):
266
+ raise ValueError(
267
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
268
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
269
+ " the batch size of `prompt`."
270
+ )
271
+ else:
272
+ uncond_tokens = negative_prompt
273
+
274
+ max_length = text_input_ids.shape[-1]
275
+ uncond_input = self.tokenizer(
276
+ uncond_tokens,
277
+ padding="max_length",
278
+ max_length=max_length,
279
+ truncation=True,
280
+ return_tensors="pt",
281
+ )
282
+
283
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
284
+ attention_mask = uncond_input.attention_mask.to(device)
285
+ else:
286
+ attention_mask = None
287
+
288
+ uncond_embeddings = self.text_encoder(
289
+ uncond_input.input_ids.to(device),
290
+ attention_mask=attention_mask,
291
+ )
292
+ uncond_embeddings = uncond_embeddings[0]
293
+
294
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
295
+ seq_len = uncond_embeddings.shape[1]
296
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
297
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
298
+
299
+ # For classifier free guidance, we need to do two forward passes.
300
+ # Here we concatenate the unconditional and text embeddings into a single batch
301
+ # to avoid doing two forward passes
302
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
303
+
304
+ return text_embeddings
305
+
306
+ def decode_latents(self, latents):
307
+ video_length = latents.shape[2]
308
+ latents = 1 / 0.18215 * latents
309
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
310
+ video = self.vae.decode(latents).sample
311
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
312
+ video = (video / 2 + 0.5).clamp(0, 1)
313
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
314
+ video = video.cpu().float().numpy()
315
+ return video
316
+
317
+ def prepare_extra_step_kwargs(self, generator, eta):
318
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
319
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
320
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
321
+ # and should be between [0, 1]
322
+
323
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
324
+ extra_step_kwargs = {}
325
+ if accepts_eta:
326
+ extra_step_kwargs["eta"] = eta
327
+
328
+ # check if the scheduler accepts generator
329
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
330
+ if accepts_generator:
331
+ extra_step_kwargs["generator"] = generator
332
+ return extra_step_kwargs
333
+
334
+ def check_inputs(self, prompt, height, width, callback_steps):
335
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
336
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
337
+
338
+ if height % 8 != 0 or width % 8 != 0:
339
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
340
+
341
+ if (callback_steps is None) or (
342
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
343
+ ):
344
+ raise ValueError(
345
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
346
+ f" {type(callback_steps)}."
347
+ )
348
+
349
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
350
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
351
+ if isinstance(generator, list) and len(generator) != batch_size:
352
+ raise ValueError(
353
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
354
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
355
+ )
356
+
357
+ if latents is None:
358
+ rand_device = "cpu" if device.type == "mps" else device
359
+
360
+ if isinstance(generator, list):
361
+ shape = (1,) + shape[1:]
362
+ latents = [
363
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
364
+ for i in range(batch_size)
365
+ ]
366
+ latents = torch.cat(latents, dim=0).to(device)
367
+ else:
368
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
369
+ else:
370
+ if latents.shape != shape:
371
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
372
+ latents = latents.to(device)
373
+
374
+ # scale the initial noise by the standard deviation required by the scheduler
375
+ latents = latents * self.scheduler.init_noise_sigma
376
+ return latents
377
+
378
+ @torch.no_grad()
379
+ def __call__(
380
+ self,
381
+ prompt: Union[str, List[str]],
382
+ image: Union[str, List[str]],
383
+ video_length: Optional[int],
384
+ height: Optional[int] = None,
385
+ width: Optional[int] = None,
386
+ num_inference_steps: int = 50,
387
+ guidance_scale: float = 7.5,
388
+ negative_prompt: Optional[Union[str, List[str]]] = None,
389
+ num_videos_per_prompt: Optional[int] = 1,
390
+ eta: float = 0.0,
391
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
392
+ latents: Optional[torch.FloatTensor] = None,
393
+ output_type: Optional[str] = "tensor",
394
+ return_dict: bool = True,
395
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
396
+ callback_steps: Optional[int] = 1,
397
+ camera_matrixs = None,
398
+ class_labels = None,
399
+ prompt_ids = None,
400
+ unet_condition_type = None,
401
+ pose_guider = None,
402
+ pose_image = None,
403
+ img_proj=None,
404
+ use_noise=True,
405
+ use_shifted_noise=False,
406
+ rescale = 0.7,
407
+ **kwargs,
408
+ ):
409
+ # Default height and width to unet
410
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
411
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
412
+
413
+ # Check inputs. Raise error if not correct
414
+ self.check_inputs(prompt, height, width, callback_steps)
415
+ if isinstance(image, list):
416
+ batch_size = len(image)
417
+ else:
418
+ batch_size = image.shape[0]
419
+ # assert batch_size >= video_length and batch_size % video_length == 0
420
+ # Define call parameters
421
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
422
+ device = self._execution_device
423
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
424
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
425
+ # corresponds to doing no classifier free guidance.
426
+ do_classifier_free_guidance = guidance_scale > 1.0
427
+
428
+ # 3. Encode input image
429
+ # if isinstance(image, list):
430
+ # image_pil = image
431
+ # elif isinstance(image, torch.Tensor):
432
+ # image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
433
+ # encode input reference image
434
+ image_embeddings, image_latents = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, img_proj=img_proj) #torch.Size([64, 1, 768]) torch.Size([64, 4, 32, 32])
435
+ image_latents = rearrange(image_latents, "(b f) c h w -> b c f h w", f=1) #torch.Size([64, 4, 1, 32, 32])
436
+
437
+ # Encode input prompt_id
438
+ # encoder_hidden_states = self.text_encoder(prompt_ids)[0] #torch.Size([32, 77, 768])
439
+
440
+ # Encode input prompt
441
+ text_embeddings = self._encode_prompt( #torch.Size([64, 77, 768])
442
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
443
+ )
444
+
445
+ # Prepare timesteps
446
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
447
+ timesteps = self.scheduler.timesteps
448
+
449
+ # Prepare latent variables
450
+ num_channels_latents = self.unet.in_channels
451
+ latents = self.prepare_latents( #torch.Size([32, 4, 4, 32, 32])
452
+ batch_size * num_videos_per_prompt,
453
+ num_channels_latents,
454
+ video_length,
455
+ height,
456
+ width,
457
+ text_embeddings.dtype,
458
+ device,
459
+ generator,
460
+ latents,
461
+ )
462
+ latents_dtype = latents.dtype
463
+ # import ipdb
464
+ # ipdb.set_trace()
465
+ # Prepare extra step kwargs.
466
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
467
+ # prepare camera_matrix
468
+ if camera_matrixs is not None:
469
+ camera_matrixs = torch.cat([camera_matrixs] * 2) if do_classifier_free_guidance else camera_matrixs #(64, 4, 12)
470
+ # Denoising loop
471
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
472
+ if pose_guider is not None:
473
+ if len(pose_image.shape) == 5:
474
+ pose_embeds = pose_guider(rearrange(pose_image, "b f c h w -> (b f) c h w"))
475
+ pose_embeds = rearrange(pose_embeds, "(b f) c h w-> b c f h w ", f=video_length)
476
+ else:
477
+ pose_embeds = pose_guider(pose_image).unsqueeze(0)
478
+ pose_embeds = torch.cat([pose_embeds]*2, dim=0)
479
+ # import ipdb
480
+ # ipdb.set_trace()
481
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
482
+ for i, t in enumerate(tqdm.tqdm(timesteps)):
483
+ # expand the latents if we are doing classifier free guidance
484
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
485
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
486
+ if pose_guider is not None:
487
+ latent_model_input = latent_model_input + pose_embeds
488
+
489
+ noise_cond = torch.randn_like(image_latents)
490
+ if use_noise:
491
+ cond_latents = self.scheduler.add_noise(image_latents, noise_cond, t)
492
+ else:
493
+ cond_latents = image_latents
494
+ cond_latent_model_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents
495
+ cond_latent_model_input = self.scheduler.scale_model_input(cond_latent_model_input, t)
496
+
497
+ # predict the noise residual
498
+ # ref text condition
499
+ ref_dict = {}
500
+ if self.ref_unet is not None:
501
+ noise_pred_cond = self.ref_unet(
502
+ cond_latent_model_input, #torch.Size([64, 4, 1, 32, 32])
503
+ t, #torch.Size([32])
504
+ encoder_hidden_states=text_embeddings.to(torch.float32), #torch.Size([64, 77, 768])
505
+ cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict)
506
+ ).sample.to(dtype=latents_dtype)
507
+
508
+ # if torch.isnan(noise_pred_cond).any():
509
+ # ipdb.set_trace()
510
+ # Predict the noise residual and compute loss
511
+ # model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, camera_matrixs).sample
512
+ # unet
513
+ #text condition for unet
514
+ text_embeddings_unet = text_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1,1)
515
+ text_embeddings_unet = rearrange(text_embeddings_unet, 'B Nv d c -> (B Nv) d c')
516
+ #image condition for unet
517
+ image_embeddings_unet = image_embeddings.unsqueeze(1).repeat(1,latents.shape[2],1, 1)
518
+ image_embeddings_unet = rearrange(image_embeddings_unet, 'B Nv d c -> (B Nv) d c')
519
+
520
+ if unet_condition_type == 'text':
521
+ encoder_hidden_states_unet_cond = text_embeddings_unet
522
+ elif unet_condition_type == 'image':
523
+ encoder_hidden_states_unet_cond = image_embeddings_unet
524
+ else:
525
+ raise('need unet_condition_type')
526
+
527
+ if self.ref_unet is not None:
528
+ noise_pred = self.unet(
529
+ latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32])
530
+ t,
531
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
532
+ camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12])
533
+ cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
534
+ # cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
535
+ ).sample.to(dtype=latents_dtype)
536
+ else:
537
+ noise_pred = self.unet(
538
+ latent_model_input.to(torch.float32), #torch.Size([64, 4, 4, 32, 32])
539
+ t,
540
+ encoder_hidden_states=encoder_hidden_states_unet_cond.to(torch.float32),
541
+ camera_matrixs=camera_matrixs.to(torch.float32), #torch.Size([64, 4, 12])
542
+ # cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
543
+ cross_attention_kwargs=dict(mode="n", ref_dict=ref_dict, is_cfg_guidance=do_classifier_free_guidance)
544
+ ).sample.to(dtype=latents_dtype)
545
+ # perform guidance
546
+ if do_classifier_free_guidance:
547
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
548
+ if use_shifted_noise:
549
+ # Apply regular classifier-free guidance.
550
+ cfg = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
551
+ # Calculate standard deviations.
552
+ std_pos = noise_pred_text.std([1,2,3], keepdim=True)
553
+ std_cfg = cfg.std([1,2,3], keepdim=True)
554
+ # Apply guidance rescale with fused operations.
555
+ factor = std_pos / std_cfg
556
+ factor = rescale * factor + (1 - rescale)
557
+ noise_pred = cfg * factor
558
+ else:
559
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
560
+ # noise_pred_uncond_, noise_pred_text_ = noise_pred_cond.chunk(2)
561
+ # noise_pred_cond = noise_pred_uncond_ + guidance_scale * (noise_pred_text_ - noise_pred_uncond_)
562
+
563
+ # compute the previous noisy sample x_t -> x_t-1
564
+ noise_pred = rearrange(noise_pred, "(b f) c h w -> b c f h w", f=video_length)
565
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
566
+ # noise_pred_cond = rearrange(noise_pred_cond, "(b f) c h w -> b c f h w", f=1)
567
+ # cond_latents = self.scheduler.step(noise_pred_cond, t, cond_latents, **extra_step_kwargs).prev_sample
568
+
569
+ # call the callback, if provided
570
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
571
+ progress_bar.update()
572
+ if callback is not None and i % callback_steps == 0:
573
+ callback(i, t, latents)
574
+
575
+ # Post-processing
576
+ video = self.decode_latents(latents)
577
+
578
+ # Convert to tensor
579
+ if output_type == "tensor":
580
+ video = torch.from_numpy(video)
581
+
582
+ if not return_dict:
583
+ return video
584
+
585
+ return TuneAVideoPipelineOutput(videos=video)
2D_Stage/tuneavideo/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+ import cv2
6
+ import torch
7
+ import torchvision
8
+
9
+ from tqdm import tqdm
10
+ from einops import rearrange
11
+
12
+ def shifted_noise(betas, image_d=512, noise_d=256, shifted_noise=True):
13
+ alphas = 1 - betas
14
+ alphas_bar = torch.cumprod(alphas, dim=0)
15
+ d = (image_d / noise_d) ** 2
16
+ if shifted_noise:
17
+ alphas_bar = alphas_bar / (d - (d - 1) * alphas_bar)
18
+ alphas_bar_sqrt = torch.sqrt(alphas_bar)
19
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
20
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
21
+ # Shift so last timestep is zero.
22
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
23
+ # Scale so first timestep is back to old value.
24
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
25
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
26
+
27
+ # Convert alphas_bar_sqrt to betas
28
+ alphas_bar = alphas_bar_sqrt ** 2
29
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
30
+ alphas = torch.cat([alphas_bar[0:1], alphas])
31
+ betas = 1 - alphas
32
+ return betas
33
+
34
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
35
+ videos = rearrange(videos, "b c t h w -> t b c h w")
36
+ outputs = []
37
+ for x in videos:
38
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
39
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
40
+ if rescale:
41
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
42
+ x = (x * 255).numpy().astype(np.uint8)
43
+ outputs.append(x)
44
+
45
+ os.makedirs(os.path.dirname(path), exist_ok=True)
46
+ imageio.mimsave(path, outputs, duration=1000/fps)
47
+
48
+ def save_imgs_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
49
+ videos = rearrange(videos, "b c t h w -> t b c h w")
50
+ for i, x in enumerate(videos):
51
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
52
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
53
+ if rescale:
54
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
55
+ x = (x * 255).numpy().astype(np.uint8)
56
+ os.makedirs(os.path.dirname(path), exist_ok=True)
57
+ cv2.imwrite(os.path.join(path, f'view_{i}.png'), x[:,:,::-1])
58
+
59
+ def imgs_grid(videos: torch.Tensor, rescale=False, n_rows=4, fps=8):
60
+ videos = rearrange(videos, "b c t h w -> t b c h w")
61
+ image_list = []
62
+ for i, x in enumerate(videos):
63
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
64
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
65
+ if rescale:
66
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
67
+ x = (x * 255).numpy().astype(np.uint8)
68
+ # image_list.append(x[:,:,::-1])
69
+ image_list.append(x)
70
+ return image_list
71
+
72
+ # DDIM Inversion
73
+ @torch.no_grad()
74
+ def init_prompt(prompt, pipeline):
75
+ uncond_input = pipeline.tokenizer(
76
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
77
+ return_tensors="pt"
78
+ )
79
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
80
+ text_input = pipeline.tokenizer(
81
+ [prompt],
82
+ padding="max_length",
83
+ max_length=pipeline.tokenizer.model_max_length,
84
+ truncation=True,
85
+ return_tensors="pt",
86
+ )
87
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
88
+ context = torch.cat([uncond_embeddings, text_embeddings])
89
+
90
+ return context
91
+
92
+
93
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
94
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
95
+ timestep, next_timestep = min(
96
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
97
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
98
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
99
+ beta_prod_t = 1 - alpha_prod_t
100
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
101
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
102
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
103
+ return next_sample
104
+
105
+
106
+ def get_noise_pred_single(latents, t, context, unet):
107
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
108
+ return noise_pred
109
+
110
+
111
+ @torch.no_grad()
112
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
113
+ context = init_prompt(prompt, pipeline)
114
+ uncond_embeddings, cond_embeddings = context.chunk(2)
115
+ all_latent = [latent]
116
+ latent = latent.clone().detach()
117
+ for i in tqdm(range(num_inv_steps)):
118
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
119
+ noise_pred = get_noise_pred_single(latent.to(torch.float32), t, cond_embeddings.to(torch.float32), pipeline.unet)
120
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
121
+ all_latent.append(latent)
122
+ return all_latent
123
+
124
+
125
+ @torch.no_grad()
126
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
127
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
128
+ return ddim_latents
2D_Stage/webui.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import glob
4
+
5
+ import io
6
+ import argparse
7
+ import inspect
8
+ import os
9
+ import random
10
+ from typing import Dict, Optional, Tuple
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.utils.checkpoint
16
+
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import AutoencoderKL, DDIMScheduler
20
+ from diffusers.utils import check_min_version
21
+ from tqdm.auto import tqdm
22
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection
23
+ from torchvision import transforms
24
+
25
+ from tuneavideo.models.unet_mv2d_condition import UNetMV2DConditionModel
26
+ from tuneavideo.models.unet_mv2d_ref import UNetMV2DRefModel
27
+ from tuneavideo.models.PoseGuider import PoseGuider
28
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
29
+ from tuneavideo.util import shifted_noise
30
+ from einops import rearrange
31
+ import PIL
32
+ from PIL import Image
33
+ from torchvision.utils import save_image
34
+ import json
35
+ import cv2
36
+
37
+ import onnxruntime as rt
38
+ from huggingface_hub.file_download import hf_hub_download
39
+ from rm_anime_bg.cli import get_mask, SCALE
40
+
41
+ from huggingface_hub import hf_hub_download, list_repo_files
42
+
43
+ repo_id = "zjpshadow/CharacterGen"
44
+ all_files = list_repo_files(repo_id, revision="main")
45
+
46
+ for file in all_files:
47
+ if os.path.exists("../" + file):
48
+ continue
49
+ if file.startswith("2D_Stage"):
50
+ hf_hub_download(repo_id, file, local_dir="../")
51
+
52
+ class rm_bg_api:
53
+
54
+ def __init__(self, force_cpu: Optional[bool] = True):
55
+ session_infer_path = hf_hub_download(
56
+ repo_id="skytnt/anime-seg", filename="isnetis.onnx",
57
+ )
58
+ providers: list[str] = ["CPUExecutionProvider"]
59
+ if not force_cpu and "CUDAExecutionProvider" in rt.get_available_providers():
60
+ providers = ["CUDAExecutionProvider"]
61
+
62
+ self.session_infer = rt.InferenceSession(
63
+ session_infer_path, providers=providers,
64
+ )
65
+
66
+ def remove_background(
67
+ self,
68
+ imgs: list[np.ndarray],
69
+ alpha_min: float,
70
+ alpha_max: float,
71
+ ) -> list:
72
+ process_imgs = []
73
+ for img in imgs:
74
+ # CHANGE to RGB
75
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
76
+ mask = get_mask(self.session_infer, img)
77
+
78
+ mask[mask < alpha_min] = 0.0 # type: ignore
79
+ mask[mask > alpha_max] = 1.0 # type: ignore
80
+
81
+ img_after = (mask * img + SCALE * (1 - mask)).astype(np.uint8) # type: ignore
82
+ mask = (mask * SCALE).astype(np.uint8) # type: ignore
83
+ img_after = np.concatenate([img_after, mask], axis=2, dtype=np.uint8)
84
+ mask = mask.repeat(3, axis=2)
85
+ process_imgs.append(Image.fromarray(img_after))
86
+ return process_imgs
87
+
88
+ check_min_version("0.24.0")
89
+
90
+ logger = get_logger(__name__, log_level="INFO")
91
+
92
+ def set_seed(seed):
93
+ random.seed(seed)
94
+ np.random.seed(seed)
95
+ torch.manual_seed(seed)
96
+ torch.cuda.manual_seed_all(seed)
97
+
98
+ def get_bg_color(bg_color):
99
+ if bg_color == 'white':
100
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
101
+ elif bg_color == 'black':
102
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
103
+ elif bg_color == 'gray':
104
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
105
+ elif bg_color == 'random':
106
+ bg_color = np.random.rand(3)
107
+ elif isinstance(bg_color, float):
108
+ bg_color = np.array([bg_color] * 3, dtype=np.float32)
109
+ else:
110
+ raise NotImplementedError
111
+ return bg_color
112
+
113
+ def process_image(image, totensor):
114
+ if not image.mode == "RGBA":
115
+ image = image.convert("RGBA")
116
+
117
+ # Find non-transparent pixels
118
+ non_transparent = np.nonzero(np.array(image)[..., 3])
119
+ min_x, max_x = non_transparent[1].min(), non_transparent[1].max()
120
+ min_y, max_y = non_transparent[0].min(), non_transparent[0].max()
121
+ image = image.crop((min_x, min_y, max_x, max_y))
122
+
123
+ # paste to center
124
+ max_dim = max(image.width, image.height)
125
+ max_height = max_dim
126
+ max_width = int(max_dim / 3 * 2)
127
+ new_image = Image.new("RGBA", (max_width, max_height))
128
+ left = (max_width - image.width) // 2
129
+ top = (max_height - image.height) // 2
130
+ new_image.paste(image, (left, top))
131
+
132
+ image = new_image.resize((512, 768), resample=PIL.Image.BICUBIC)
133
+ image = np.array(image)
134
+ image = image.astype(np.float32) / 255.
135
+ assert image.shape[-1] == 4 # RGBA
136
+ alpha = image[..., 3:4]
137
+ bg_color = get_bg_color("gray")
138
+ image = image[..., :3] * alpha + bg_color * (1 - alpha)
139
+ # save image
140
+ # new_image = Image.fromarray((image * 255).astype(np.uint8))
141
+ # new_image.save("input.png")
142
+ return totensor(image)
143
+
144
+ class Inference_API:
145
+
146
+ def __init__(self):
147
+ self.validation_pipeline = None
148
+
149
+ @torch.no_grad()
150
+ def inference(self, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path, generator, validation, val_width, val_height, unet_condition_type,
151
+ pose_guider=None, use_noise=True, use_shifted_noise=False, noise_d=256, crop=False, seed=100, timestep=20):
152
+ set_seed(seed)
153
+ # Get the validation pipeline
154
+ if self.validation_pipeline is None:
155
+ noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
156
+ if use_shifted_noise:
157
+ print(f"enable shifted noise for {val_height} to {noise_d}")
158
+ betas = shifted_noise(noise_scheduler.betas, image_d=val_height, noise_d=noise_d)
159
+ noise_scheduler.betas = betas
160
+ noise_scheduler.alphas = 1 - betas
161
+ noise_scheduler.alphas_cumprod = torch.cumprod(noise_scheduler.alphas, dim=0)
162
+ self.validation_pipeline = TuneAVideoPipeline(
163
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, ref_unet=ref_unet,feature_extractor=feature_extractor,image_encoder=image_encoder,
164
+ scheduler=noise_scheduler
165
+ )
166
+ self.validation_pipeline.enable_vae_slicing()
167
+ self.validation_pipeline.set_progress_bar_config(disable=True)
168
+
169
+ totensor = transforms.ToTensor()
170
+
171
+ metas = json.load(open("./material/pose.json", "r"))
172
+ cameras = []
173
+ pose_images = []
174
+ input_path = "./material"
175
+ for lm in metas:
176
+ cameras.append(torch.tensor(np.array(lm[0]).reshape(4, 4).transpose(1,0)[:3, :4]).reshape(-1))
177
+ if not crop:
178
+ pose_images.append(totensor(np.asarray(Image.open(os.path.join(input_path, lm[1])).resize(
179
+ (val_height, val_width), resample=PIL.Image.BICUBIC)).astype(np.float32) / 255.))
180
+ else:
181
+ pose_image = Image.open(os.path.join(input_path, lm[1]))
182
+ crop_area = (128, 0, 640, 768)
183
+ pose_images.append(totensor(np.array(pose_image.crop(crop_area)).astype(np.float32)) / 255.)
184
+ camera_matrixs = torch.stack(cameras).unsqueeze(0).to("cuda")
185
+ pose_imgs_in = torch.stack(pose_images).to("cuda")
186
+ prompts = "high quality, best quality"
187
+ prompt_ids = tokenizer(
188
+ prompts, max_length=tokenizer.model_max_length, padding="max_length", truncation=True,
189
+ return_tensors="pt"
190
+ ).input_ids[0]
191
+
192
+ # (B*Nv, 3, H, W)
193
+ B = 1
194
+ weight_dtype = torch.bfloat16
195
+ imgs_in = process_image(input_image, totensor)
196
+ imgs_in = rearrange(imgs_in.unsqueeze(0).unsqueeze(0), "B Nv C H W -> (B Nv) C H W")
197
+
198
+ with torch.autocast("cuda", dtype=weight_dtype):
199
+ imgs_in = imgs_in.to("cuda")
200
+ # B*Nv images
201
+ out = self.validation_pipeline(prompt=prompts, image=imgs_in.to(weight_dtype), generator=generator,
202
+ num_inference_steps=timestep,
203
+ camera_matrixs=camera_matrixs.to(weight_dtype), prompt_ids=prompt_ids,
204
+ height=val_height, width=val_width, unet_condition_type=unet_condition_type,
205
+ pose_guider=None, pose_image=pose_imgs_in, use_noise=use_noise,
206
+ use_shifted_noise=use_shifted_noise, **validation).videos
207
+ out = rearrange(out, "B C f H W -> (B f) C H W", f=validation.video_length)
208
+
209
+ image_outputs = []
210
+ for bs in range(4):
211
+ img_buf = io.BytesIO()
212
+ save_image(out[bs], img_buf, format='PNG')
213
+ img_buf.seek(0)
214
+ img = Image.open(img_buf)
215
+ image_outputs.append(img)
216
+ torch.cuda.empty_cache()
217
+ return image_outputs
218
+
219
+ @torch.no_grad()
220
+ def main(
221
+ pretrained_model_path: str,
222
+ image_encoder_path: str,
223
+ ckpt_dir: str,
224
+ validation: Dict,
225
+ local_crossattn: bool = True,
226
+ unet_from_pretrained_kwargs=None,
227
+ unet_condition_type=None,
228
+ use_pose_guider=False,
229
+ use_noise=True,
230
+ use_shifted_noise=False,
231
+ noise_d=256
232
+ ):
233
+ *_, config = inspect.getargvalues(inspect.currentframe())
234
+
235
+ device = "cuda"
236
+
237
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
238
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
239
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path)
240
+ feature_extractor = CLIPImageProcessor()
241
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
242
+ unet = UNetMV2DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
243
+ ref_unet = UNetMV2DRefModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", local_crossattn=local_crossattn, **unet_from_pretrained_kwargs)
244
+ if use_pose_guider:
245
+ pose_guider = PoseGuider(noise_latent_channels=4).to("cuda")
246
+ else:
247
+ pose_guider = None
248
+
249
+ unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"), map_location="cpu")
250
+ if use_pose_guider:
251
+ pose_guider_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
252
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_2.bin"), map_location="cpu")
253
+ pose_guider.load_state_dict(pose_guider_params)
254
+ else:
255
+ ref_unet_params = torch.load(os.path.join(ckpt_dir, "pytorch_model_1.bin"), map_location="cpu")
256
+ unet.load_state_dict(unet_params)
257
+ ref_unet.load_state_dict(ref_unet_params)
258
+
259
+ weight_dtype = torch.float16
260
+
261
+ text_encoder.to(device, dtype=weight_dtype)
262
+ image_encoder.to(device, dtype=weight_dtype)
263
+ vae.to(device, dtype=weight_dtype)
264
+ ref_unet.to(device, dtype=weight_dtype)
265
+ unet.to(device, dtype=weight_dtype)
266
+
267
+ vae.requires_grad_(False)
268
+ unet.requires_grad_(False)
269
+ ref_unet.requires_grad_(False)
270
+
271
+ generator = torch.Generator(device="cuda")
272
+ inferapi = Inference_API()
273
+ remove_api = rm_bg_api()
274
+ def gen4views(image, width, height, seed, timestep, remove_bg):
275
+ if remove_bg:
276
+ image = remove_api.remove_background(
277
+ imgs=[np.array(image)],
278
+ alpha_min=0.1,
279
+ alpha_max=0.9,
280
+ )[0]
281
+ return inferapi.inference(
282
+ image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer, text_encoder, pretrained_model_path,
283
+ generator, validation, width, height, unet_condition_type,
284
+ pose_guider=pose_guider, use_noise=use_noise, use_shifted_noise=use_shifted_noise, noise_d=noise_d,
285
+ crop=True, seed=seed, timestep=timestep
286
+ )
287
+
288
+ with gr.Blocks() as demo:
289
+ gr.Markdown("# [SIGGRAPH'24] CharacterGen: Efficient 3D Character Generation from Single Images with Multi-View Pose Calibration")
290
+ gr.Markdown("# 2D Stage: One Image to Four Views of Character Image")
291
+ gr.Markdown("**Please Upload the Image without background, and the pictures uploaded should preferably be full-body frontal photos.**")
292
+ with gr.Row():
293
+ with gr.Column():
294
+ img_input = gr.Image(type="pil", label="Upload Image(without background)", image_mode="RGBA", width=768, height=512)
295
+ gr.Examples(
296
+ label="Example Images",
297
+ examples=glob.glob("./material/examples/*.png"),
298
+ inputs=[img_input]
299
+ )
300
+ with gr.Row():
301
+ width_input = gr.Number(label="Width", value=512)
302
+ height_input = gr.Number(label="Height", value=768)
303
+ seed_input = gr.Number(label="Seed", value=2333)
304
+ remove_bg = gr.Checkbox(label="Remove Background (with algorithm)", value=False)
305
+ timestep = gr.Slider(minimum=10, maximum=70, step=1, value=40, label="Timesteps")
306
+ with gr.Column():
307
+ button = gr.Button(value="Generate")
308
+ output = gr.Gallery(label="4 views of Character Image")
309
+
310
+ button.click(
311
+ fn=gen4views,
312
+ inputs=[img_input, width_input, height_input, seed_input, timestep, remove_bg],
313
+ outputs=[output]
314
+ )
315
+
316
+ demo.launch()
317
+
318
+ if __name__ == "__main__":
319
+ parser = argparse.ArgumentParser()
320
+ parser.add_argument("--config", type=str, default="./configs/infer.yaml")
321
+ args = parser.parse_args()
322
+
323
+ main(**OmegaConf.load(args.config))
3D_Stage/__pycache__/refine.cpython-310.pyc ADDED
Binary file (6.18 kB). View file
 
3D_Stage/configs/infer.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_cls: lrm.systems.multiview_lrm.MultiviewLRM
2
+ data:
3
+ cond_width: 504
4
+ cond_height: 504
5
+
6
+ system:
7
+ weights: ./models/lrm.ckpt
8
+
9
+ weights_ignore_modules:
10
+ - decoder.heads.density
11
+
12
+ check_train_every_n_steps: 100
13
+
14
+ camera_embedder_cls: lrm.models.camera.LinearCameraEmbedder
15
+ camera_embedder:
16
+ in_channels: 16
17
+ out_channels: 768
18
+ conditions:
19
+ - c2w_cond
20
+
21
+ # image tokenizer transforms input images to tokens
22
+ image_tokenizer_cls: lrm.models.tokenizers.image.DINOV2SingleImageTokenizer
23
+ image_tokenizer:
24
+ pretrained_model_name_or_path: "./models/base"
25
+ freeze_backbone_params: false
26
+ enable_memory_efficient_attention: true
27
+ enable_gradient_checkpointing: true
28
+ # camera modulation to the DINO transformer layers
29
+ modulation: true
30
+ modulation_zero_init: true
31
+ modulation_single_layer: true
32
+ modulation_cond_dim: ${system.camera_embedder.out_channels}
33
+
34
+ # tokenizer gives a tokenized representation for the 3D scene
35
+ # triplane tokens in this case
36
+ tokenizer_cls: lrm.models.tokenizers.triplane.TriplaneLearnablePositionalEmbedding
37
+ tokenizer:
38
+ plane_size: 32
39
+ num_channels: 512
40
+
41
+ # backbone network is a transformer that takes scene tokens (potentially with conditional image tokens)
42
+ # and outputs scene tokens of the same size
43
+ backbone_cls: lrm.models.transformers.transformer_1d.Transformer1D
44
+ backbone:
45
+ in_channels: ${system.tokenizer.num_channels}
46
+ num_attention_heads: 16
47
+ attention_head_dim: 64
48
+ num_layers: 12
49
+ cross_attention_dim: 768 # hard-code, =DINO feature dim
50
+ # camera modulation to the transformer layers
51
+ # if not needed, set norm_type=layer_norm and do not specify cond_dim_ada_norm_continuous
52
+ norm_type: "layer_norm"
53
+ enable_memory_efficient_attention: true
54
+ gradient_checkpointing: true
55
+
56
+ # post processor takes scene tokens and outputs the final scene parameters that will be used for rendering
57
+ # in this case, triplanes are upsampled and the features are condensed
58
+ post_processor_cls: lrm.models.networks.TriplaneUpsampleNetwork
59
+ post_processor:
60
+ in_channels: 512
61
+ out_channels: 80
62
+
63
+ renderer_cls: lrm.models.renderers.triplane_dmtet.TriplaneDMTetRenderer
64
+ renderer:
65
+ radius: 0.6 # slightly larger than 0.5
66
+ feature_reduction: concat
67
+ sdf_bias: -2.
68
+ tet_dir: "./load/tets/"
69
+ isosurface_resolution: 256
70
+ enable_isosurface_grid_deformation: false
71
+ sdf_activation: negative
72
+
73
+ decoder_cls: lrm.models.networks.MultiHeadMLP
74
+ decoder:
75
+ in_channels: 240 # 3 * 80
76
+ n_neurons: 64
77
+ n_hidden_layers_share: 8
78
+ heads:
79
+ - name: sdf
80
+ out_channels: 1
81
+ n_hidden_layers: 1
82
+ output_activation: null
83
+ - name: features
84
+ out_channels: 3
85
+ n_hidden_layers: 1
86
+ output_activation: null # activate in material
87
+ activation: silu
88
+ chunk_mode: deferred
89
+ chunk_size: 131072
90
+
91
+ exporter:
92
+ fmt: "obj"
93
+ #visual: "vertex"
94
+ visual: "uv"
95
+ save_uv: True
96
+ save_texture: True
97
+ uv_unwrap_method: "open3d"
98
+ output_path: "./outputs"
99
+
100
+ material_cls: lrm.models.materials.no_material.NoMaterial
101
+
102
+ background_cls: lrm.models.background.solid_color_background.SolidColorBackground
103
+ background:
104
+ color: [0.5, 0.5, 0.5]
3D_Stage/load/tets/128_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daa82da88746777043efe2182a4ff01843dbe4400cb34f53c8e2f5da8d35569d
3
+ size 7565405
3D_Stage/load/tets/256_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5822cae907aba088af41fed74461105f8864c05d58e557c82ca40561497db4b3
3
+ size 63136604
3D_Stage/load/tets/32_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d76349b760e99afd4ecbbfb20e421d134385a87867dca34ec21093e8fe4b2b72
3
+ size 124137
3D_Stage/load/tets/64_tets.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5dc0f19e87c275b54b78023c931b264bb34a8ea9804f82b991dde7fa99fbaee
3
+ size 957742
3D_Stage/load/tets/generate_tets.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import os
11
+
12
+ import numpy as np
13
+
14
+ """
15
+ This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
16
+ to generate a tet grid
17
+ 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
18
+ 2) Run the function below to generate a file `cube_32_tet.tet`
19
+ """
20
+
21
+
22
+ def generate_tetrahedron_grid_file(res=32, root=".."):
23
+ frac = 1.0 / res
24
+ command = f"cd {root}; ./quartet meshes/cube.obj {frac} meshes/cube_{res}_tet.tet -s meshes/cube_boundary_{res}.obj"
25
+ os.system(command)
26
+
27
+
28
+ """
29
+ This code segment shows how to convert from a quartet .tet file to compressed npz file
30
+ """
31
+
32
+
33
+ def convert_from_quartet_to_npz(quartetfile="cube_32_tet.tet", npzfile="32_tets"):
34
+ file1 = open(quartetfile, "r")
35
+ header = file1.readline()
36
+ numvertices = int(header.split(" ")[1])
37
+ numtets = int(header.split(" ")[2])
38
+ print(numvertices, numtets)
39
+
40
+ # load vertices
41
+ vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
42
+ print(vertices.shape)
43
+
44
+ # load indices
45
+ indices = np.loadtxt(
46
+ quartetfile, dtype=int, skiprows=1 + numvertices, max_rows=numtets
47
+ )
48
+ print(indices.shape)
49
+
50
+ np.savez_compressed(npzfile, vertices=vertices, indices=indices)
51
+
52
+
53
+ root = "/home/gyc/quartet"
54
+ for res in [300, 350, 400]:
55
+ generate_tetrahedron_grid_file(res, root)
56
+ convert_from_quartet_to_npz(
57
+ os.path.join(root, f"meshes/cube_{res}_tet.tet"), npzfile=f"{res}_tets"
58
+ )
3D_Stage/lrm/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def find(cls_string):
5
+ module_string = ".".join(cls_string.split(".")[:-1])
6
+ cls_name = cls_string.split(".")[-1]
7
+ module = importlib.import_module(module_string, package=None)
8
+ cls = getattr(module, cls_name)
9
+ return cls
10
+
11
+
12
+ ### grammar sugar for logging utilities ###
13
+ import logging
14
+
15
+ logger = logging.getLogger("pytorch_lightning")
16
+
17
+ from pytorch_lightning.utilities.rank_zero import (
18
+ rank_zero_debug,
19
+ rank_zero_info,
20
+ rank_zero_only,
21
+ )
22
+
23
+ debug = rank_zero_debug
24
+ info = rank_zero_info
25
+
26
+
27
+ @rank_zero_only
28
+ def warn(*args, **kwargs):
29
+ logger.warn(*args, **kwargs)
3D_Stage/lrm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (787 Bytes). View file
 
3D_Stage/lrm/models/__init__.py ADDED
File without changes