Spaces:
Build error
Build error
Upload 42 files
Browse files- README.md +1 -1
- animatelcm/models/attention.py +296 -0
- animatelcm/models/embeddings.py +213 -0
- animatelcm/models/motion_module.py +337 -0
- animatelcm/models/resnet.py +313 -0
- animatelcm/models/unet.py +568 -0
- animatelcm/models/unet_blocks.py +904 -0
- animatelcm/pipelines/pipeline_animation.py +456 -0
- animatelcm/scheduler/lcm_scheduler.py +722 -0
- animatelcm/utils/convert_from_ckpt.py +951 -0
- animatelcm/utils/convert_lora_safetensor_to_diffusers.py +152 -0
- animatelcm/utils/lcm_utils.py +237 -0
- animatelcm/utils/util.py +153 -0
- app.py +392 -0
- models/.DS_Store +0 -0
- models/DreamBooth_LoRA/cartoon2d.safetensors +3 -0
- models/DreamBooth_LoRA/cartoon3d.safetensors +3 -0
- models/DreamBooth_LoRA/realistic1.safetensors +3 -0
- models/DreamBooth_LoRA/realistic2.safetensors +3 -0
- models/LCM_LoRA/Put LCMLoRA checkpoints here.txt +0 -0
- models/LCM_LoRA/sd15_t2v_beta_lora.safetensors +3 -0
- models/Motion_Module/Put motion module checkpoints here.txt +0 -0
- models/Motion_Module/sd15_t2v_beta_motion.ckpt +3 -0
- models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt +0 -0
- models/StableDiffusion/stable-diffusion-v1-5/.gitattributes +35 -0
- models/StableDiffusion/stable-diffusion-v1-5/README.md +207 -0
- models/StableDiffusion/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json +20 -0
- models/StableDiffusion/stable-diffusion-v1-5/model_index.json +32 -0
- models/StableDiffusion/stable-diffusion-v1-5/safety_checker/config.json +175 -0
- models/StableDiffusion/stable-diffusion-v1-5/scheduler/scheduler_config.json +13 -0
- models/StableDiffusion/stable-diffusion-v1-5/text_encoder/config.json +25 -0
- models/StableDiffusion/stable-diffusion-v1-5/text_encoder/model.safetensors +3 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/merges.txt +0 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/special_tokens_map.json +24 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/tokenizer_config.json +34 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/vocab.json +0 -0
- models/StableDiffusion/stable-diffusion-v1-5/unet/config.json +36 -0
- models/StableDiffusion/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin +3 -0
- models/StableDiffusion/stable-diffusion-v1-5/v1-inference.yaml +70 -0
- models/StableDiffusion/stable-diffusion-v1-5/vae/config.json +29 -0
- models/StableDiffusion/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin +3 -0
- requirements.txt +15 -0
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
|
|
4 |
colorFrom: red
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
4 |
colorFrom: red
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.48.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
animatelcm/models/attention.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.modeling_utils import ModelMixin
|
10 |
+
from diffusers.utils import BaseOutput
|
11 |
+
from diffusers.utils.import_utils import is_xformers_available
|
12 |
+
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
|
13 |
+
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class Transformer3DModelOutput(BaseOutput):
|
18 |
+
sample: torch.FloatTensor
|
19 |
+
|
20 |
+
|
21 |
+
if is_xformers_available():
|
22 |
+
import xformers
|
23 |
+
import xformers.ops
|
24 |
+
else:
|
25 |
+
xformers = None
|
26 |
+
|
27 |
+
|
28 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
29 |
+
@register_to_config
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
num_attention_heads: int = 16,
|
33 |
+
attention_head_dim: int = 88,
|
34 |
+
in_channels: Optional[int] = None,
|
35 |
+
num_layers: int = 1,
|
36 |
+
dropout: float = 0.0,
|
37 |
+
norm_num_groups: int = 32,
|
38 |
+
cross_attention_dim: Optional[int] = None,
|
39 |
+
attention_bias: bool = False,
|
40 |
+
activation_fn: str = "geglu",
|
41 |
+
num_embeds_ada_norm: Optional[int] = None,
|
42 |
+
use_linear_projection: bool = False,
|
43 |
+
only_cross_attention: bool = False,
|
44 |
+
upcast_attention: bool = False,
|
45 |
+
|
46 |
+
unet_use_cross_frame_attention=None,
|
47 |
+
unet_use_temporal_attention=None,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.use_linear_projection = use_linear_projection
|
51 |
+
self.num_attention_heads = num_attention_heads
|
52 |
+
self.attention_head_dim = attention_head_dim
|
53 |
+
inner_dim = num_attention_heads * attention_head_dim
|
54 |
+
|
55 |
+
# Define input layers
|
56 |
+
self.in_channels = in_channels
|
57 |
+
|
58 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
59 |
+
if use_linear_projection:
|
60 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
61 |
+
else:
|
62 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
63 |
+
|
64 |
+
# Define transformers blocks
|
65 |
+
self.transformer_blocks = nn.ModuleList(
|
66 |
+
[
|
67 |
+
BasicTransformerBlock(
|
68 |
+
inner_dim,
|
69 |
+
num_attention_heads,
|
70 |
+
attention_head_dim,
|
71 |
+
dropout=dropout,
|
72 |
+
cross_attention_dim=cross_attention_dim,
|
73 |
+
activation_fn=activation_fn,
|
74 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
75 |
+
attention_bias=attention_bias,
|
76 |
+
only_cross_attention=only_cross_attention,
|
77 |
+
upcast_attention=upcast_attention,
|
78 |
+
|
79 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
80 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
81 |
+
)
|
82 |
+
for d in range(num_layers)
|
83 |
+
]
|
84 |
+
)
|
85 |
+
|
86 |
+
# 4. Define output layers
|
87 |
+
if use_linear_projection:
|
88 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
89 |
+
else:
|
90 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
91 |
+
|
92 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
93 |
+
# Input
|
94 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
95 |
+
video_length = hidden_states.shape[2]
|
96 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
97 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
98 |
+
|
99 |
+
batch, channel, height, weight = hidden_states.shape
|
100 |
+
residual = hidden_states
|
101 |
+
|
102 |
+
hidden_states = self.norm(hidden_states)
|
103 |
+
if not self.use_linear_projection:
|
104 |
+
hidden_states = self.proj_in(hidden_states)
|
105 |
+
inner_dim = hidden_states.shape[1]
|
106 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
107 |
+
else:
|
108 |
+
inner_dim = hidden_states.shape[1]
|
109 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
110 |
+
hidden_states = self.proj_in(hidden_states)
|
111 |
+
|
112 |
+
# Blocks
|
113 |
+
for block in self.transformer_blocks:
|
114 |
+
hidden_states = block(
|
115 |
+
hidden_states,
|
116 |
+
encoder_hidden_states=encoder_hidden_states,
|
117 |
+
timestep=timestep,
|
118 |
+
video_length=video_length
|
119 |
+
)
|
120 |
+
|
121 |
+
# Output
|
122 |
+
if not self.use_linear_projection:
|
123 |
+
hidden_states = (
|
124 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
125 |
+
)
|
126 |
+
hidden_states = self.proj_out(hidden_states)
|
127 |
+
else:
|
128 |
+
hidden_states = self.proj_out(hidden_states)
|
129 |
+
hidden_states = (
|
130 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
131 |
+
)
|
132 |
+
|
133 |
+
output = hidden_states + residual
|
134 |
+
|
135 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
136 |
+
if not return_dict:
|
137 |
+
return (output,)
|
138 |
+
|
139 |
+
return Transformer3DModelOutput(sample=output)
|
140 |
+
|
141 |
+
|
142 |
+
class BasicTransformerBlock(nn.Module):
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
dim: int,
|
146 |
+
num_attention_heads: int,
|
147 |
+
attention_head_dim: int,
|
148 |
+
dropout=0.0,
|
149 |
+
cross_attention_dim: Optional[int] = None,
|
150 |
+
activation_fn: str = "geglu",
|
151 |
+
num_embeds_ada_norm: Optional[int] = None,
|
152 |
+
attention_bias: bool = False,
|
153 |
+
only_cross_attention: bool = False,
|
154 |
+
upcast_attention: bool = False,
|
155 |
+
|
156 |
+
unet_use_cross_frame_attention = None,
|
157 |
+
unet_use_temporal_attention = None,
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
self.only_cross_attention = only_cross_attention
|
161 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
162 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
163 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
164 |
+
|
165 |
+
# SC-Attn
|
166 |
+
assert unet_use_cross_frame_attention is not None
|
167 |
+
if unet_use_cross_frame_attention:
|
168 |
+
self.attn1 = SparseCausalAttention2D(
|
169 |
+
query_dim=dim,
|
170 |
+
heads=num_attention_heads,
|
171 |
+
dim_head=attention_head_dim,
|
172 |
+
dropout=dropout,
|
173 |
+
bias=attention_bias,
|
174 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
175 |
+
upcast_attention=upcast_attention,
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
self.attn1 = CrossAttention(
|
179 |
+
query_dim=dim,
|
180 |
+
heads=num_attention_heads,
|
181 |
+
dim_head=attention_head_dim,
|
182 |
+
dropout=dropout,
|
183 |
+
bias=attention_bias,
|
184 |
+
upcast_attention=upcast_attention,
|
185 |
+
)
|
186 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
187 |
+
|
188 |
+
# Cross-Attn
|
189 |
+
if cross_attention_dim is not None:
|
190 |
+
self.attn2 = CrossAttention(
|
191 |
+
query_dim=dim,
|
192 |
+
cross_attention_dim=cross_attention_dim,
|
193 |
+
heads=num_attention_heads,
|
194 |
+
dim_head=attention_head_dim,
|
195 |
+
dropout=dropout,
|
196 |
+
bias=attention_bias,
|
197 |
+
upcast_attention=upcast_attention,
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
self.attn2 = None
|
201 |
+
|
202 |
+
if cross_attention_dim is not None:
|
203 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
204 |
+
else:
|
205 |
+
self.norm2 = None
|
206 |
+
|
207 |
+
# Feed-forward
|
208 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
209 |
+
self.norm3 = nn.LayerNorm(dim)
|
210 |
+
|
211 |
+
# Temp-Attn
|
212 |
+
assert unet_use_temporal_attention is not None
|
213 |
+
if unet_use_temporal_attention:
|
214 |
+
self.attn_temp = CrossAttention(
|
215 |
+
query_dim=dim,
|
216 |
+
heads=num_attention_heads,
|
217 |
+
dim_head=attention_head_dim,
|
218 |
+
dropout=dropout,
|
219 |
+
bias=attention_bias,
|
220 |
+
upcast_attention=upcast_attention,
|
221 |
+
)
|
222 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
223 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
224 |
+
|
225 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
226 |
+
if not is_xformers_available():
|
227 |
+
raise ModuleNotFoundError(
|
228 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
229 |
+
" xformers",
|
230 |
+
name="xformers",
|
231 |
+
)
|
232 |
+
elif not torch.cuda.is_available():
|
233 |
+
raise ValueError(
|
234 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
235 |
+
" available for GPU "
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
try:
|
239 |
+
# Make sure we can run the memory efficient attention
|
240 |
+
_ = xformers.ops.memory_efficient_attention(
|
241 |
+
torch.randn((1, 2, 40), device="cuda"),
|
242 |
+
torch.randn((1, 2, 40), device="cuda"),
|
243 |
+
torch.randn((1, 2, 40), device="cuda"),
|
244 |
+
)
|
245 |
+
except Exception as e:
|
246 |
+
raise e
|
247 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
248 |
+
if self.attn2 is not None:
|
249 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
250 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
251 |
+
|
252 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
|
253 |
+
# SparseCausal-Attention
|
254 |
+
norm_hidden_states = (
|
255 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
256 |
+
)
|
257 |
+
|
258 |
+
# if self.only_cross_attention:
|
259 |
+
# hidden_states = (
|
260 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
261 |
+
# )
|
262 |
+
# else:
|
263 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
264 |
+
|
265 |
+
# pdb.set_trace()
|
266 |
+
if self.unet_use_cross_frame_attention:
|
267 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
268 |
+
else:
|
269 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
270 |
+
|
271 |
+
if self.attn2 is not None:
|
272 |
+
# Cross-Attention
|
273 |
+
norm_hidden_states = (
|
274 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
275 |
+
)
|
276 |
+
hidden_states = (
|
277 |
+
self.attn2(
|
278 |
+
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
279 |
+
)
|
280 |
+
+ hidden_states
|
281 |
+
)
|
282 |
+
|
283 |
+
# Feed-forward
|
284 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
285 |
+
|
286 |
+
# Temporal-Attention
|
287 |
+
if self.unet_use_temporal_attention:
|
288 |
+
d = hidden_states.shape[1]
|
289 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
290 |
+
norm_hidden_states = (
|
291 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
292 |
+
)
|
293 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
294 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
295 |
+
|
296 |
+
return hidden_states
|
animatelcm/models/embeddings.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
def get_timestep_embedding(
|
22 |
+
timesteps: torch.Tensor,
|
23 |
+
embedding_dim: int,
|
24 |
+
flip_sin_to_cos: bool = False,
|
25 |
+
downscale_freq_shift: float = 1,
|
26 |
+
scale: float = 1,
|
27 |
+
max_period: int = 10000,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
31 |
+
|
32 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
33 |
+
These may be fractional.
|
34 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
35 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
36 |
+
"""
|
37 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
38 |
+
|
39 |
+
half_dim = embedding_dim // 2
|
40 |
+
exponent = -math.log(max_period) * torch.arange(
|
41 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
42 |
+
)
|
43 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
44 |
+
|
45 |
+
emb = torch.exp(exponent)
|
46 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
47 |
+
|
48 |
+
# scale embeddings
|
49 |
+
emb = scale * emb
|
50 |
+
|
51 |
+
# concat sine and cosine embeddings
|
52 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
53 |
+
|
54 |
+
# flip sine and cosine embeddings
|
55 |
+
if flip_sin_to_cos:
|
56 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
57 |
+
|
58 |
+
# zero pad
|
59 |
+
if embedding_dim % 2 == 1:
|
60 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
61 |
+
return emb
|
62 |
+
|
63 |
+
def zero_module(module):
|
64 |
+
# Zero out the parameters of a module and return it.
|
65 |
+
for p in module.parameters():
|
66 |
+
p.detach().zero_()
|
67 |
+
return module
|
68 |
+
|
69 |
+
class TimestepEmbedding(nn.Module):
|
70 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, time_cond_proj_dim=None):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
74 |
+
self.act = None
|
75 |
+
if act_fn == "silu":
|
76 |
+
self.act = nn.SiLU()
|
77 |
+
elif act_fn == "mish":
|
78 |
+
self.act = nn.Mish()
|
79 |
+
|
80 |
+
if time_cond_proj_dim is not None:
|
81 |
+
self.cond_proj = zero_module(nn.Linear(time_cond_proj_dim, in_channels, bias=False))
|
82 |
+
else:
|
83 |
+
self.cond_proj = None
|
84 |
+
|
85 |
+
|
86 |
+
if out_dim is not None:
|
87 |
+
time_embed_dim_out = out_dim
|
88 |
+
else:
|
89 |
+
time_embed_dim_out = time_embed_dim
|
90 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
91 |
+
|
92 |
+
def forward(self, sample, condition=None):
|
93 |
+
if condition is not None:
|
94 |
+
sample = sample + self.cond_proj(condition)
|
95 |
+
sample = self.linear_1(sample)
|
96 |
+
|
97 |
+
if self.act is not None:
|
98 |
+
sample = self.act(sample)
|
99 |
+
|
100 |
+
sample = self.linear_2(sample)
|
101 |
+
return sample
|
102 |
+
|
103 |
+
|
104 |
+
class Timesteps(nn.Module):
|
105 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
106 |
+
super().__init__()
|
107 |
+
self.num_channels = num_channels
|
108 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
109 |
+
self.downscale_freq_shift = downscale_freq_shift
|
110 |
+
|
111 |
+
def forward(self, timesteps):
|
112 |
+
t_emb = get_timestep_embedding(
|
113 |
+
timesteps,
|
114 |
+
self.num_channels,
|
115 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
116 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
117 |
+
)
|
118 |
+
return t_emb
|
119 |
+
|
120 |
+
|
121 |
+
class GaussianFourierProjection(nn.Module):
|
122 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
129 |
+
self.log = log
|
130 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
131 |
+
|
132 |
+
if set_W_to_weight:
|
133 |
+
# to delete later
|
134 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
135 |
+
|
136 |
+
self.weight = self.W
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
if self.log:
|
140 |
+
x = torch.log(x)
|
141 |
+
|
142 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
143 |
+
|
144 |
+
if self.flip_sin_to_cos:
|
145 |
+
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
146 |
+
else:
|
147 |
+
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
148 |
+
return out
|
149 |
+
|
150 |
+
|
151 |
+
class ImagePositionalEmbeddings(nn.Module):
|
152 |
+
"""
|
153 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
154 |
+
height and width of the latent space.
|
155 |
+
|
156 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
157 |
+
|
158 |
+
For VQ-diffusion:
|
159 |
+
|
160 |
+
Output vector embeddings are used as input for the transformer.
|
161 |
+
|
162 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
num_embed (`int`):
|
166 |
+
Number of embeddings for the latent pixels embeddings.
|
167 |
+
height (`int`):
|
168 |
+
Height of the latent image i.e. the number of height embeddings.
|
169 |
+
width (`int`):
|
170 |
+
Width of the latent image i.e. the number of width embeddings.
|
171 |
+
embed_dim (`int`):
|
172 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
num_embed: int,
|
178 |
+
height: int,
|
179 |
+
width: int,
|
180 |
+
embed_dim: int,
|
181 |
+
):
|
182 |
+
super().__init__()
|
183 |
+
|
184 |
+
self.height = height
|
185 |
+
self.width = width
|
186 |
+
self.num_embed = num_embed
|
187 |
+
self.embed_dim = embed_dim
|
188 |
+
|
189 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
190 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
191 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
192 |
+
|
193 |
+
def forward(self, index):
|
194 |
+
emb = self.emb(index)
|
195 |
+
|
196 |
+
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
197 |
+
|
198 |
+
# 1 x H x D -> 1 x H x 1 x D
|
199 |
+
height_emb = height_emb.unsqueeze(2)
|
200 |
+
|
201 |
+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
202 |
+
|
203 |
+
# 1 x W x D -> 1 x 1 x W x D
|
204 |
+
width_emb = width_emb.unsqueeze(1)
|
205 |
+
|
206 |
+
pos_emb = height_emb + width_emb
|
207 |
+
|
208 |
+
# 1 x H x W x D -> 1 x L xD
|
209 |
+
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
210 |
+
|
211 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
212 |
+
|
213 |
+
return emb
|
animatelcm/models/motion_module.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from diffusers.modeling_utils import ModelMixin
|
10 |
+
from diffusers.utils import BaseOutput
|
11 |
+
from diffusers.utils.import_utils import is_xformers_available
|
12 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
13 |
+
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
def zero_module(module):
|
19 |
+
for p in module.parameters():
|
20 |
+
p.detach().zero_()
|
21 |
+
return module
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
26 |
+
sample: torch.FloatTensor
|
27 |
+
|
28 |
+
|
29 |
+
if is_xformers_available():
|
30 |
+
import xformers
|
31 |
+
import xformers.ops
|
32 |
+
else:
|
33 |
+
xformers = None
|
34 |
+
|
35 |
+
|
36 |
+
def get_motion_module(
|
37 |
+
in_channels,
|
38 |
+
motion_module_type: str,
|
39 |
+
motion_module_kwargs: dict
|
40 |
+
):
|
41 |
+
if motion_module_type == "Vanilla":
|
42 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
43 |
+
else:
|
44 |
+
raise ValueError
|
45 |
+
|
46 |
+
|
47 |
+
class VanillaTemporalModule(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_channels,
|
51 |
+
num_attention_heads=8,
|
52 |
+
num_transformer_block=2,
|
53 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
54 |
+
cross_frame_attention_mode=None,
|
55 |
+
temporal_position_encoding=False,
|
56 |
+
temporal_attention_dim_div=1,
|
57 |
+
zero_initialize=True,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
62 |
+
in_channels=in_channels,
|
63 |
+
num_attention_heads=num_attention_heads,
|
64 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
65 |
+
num_layers=num_transformer_block,
|
66 |
+
attention_block_types=attention_block_types,
|
67 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
68 |
+
temporal_position_encoding=temporal_position_encoding,
|
69 |
+
)
|
70 |
+
|
71 |
+
if zero_initialize:
|
72 |
+
self.temporal_transformer.proj_out = zero_module(
|
73 |
+
self.temporal_transformer.proj_out)
|
74 |
+
|
75 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
76 |
+
hidden_states = input_tensor
|
77 |
+
hidden_states = self.temporal_transformer(
|
78 |
+
hidden_states, encoder_hidden_states, attention_mask)
|
79 |
+
|
80 |
+
output = hidden_states
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
class TemporalTransformer3DModel(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
in_channels,
|
88 |
+
num_attention_heads,
|
89 |
+
attention_head_dim,
|
90 |
+
|
91 |
+
num_layers,
|
92 |
+
attention_block_types=("Temporal_Self", "Temporal_Self", ),
|
93 |
+
dropout=0.0,
|
94 |
+
norm_num_groups=32,
|
95 |
+
cross_attention_dim=768,
|
96 |
+
activation_fn="geglu",
|
97 |
+
attention_bias=False,
|
98 |
+
upcast_attention=False,
|
99 |
+
|
100 |
+
cross_frame_attention_mode=None,
|
101 |
+
temporal_position_encoding=False,
|
102 |
+
):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
inner_dim = num_attention_heads * attention_head_dim
|
106 |
+
|
107 |
+
self.norm = torch.nn.GroupNorm(
|
108 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
109 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
110 |
+
|
111 |
+
self.transformer_blocks = nn.ModuleList(
|
112 |
+
[
|
113 |
+
TemporalTransformerBlock(
|
114 |
+
dim=inner_dim,
|
115 |
+
num_attention_heads=num_attention_heads,
|
116 |
+
attention_head_dim=attention_head_dim,
|
117 |
+
attention_block_types=attention_block_types,
|
118 |
+
dropout=dropout,
|
119 |
+
norm_num_groups=norm_num_groups,
|
120 |
+
cross_attention_dim=cross_attention_dim,
|
121 |
+
activation_fn=activation_fn,
|
122 |
+
attention_bias=attention_bias,
|
123 |
+
upcast_attention=upcast_attention,
|
124 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
125 |
+
temporal_position_encoding=temporal_position_encoding,
|
126 |
+
)
|
127 |
+
for d in range(num_layers)
|
128 |
+
]
|
129 |
+
)
|
130 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
131 |
+
|
132 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
133 |
+
assert hidden_states.dim(
|
134 |
+
) == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
135 |
+
video_length = hidden_states.shape[2]
|
136 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
137 |
+
|
138 |
+
batch, channel, height, weight = hidden_states.shape
|
139 |
+
residual = hidden_states
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
inner_dim = hidden_states.shape[1]
|
143 |
+
hidden_states = hidden_states.permute(
|
144 |
+
0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
145 |
+
hidden_states = self.proj_in(hidden_states)
|
146 |
+
|
147 |
+
# Transformer Blocks
|
148 |
+
for block in self.transformer_blocks:
|
149 |
+
hidden_states = block(
|
150 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
|
151 |
+
|
152 |
+
# output
|
153 |
+
hidden_states = self.proj_out(hidden_states)
|
154 |
+
hidden_states = hidden_states.reshape(
|
155 |
+
batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
156 |
+
|
157 |
+
output = hidden_states + residual
|
158 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
159 |
+
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
class TemporalTransformerBlock(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
dim,
|
167 |
+
num_attention_heads,
|
168 |
+
attention_head_dim,
|
169 |
+
attention_block_types=("Temporal_Self", "Temporal_Self", ),
|
170 |
+
dropout=0.0,
|
171 |
+
norm_num_groups=32,
|
172 |
+
cross_attention_dim=768,
|
173 |
+
activation_fn="geglu",
|
174 |
+
attention_bias=False,
|
175 |
+
upcast_attention=False,
|
176 |
+
cross_frame_attention_mode=None,
|
177 |
+
temporal_position_encoding=False,
|
178 |
+
):
|
179 |
+
super().__init__()
|
180 |
+
|
181 |
+
attention_blocks = []
|
182 |
+
norms = []
|
183 |
+
|
184 |
+
for block_name in attention_block_types:
|
185 |
+
attention_blocks.append(
|
186 |
+
VersatileAttention(
|
187 |
+
attention_mode=block_name.split("_")[0],
|
188 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith(
|
189 |
+
"_Cross") else None,
|
190 |
+
|
191 |
+
query_dim=dim,
|
192 |
+
heads=num_attention_heads,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
dropout=dropout,
|
195 |
+
bias=attention_bias,
|
196 |
+
upcast_attention=upcast_attention,
|
197 |
+
|
198 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
199 |
+
temporal_position_encoding=temporal_position_encoding,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
norms.append(nn.LayerNorm(dim))
|
203 |
+
|
204 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
205 |
+
self.norms = nn.ModuleList(norms)
|
206 |
+
|
207 |
+
self.ff = FeedForward(dim, dropout=dropout,
|
208 |
+
activation_fn=activation_fn)
|
209 |
+
self.ff_norm = nn.LayerNorm(dim)
|
210 |
+
|
211 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
212 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
213 |
+
norm_hidden_states = norm(hidden_states)
|
214 |
+
hidden_states = attention_block(
|
215 |
+
norm_hidden_states,
|
216 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
217 |
+
video_length=video_length,
|
218 |
+
) + hidden_states
|
219 |
+
|
220 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
221 |
+
|
222 |
+
output = hidden_states
|
223 |
+
return output
|
224 |
+
|
225 |
+
|
226 |
+
class PositionalEncoding(nn.Module):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
d_model,
|
230 |
+
dropout=0.,
|
231 |
+
):
|
232 |
+
super().__init__()
|
233 |
+
|
234 |
+
max_length = 64
|
235 |
+
self.dropout = nn.Dropout(p=dropout)
|
236 |
+
position = torch.arange(max_length).unsqueeze(1)
|
237 |
+
div_term = torch.exp(torch.arange(0, d_model, 2)
|
238 |
+
* (-math.log(10000.0) / d_model))
|
239 |
+
pe = torch.zeros(1, max_length, d_model)
|
240 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
241 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
242 |
+
self.register_buffer('pos_encoding', pe)
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
x = x + self.pos_encoding[:, :x.size(1)]
|
246 |
+
return self.dropout(x)
|
247 |
+
|
248 |
+
|
249 |
+
class VersatileAttention(CrossAttention):
|
250 |
+
def __init__(
|
251 |
+
self,
|
252 |
+
attention_mode=None,
|
253 |
+
cross_frame_attention_mode=None,
|
254 |
+
temporal_position_encoding=False,
|
255 |
+
*args, **kwargs
|
256 |
+
):
|
257 |
+
super().__init__(*args, **kwargs)
|
258 |
+
assert attention_mode == "Temporal"
|
259 |
+
|
260 |
+
self.attention_mode = attention_mode
|
261 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
262 |
+
|
263 |
+
self.pos_encoder = PositionalEncoding(
|
264 |
+
kwargs["query_dim"],
|
265 |
+
dropout=0.,
|
266 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
267 |
+
|
268 |
+
def extra_repr(self):
|
269 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
270 |
+
|
271 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
272 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
273 |
+
|
274 |
+
if self.attention_mode == "Temporal":
|
275 |
+
d = hidden_states.shape[1]
|
276 |
+
hidden_states = rearrange(
|
277 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
278 |
+
|
279 |
+
if self.pos_encoder is not None:
|
280 |
+
hidden_states = self.pos_encoder(hidden_states)
|
281 |
+
|
282 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c",
|
283 |
+
d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
284 |
+
else:
|
285 |
+
raise NotImplementedError
|
286 |
+
|
287 |
+
encoder_hidden_states = encoder_hidden_states
|
288 |
+
|
289 |
+
if self.group_norm is not None:
|
290 |
+
hidden_states = self.group_norm(
|
291 |
+
hidden_states.transpose(1, 2)).transpose(1, 2)
|
292 |
+
|
293 |
+
query = self.to_q(hidden_states)
|
294 |
+
dim = query.shape[-1]
|
295 |
+
query = self.reshape_heads_to_batch_dim(query)
|
296 |
+
|
297 |
+
if self.added_kv_proj_dim is not None:
|
298 |
+
raise NotImplementedError
|
299 |
+
|
300 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
301 |
+
key = self.to_k(encoder_hidden_states)
|
302 |
+
value = self.to_v(encoder_hidden_states)
|
303 |
+
|
304 |
+
key = self.reshape_heads_to_batch_dim(key)
|
305 |
+
value = self.reshape_heads_to_batch_dim(value)
|
306 |
+
|
307 |
+
if attention_mask is not None:
|
308 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
309 |
+
target_length = query.shape[1]
|
310 |
+
attention_mask = F.pad(
|
311 |
+
attention_mask, (0, target_length), value=0.0)
|
312 |
+
attention_mask = attention_mask.repeat_interleave(
|
313 |
+
self.heads, dim=0)
|
314 |
+
|
315 |
+
if self._use_memory_efficient_attention_xformers:
|
316 |
+
hidden_states = self._memory_efficient_attention_xformers(
|
317 |
+
query, key, value, attention_mask)
|
318 |
+
hidden_states = hidden_states.to(query.dtype)
|
319 |
+
else:
|
320 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
321 |
+
hidden_states = self._attention(
|
322 |
+
query, key, value, attention_mask)
|
323 |
+
else:
|
324 |
+
hidden_states = self._sliced_attention(
|
325 |
+
query, key, value, sequence_length, dim, attention_mask)
|
326 |
+
|
327 |
+
# linear proj
|
328 |
+
hidden_states = self.to_out[0](hidden_states)
|
329 |
+
|
330 |
+
# dropout
|
331 |
+
hidden_states = self.to_out[1](hidden_states)
|
332 |
+
|
333 |
+
if self.attention_mode == "Temporal":
|
334 |
+
hidden_states = rearrange(
|
335 |
+
hidden_states, "(b d) f c -> (b f) d c", d=d)
|
336 |
+
|
337 |
+
return hidden_states
|
animatelcm/models/resnet.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
|
11 |
+
class InflatedConv3d(nn.Conv2d):
|
12 |
+
def forward(self, x):
|
13 |
+
video_length = x.shape[2]
|
14 |
+
|
15 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
16 |
+
x = super().forward(x)
|
17 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
18 |
+
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
23 |
+
def forward(self, x):
|
24 |
+
video_length = x.shape[2]
|
25 |
+
|
26 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
27 |
+
x = super().forward(x)
|
28 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
29 |
+
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
class Upsample3D(nn.Module):
|
34 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
35 |
+
super().__init__()
|
36 |
+
self.channels = channels
|
37 |
+
self.out_channels = out_channels or channels
|
38 |
+
self.use_conv = use_conv
|
39 |
+
self.use_conv_transpose = use_conv_transpose
|
40 |
+
self.name = name
|
41 |
+
|
42 |
+
conv = None
|
43 |
+
if use_conv_transpose:
|
44 |
+
raise NotImplementedError
|
45 |
+
elif use_conv:
|
46 |
+
self.conv = InflatedConv3d(
|
47 |
+
self.channels, self.out_channels, 3, padding=1)
|
48 |
+
|
49 |
+
def forward(self, hidden_states, output_size=None):
|
50 |
+
assert hidden_states.shape[1] == self.channels
|
51 |
+
|
52 |
+
if self.use_conv_transpose:
|
53 |
+
raise NotImplementedError
|
54 |
+
|
55 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
56 |
+
dtype = hidden_states.dtype
|
57 |
+
if dtype == torch.bfloat16:
|
58 |
+
hidden_states = hidden_states.to(torch.float32)
|
59 |
+
|
60 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
61 |
+
if hidden_states.shape[0] >= 64:
|
62 |
+
hidden_states = hidden_states.contiguous()
|
63 |
+
|
64 |
+
# if `output_size` is passed we force the interpolation output
|
65 |
+
# size and do not make use of `scale_factor=2`
|
66 |
+
if output_size is None:
|
67 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[
|
68 |
+
1.0, 2.0, 2.0], mode="nearest")
|
69 |
+
else:
|
70 |
+
hidden_states = F.interpolate(
|
71 |
+
hidden_states, size=output_size, mode="nearest")
|
72 |
+
|
73 |
+
# If the input is bfloat16, we cast back to bfloat16
|
74 |
+
if dtype == torch.bfloat16:
|
75 |
+
hidden_states = hidden_states.to(dtype)
|
76 |
+
|
77 |
+
# if self.use_conv:
|
78 |
+
# if self.name == "conv":
|
79 |
+
# hidden_states = self.conv(hidden_states)
|
80 |
+
# else:
|
81 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
82 |
+
hidden_states = self.conv(hidden_states)
|
83 |
+
|
84 |
+
return hidden_states
|
85 |
+
|
86 |
+
|
87 |
+
class Downsample3D(nn.Module):
|
88 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
89 |
+
super().__init__()
|
90 |
+
self.channels = channels
|
91 |
+
self.out_channels = out_channels or channels
|
92 |
+
self.use_conv = use_conv
|
93 |
+
self.padding = padding
|
94 |
+
stride = 2
|
95 |
+
self.name = name
|
96 |
+
|
97 |
+
if use_conv:
|
98 |
+
self.conv = InflatedConv3d(
|
99 |
+
self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
100 |
+
else:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
def forward(self, hidden_states):
|
104 |
+
assert hidden_states.shape[1] == self.channels
|
105 |
+
if self.use_conv and self.padding == 0:
|
106 |
+
raise NotImplementedError
|
107 |
+
|
108 |
+
assert hidden_states.shape[1] == self.channels
|
109 |
+
hidden_states = self.conv(hidden_states)
|
110 |
+
|
111 |
+
return hidden_states
|
112 |
+
|
113 |
+
|
114 |
+
class ResnetBlock3D(nn.Module):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
*,
|
118 |
+
in_channels,
|
119 |
+
out_channels=None,
|
120 |
+
conv_shortcut=False,
|
121 |
+
dropout=0.0,
|
122 |
+
temb_channels=512,
|
123 |
+
groups=32,
|
124 |
+
groups_out=None,
|
125 |
+
pre_norm=True,
|
126 |
+
eps=1e-6,
|
127 |
+
non_linearity="swish",
|
128 |
+
time_embedding_norm="default",
|
129 |
+
output_scale_factor=1.0,
|
130 |
+
use_in_shortcut=None,
|
131 |
+
use_inflated_groupnorm=None,
|
132 |
+
use_temporal_conv=False,
|
133 |
+
use_temporal_mixer=False,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.pre_norm = pre_norm
|
137 |
+
self.pre_norm = True
|
138 |
+
self.in_channels = in_channels
|
139 |
+
out_channels = in_channels if out_channels is None else out_channels
|
140 |
+
self.out_channels = out_channels
|
141 |
+
self.use_conv_shortcut = conv_shortcut
|
142 |
+
self.time_embedding_norm = time_embedding_norm
|
143 |
+
self.output_scale_factor = output_scale_factor
|
144 |
+
self.use_temporal_mixer = use_temporal_mixer
|
145 |
+
if use_temporal_mixer:
|
146 |
+
self.temporal_mixer = AlphaBlender(0.3, "learned", None)
|
147 |
+
|
148 |
+
if groups_out is None:
|
149 |
+
groups_out = groups
|
150 |
+
|
151 |
+
assert use_inflated_groupnorm != None
|
152 |
+
if use_inflated_groupnorm:
|
153 |
+
self.norm1 = InflatedGroupNorm(
|
154 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
155 |
+
else:
|
156 |
+
self.norm1 = torch.nn.GroupNorm(
|
157 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
158 |
+
|
159 |
+
if use_temporal_conv:
|
160 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(
|
161 |
+
3, 1, 1), stride=1, padding=(1, 0, 0))
|
162 |
+
else:
|
163 |
+
self.conv1 = InflatedConv3d(
|
164 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
165 |
+
|
166 |
+
if temb_channels is not None:
|
167 |
+
if self.time_embedding_norm == "default":
|
168 |
+
time_emb_proj_out_channels = out_channels
|
169 |
+
elif self.time_embedding_norm == "scale_shift":
|
170 |
+
time_emb_proj_out_channels = out_channels * 2
|
171 |
+
else:
|
172 |
+
raise ValueError(
|
173 |
+
f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
174 |
+
|
175 |
+
self.time_emb_proj = torch.nn.Linear(
|
176 |
+
temb_channels, time_emb_proj_out_channels)
|
177 |
+
else:
|
178 |
+
self.time_emb_proj = None
|
179 |
+
|
180 |
+
if use_inflated_groupnorm:
|
181 |
+
self.norm2 = InflatedGroupNorm(
|
182 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
183 |
+
else:
|
184 |
+
self.norm2 = torch.nn.GroupNorm(
|
185 |
+
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
186 |
+
|
187 |
+
self.dropout = torch.nn.Dropout(dropout)
|
188 |
+
if use_temporal_conv:
|
189 |
+
self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(
|
190 |
+
3, 1, 1), stride=1, padding=(1, 0, 0))
|
191 |
+
else:
|
192 |
+
self.conv2 = InflatedConv3d(
|
193 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
194 |
+
|
195 |
+
if non_linearity == "swish":
|
196 |
+
self.nonlinearity = lambda x: F.silu(x)
|
197 |
+
elif non_linearity == "mish":
|
198 |
+
self.nonlinearity = Mish()
|
199 |
+
elif non_linearity == "silu":
|
200 |
+
self.nonlinearity = nn.SiLU()
|
201 |
+
|
202 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
203 |
+
|
204 |
+
self.conv_shortcut = None
|
205 |
+
if self.use_in_shortcut:
|
206 |
+
self.conv_shortcut = InflatedConv3d(
|
207 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
208 |
+
|
209 |
+
def forward(self, input_tensor, temb):
|
210 |
+
if self.use_temporal_mixer:
|
211 |
+
residual = input_tensor
|
212 |
+
|
213 |
+
hidden_states = input_tensor
|
214 |
+
|
215 |
+
hidden_states = self.norm1(hidden_states)
|
216 |
+
hidden_states = self.nonlinearity(hidden_states)
|
217 |
+
|
218 |
+
hidden_states = self.conv1(hidden_states)
|
219 |
+
|
220 |
+
if temb is not None:
|
221 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[
|
222 |
+
:, :, None, None, None]
|
223 |
+
|
224 |
+
if temb is not None and self.time_embedding_norm == "default":
|
225 |
+
hidden_states = hidden_states + temb
|
226 |
+
|
227 |
+
hidden_states = self.norm2(hidden_states)
|
228 |
+
|
229 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
230 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
231 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
232 |
+
|
233 |
+
hidden_states = self.nonlinearity(hidden_states)
|
234 |
+
|
235 |
+
hidden_states = self.dropout(hidden_states)
|
236 |
+
hidden_states = self.conv2(hidden_states)
|
237 |
+
|
238 |
+
if self.conv_shortcut is not None:
|
239 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
240 |
+
|
241 |
+
output_tensor = (input_tensor + hidden_states) / \
|
242 |
+
self.output_scale_factor
|
243 |
+
|
244 |
+
if self.use_temporal_mixer:
|
245 |
+
output_tensor = self.temporal_mixer(residual, output_tensor, None)
|
246 |
+
# return residual + 0.0 * self.temporal_mixer(residual, output_tensor, None)
|
247 |
+
return output_tensor
|
248 |
+
|
249 |
+
|
250 |
+
class Mish(torch.nn.Module):
|
251 |
+
def forward(self, hidden_states):
|
252 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
253 |
+
|
254 |
+
|
255 |
+
class AlphaBlender(nn.Module):
|
256 |
+
strategies = ["learned", "fixed", "learned_with_images"]
|
257 |
+
|
258 |
+
def __init__(
|
259 |
+
self,
|
260 |
+
alpha: float,
|
261 |
+
merge_strategy: str = "learned_with_images",
|
262 |
+
rearrange_pattern: str = "b t -> (b t) 1 1",
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
self.merge_strategy = merge_strategy
|
266 |
+
self.rearrange_pattern = rearrange_pattern
|
267 |
+
self.scaler = 10.
|
268 |
+
|
269 |
+
assert (
|
270 |
+
merge_strategy in self.strategies
|
271 |
+
), f"merge_strategy needs to be in {self.strategies}"
|
272 |
+
|
273 |
+
if self.merge_strategy == "fixed":
|
274 |
+
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
275 |
+
elif (
|
276 |
+
self.merge_strategy == "learned"
|
277 |
+
or self.merge_strategy == "learned_with_images"
|
278 |
+
):
|
279 |
+
self.register_parameter(
|
280 |
+
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
284 |
+
|
285 |
+
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
286 |
+
if self.merge_strategy == "fixed":
|
287 |
+
alpha = self.mix_factor
|
288 |
+
elif self.merge_strategy == "learned":
|
289 |
+
alpha = torch.sigmoid(self.mix_factor*self.scaler)
|
290 |
+
elif self.merge_strategy == "learned_with_images":
|
291 |
+
assert image_only_indicator is not None, "need image_only_indicator ..."
|
292 |
+
alpha = torch.where(
|
293 |
+
image_only_indicator.bool(),
|
294 |
+
torch.ones(1, 1, device=image_only_indicator.device),
|
295 |
+
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
296 |
+
)
|
297 |
+
alpha = rearrange(alpha, self.rearrange_pattern)
|
298 |
+
else:
|
299 |
+
raise NotImplementedError
|
300 |
+
return alpha
|
301 |
+
|
302 |
+
def forward(
|
303 |
+
self,
|
304 |
+
x_spatial: torch.Tensor,
|
305 |
+
x_temporal: torch.Tensor,
|
306 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
307 |
+
) -> torch.Tensor:
|
308 |
+
alpha = self.get_alpha(image_only_indicator)
|
309 |
+
x = (
|
310 |
+
alpha.to(x_spatial.dtype) * x_spatial
|
311 |
+
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
312 |
+
)
|
313 |
+
return x
|
animatelcm/models/unet.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
|
13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
14 |
+
from diffusers.modeling_utils import ModelMixin
|
15 |
+
from diffusers.utils import BaseOutput, logging
|
16 |
+
from animatelcm.models.embeddings import TimestepEmbedding, Timesteps
|
17 |
+
from .unet_blocks import (
|
18 |
+
CrossAttnDownBlock3D,
|
19 |
+
CrossAttnUpBlock3D,
|
20 |
+
DownBlock3D,
|
21 |
+
UNetMidBlock3DCrossAttn,
|
22 |
+
UpBlock3D,
|
23 |
+
get_down_block,
|
24 |
+
get_up_block,
|
25 |
+
)
|
26 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
27 |
+
# from .adapter import Adapter, PixelAdapter # Not ready
|
28 |
+
from einops import repeat
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class UNet3DConditionOutput(BaseOutput):
|
36 |
+
sample: torch.FloatTensor
|
37 |
+
|
38 |
+
|
39 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
40 |
+
_supports_gradient_checkpointing = True
|
41 |
+
|
42 |
+
@register_to_config
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
sample_size: Optional[int] = None,
|
46 |
+
in_channels: int = 4,
|
47 |
+
out_channels: int = 4,
|
48 |
+
center_input_sample: bool = False,
|
49 |
+
flip_sin_to_cos: bool = True,
|
50 |
+
freq_shift: int = 0,
|
51 |
+
down_block_types: Tuple[str] = (
|
52 |
+
"CrossAttnDownBlock3D",
|
53 |
+
"CrossAttnDownBlock3D",
|
54 |
+
"CrossAttnDownBlock3D",
|
55 |
+
"DownBlock3D",
|
56 |
+
),
|
57 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
58 |
+
up_block_types: Tuple[str] = (
|
59 |
+
"UpBlock3D",
|
60 |
+
"CrossAttnUpBlock3D",
|
61 |
+
"CrossAttnUpBlock3D",
|
62 |
+
"CrossAttnUpBlock3D"
|
63 |
+
),
|
64 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
65 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
66 |
+
layers_per_block: int = 2,
|
67 |
+
downsample_padding: int = 1,
|
68 |
+
mid_block_scale_factor: float = 1,
|
69 |
+
act_fn: str = "silu",
|
70 |
+
norm_num_groups: int = 32,
|
71 |
+
norm_eps: float = 1e-5,
|
72 |
+
cross_attention_dim: int = 1280,
|
73 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
74 |
+
dual_cross_attention: bool = False,
|
75 |
+
use_linear_projection: bool = False,
|
76 |
+
class_embed_type: Optional[str] = None,
|
77 |
+
num_class_embeds: Optional[int] = None,
|
78 |
+
upcast_attention: bool = False,
|
79 |
+
resnet_time_scale_shift: str = "default",
|
80 |
+
|
81 |
+
use_inflated_groupnorm=False,
|
82 |
+
|
83 |
+
# Additional
|
84 |
+
use_motion_module=False,
|
85 |
+
use_motion_resnet=False,
|
86 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
87 |
+
motion_module_mid_block=False,
|
88 |
+
motion_module_decoder_only=False,
|
89 |
+
motion_module_type=None,
|
90 |
+
motion_module_kwargs={},
|
91 |
+
unet_use_cross_frame_attention=None,
|
92 |
+
unet_use_temporal_attention=None,
|
93 |
+
time_cond_proj_dim=None, # not ready
|
94 |
+
use_img_encoder=False,
|
95 |
+
use_pixel_encoder=False,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
self.sample_size = sample_size
|
100 |
+
time_embed_dim = block_out_channels[0] * 4
|
101 |
+
|
102 |
+
self.img_encoder = None if use_img_encoder else None # not ready
|
103 |
+
self.pixel_encoder = None if use_pixel_encoder else None # not ready
|
104 |
+
|
105 |
+
|
106 |
+
# input
|
107 |
+
self.conv_in = InflatedConv3d(
|
108 |
+
in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
109 |
+
|
110 |
+
# time
|
111 |
+
self.time_proj = Timesteps(
|
112 |
+
block_out_channels[0], flip_sin_to_cos, freq_shift)
|
113 |
+
timestep_input_dim = block_out_channels[0]
|
114 |
+
|
115 |
+
self.time_embedding = TimestepEmbedding(
|
116 |
+
timestep_input_dim, time_embed_dim, time_cond_proj_dim=time_cond_proj_dim)
|
117 |
+
|
118 |
+
# class embedding
|
119 |
+
if class_embed_type is None and num_class_embeds is not None:
|
120 |
+
self.class_embedding = nn.Embedding(
|
121 |
+
num_class_embeds, time_embed_dim)
|
122 |
+
elif class_embed_type == "timestep":
|
123 |
+
self.class_embedding = TimestepEmbedding(
|
124 |
+
timestep_input_dim, time_embed_dim)
|
125 |
+
elif class_embed_type == "identity":
|
126 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
127 |
+
else:
|
128 |
+
self.class_embedding = None
|
129 |
+
|
130 |
+
self.down_blocks = nn.ModuleList([])
|
131 |
+
self.mid_block = None
|
132 |
+
self.up_blocks = nn.ModuleList([])
|
133 |
+
|
134 |
+
if isinstance(only_cross_attention, bool):
|
135 |
+
only_cross_attention = [
|
136 |
+
only_cross_attention] * len(down_block_types)
|
137 |
+
|
138 |
+
if isinstance(attention_head_dim, int):
|
139 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
140 |
+
|
141 |
+
# down
|
142 |
+
output_channel = block_out_channels[0]
|
143 |
+
for i, down_block_type in enumerate(down_block_types):
|
144 |
+
res = 2 ** i
|
145 |
+
input_channel = output_channel
|
146 |
+
output_channel = block_out_channels[i]
|
147 |
+
is_final_block = i == len(block_out_channels) - 1
|
148 |
+
|
149 |
+
down_block = get_down_block(
|
150 |
+
down_block_type,
|
151 |
+
num_layers=layers_per_block,
|
152 |
+
in_channels=input_channel,
|
153 |
+
out_channels=output_channel,
|
154 |
+
temb_channels=time_embed_dim,
|
155 |
+
add_downsample=not is_final_block,
|
156 |
+
resnet_eps=norm_eps,
|
157 |
+
resnet_act_fn=act_fn,
|
158 |
+
resnet_groups=norm_num_groups,
|
159 |
+
cross_attention_dim=cross_attention_dim,
|
160 |
+
attn_num_head_channels=attention_head_dim[i],
|
161 |
+
downsample_padding=downsample_padding,
|
162 |
+
dual_cross_attention=dual_cross_attention,
|
163 |
+
use_linear_projection=use_linear_projection,
|
164 |
+
only_cross_attention=only_cross_attention[i],
|
165 |
+
upcast_attention=upcast_attention,
|
166 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
167 |
+
|
168 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
169 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
170 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
171 |
+
|
172 |
+
use_motion_module=use_motion_module and (
|
173 |
+
res in motion_module_resolutions) and (not motion_module_decoder_only),
|
174 |
+
use_motion_resnet=use_motion_resnet and (
|
175 |
+
res in motion_module_resolutions) and (not motion_module_decoder_only),
|
176 |
+
motion_module_type=motion_module_type,
|
177 |
+
motion_module_kwargs=motion_module_kwargs,
|
178 |
+
)
|
179 |
+
self.down_blocks.append(down_block)
|
180 |
+
|
181 |
+
# mid
|
182 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
183 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
184 |
+
in_channels=block_out_channels[-1],
|
185 |
+
temb_channels=time_embed_dim,
|
186 |
+
resnet_eps=norm_eps,
|
187 |
+
resnet_act_fn=act_fn,
|
188 |
+
output_scale_factor=mid_block_scale_factor,
|
189 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
190 |
+
cross_attention_dim=cross_attention_dim,
|
191 |
+
attn_num_head_channels=attention_head_dim[-1],
|
192 |
+
resnet_groups=norm_num_groups,
|
193 |
+
dual_cross_attention=dual_cross_attention,
|
194 |
+
use_linear_projection=use_linear_projection,
|
195 |
+
upcast_attention=upcast_attention,
|
196 |
+
|
197 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
198 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
199 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
200 |
+
|
201 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
202 |
+
use_motion_resnet=use_motion_resnet and motion_module_mid_block,
|
203 |
+
|
204 |
+
motion_module_type=motion_module_type,
|
205 |
+
motion_module_kwargs=motion_module_kwargs,
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
209 |
+
|
210 |
+
# count how many layers upsample the videos
|
211 |
+
self.num_upsamplers = 0
|
212 |
+
|
213 |
+
# up
|
214 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
215 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
216 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
217 |
+
output_channel = reversed_block_out_channels[0]
|
218 |
+
for i, up_block_type in enumerate(up_block_types):
|
219 |
+
res = 2 ** (3 - i)
|
220 |
+
is_final_block = i == len(block_out_channels) - 1
|
221 |
+
|
222 |
+
prev_output_channel = output_channel
|
223 |
+
output_channel = reversed_block_out_channels[i]
|
224 |
+
input_channel = reversed_block_out_channels[min(
|
225 |
+
i + 1, len(block_out_channels) - 1)]
|
226 |
+
|
227 |
+
# add upsample block for all BUT final layer
|
228 |
+
if not is_final_block:
|
229 |
+
add_upsample = True
|
230 |
+
self.num_upsamplers += 1
|
231 |
+
else:
|
232 |
+
add_upsample = False
|
233 |
+
|
234 |
+
up_block = get_up_block(
|
235 |
+
up_block_type,
|
236 |
+
num_layers=layers_per_block + 1,
|
237 |
+
in_channels=input_channel,
|
238 |
+
out_channels=output_channel,
|
239 |
+
prev_output_channel=prev_output_channel,
|
240 |
+
temb_channels=time_embed_dim,
|
241 |
+
add_upsample=add_upsample,
|
242 |
+
resnet_eps=norm_eps,
|
243 |
+
resnet_act_fn=act_fn,
|
244 |
+
resnet_groups=norm_num_groups,
|
245 |
+
cross_attention_dim=cross_attention_dim,
|
246 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
247 |
+
dual_cross_attention=dual_cross_attention,
|
248 |
+
use_linear_projection=use_linear_projection,
|
249 |
+
only_cross_attention=only_cross_attention[i],
|
250 |
+
upcast_attention=upcast_attention,
|
251 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
252 |
+
|
253 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
254 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
255 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
256 |
+
|
257 |
+
use_motion_module=use_motion_module and (
|
258 |
+
res in motion_module_resolutions),
|
259 |
+
use_motion_resnet=use_motion_resnet and (
|
260 |
+
res in motion_module_resolutions),
|
261 |
+
|
262 |
+
motion_module_type=motion_module_type,
|
263 |
+
motion_module_kwargs=motion_module_kwargs,
|
264 |
+
)
|
265 |
+
self.up_blocks.append(up_block)
|
266 |
+
prev_output_channel = output_channel
|
267 |
+
|
268 |
+
# out
|
269 |
+
if use_inflated_groupnorm:
|
270 |
+
self.conv_norm_out = InflatedGroupNorm(
|
271 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
272 |
+
else:
|
273 |
+
self.conv_norm_out = nn.GroupNorm(
|
274 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
275 |
+
self.conv_act = nn.SiLU()
|
276 |
+
self.conv_out = InflatedConv3d(
|
277 |
+
block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
278 |
+
|
279 |
+
def set_attention_slice(self, slice_size):
|
280 |
+
r"""
|
281 |
+
Enable sliced attention computation.
|
282 |
+
|
283 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
284 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
288 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
289 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
290 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
291 |
+
must be a multiple of `slice_size`.
|
292 |
+
"""
|
293 |
+
sliceable_head_dims = []
|
294 |
+
|
295 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
296 |
+
if hasattr(module, "set_attention_slice"):
|
297 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
298 |
+
|
299 |
+
for child in module.children():
|
300 |
+
fn_recursive_retrieve_slicable_dims(child)
|
301 |
+
|
302 |
+
# retrieve number of attention layers
|
303 |
+
for module in self.children():
|
304 |
+
fn_recursive_retrieve_slicable_dims(module)
|
305 |
+
|
306 |
+
num_slicable_layers = len(sliceable_head_dims)
|
307 |
+
|
308 |
+
if slice_size == "auto":
|
309 |
+
# half the attention head size is usually a good trade-off between
|
310 |
+
# speed and memory
|
311 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
312 |
+
elif slice_size == "max":
|
313 |
+
# make smallest slice possible
|
314 |
+
slice_size = num_slicable_layers * [1]
|
315 |
+
|
316 |
+
slice_size = num_slicable_layers * \
|
317 |
+
[slice_size] if not isinstance(slice_size, list) else slice_size
|
318 |
+
|
319 |
+
if len(slice_size) != len(sliceable_head_dims):
|
320 |
+
raise ValueError(
|
321 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
322 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
323 |
+
)
|
324 |
+
|
325 |
+
for i in range(len(slice_size)):
|
326 |
+
size = slice_size[i]
|
327 |
+
dim = sliceable_head_dims[i]
|
328 |
+
if size is not None and size > dim:
|
329 |
+
raise ValueError(
|
330 |
+
f"size {size} has to be smaller or equal to {dim}.")
|
331 |
+
|
332 |
+
# Recursively walk through all the children.
|
333 |
+
# Any children which exposes the set_attention_slice method
|
334 |
+
# gets the message
|
335 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
336 |
+
if hasattr(module, "set_attention_slice"):
|
337 |
+
module.set_attention_slice(slice_size.pop())
|
338 |
+
|
339 |
+
for child in module.children():
|
340 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
341 |
+
|
342 |
+
reversed_slice_size = list(reversed(slice_size))
|
343 |
+
for module in self.children():
|
344 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
345 |
+
|
346 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
347 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
348 |
+
module.gradient_checkpointing = value
|
349 |
+
|
350 |
+
def forward(
|
351 |
+
self,
|
352 |
+
sample: torch.FloatTensor,
|
353 |
+
timestep: Union[torch.Tensor, float, int],
|
354 |
+
encoder_hidden_states: torch.Tensor,
|
355 |
+
img_latent: torch.FloatTensor = None,
|
356 |
+
control: torch.FloatTensor = None,
|
357 |
+
time_cond: torch.FloatTensor = None, # not ready
|
358 |
+
class_labels: Optional[torch.Tensor] = None,
|
359 |
+
attention_mask: Optional[torch.Tensor] = None,
|
360 |
+
return_dict: bool = True,
|
361 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
362 |
+
r"""
|
363 |
+
Args:
|
364 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
365 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
366 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
367 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
368 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
372 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
373 |
+
returning a tuple, the first element is the sample tensor.
|
374 |
+
"""
|
375 |
+
|
376 |
+
if img_latent is not None and self.img_encoder is not None:
|
377 |
+
f = sample.shape[2]
|
378 |
+
img_latent = repeat(img_latent, "b c h w -> b c f h w",
|
379 |
+
f=f) if img_latent.ndim == 4 else img_latent
|
380 |
+
img_features = self.img_encoder(img_latent)
|
381 |
+
else:
|
382 |
+
img_features = None
|
383 |
+
|
384 |
+
if control is not None and self.pixel_encoder is not None:
|
385 |
+
ctrl_features = self.pixel_encoder(control)
|
386 |
+
else:
|
387 |
+
# assert 0
|
388 |
+
ctrl_features = None
|
389 |
+
|
390 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
391 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
392 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
393 |
+
# on the fly if necessary.
|
394 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
395 |
+
|
396 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
397 |
+
forward_upsample_size = False
|
398 |
+
upsample_size = None
|
399 |
+
|
400 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
401 |
+
logger.info(
|
402 |
+
"Forward upsample size to force interpolation output size.")
|
403 |
+
forward_upsample_size = True
|
404 |
+
|
405 |
+
# prepare attention_mask
|
406 |
+
if attention_mask is not None:
|
407 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
408 |
+
attention_mask = attention_mask.unsqueeze(1)
|
409 |
+
|
410 |
+
# center input if necessary
|
411 |
+
if self.config.center_input_sample:
|
412 |
+
sample = 2 * sample - 1.0
|
413 |
+
|
414 |
+
# time
|
415 |
+
timesteps = timestep
|
416 |
+
if not torch.is_tensor(timesteps):
|
417 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
418 |
+
is_mps = sample.device.type == "mps"
|
419 |
+
if isinstance(timestep, float):
|
420 |
+
dtype = torch.float32 if is_mps else torch.float64
|
421 |
+
else:
|
422 |
+
dtype = torch.int32 if is_mps else torch.int64
|
423 |
+
timesteps = torch.tensor(
|
424 |
+
[timesteps], dtype=dtype, device=sample.device)
|
425 |
+
elif len(timesteps.shape) == 0:
|
426 |
+
timesteps = timesteps[None].to(sample.device)
|
427 |
+
|
428 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
429 |
+
timesteps = timesteps.expand(sample.shape[0])
|
430 |
+
|
431 |
+
t_emb = self.time_proj(timesteps)
|
432 |
+
|
433 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
434 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
435 |
+
# there might be better ways to encapsulate this.
|
436 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
437 |
+
|
438 |
+
emb = self.time_embedding(t_emb)
|
439 |
+
|
440 |
+
if self.class_embedding is not None:
|
441 |
+
if class_labels is None:
|
442 |
+
raise ValueError(
|
443 |
+
"class_labels should be provided when num_class_embeds > 0")
|
444 |
+
|
445 |
+
if self.config.class_embed_type == "timestep":
|
446 |
+
class_labels = self.time_proj(class_labels)
|
447 |
+
|
448 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
449 |
+
emb = emb + class_emb
|
450 |
+
|
451 |
+
# pre-process
|
452 |
+
sample = self.conv_in(sample)
|
453 |
+
|
454 |
+
# down
|
455 |
+
|
456 |
+
down_block_res_samples = (sample,)
|
457 |
+
|
458 |
+
img_feature_idx = 0
|
459 |
+
|
460 |
+
for downsample_block in self.down_blocks:
|
461 |
+
|
462 |
+
added_feature = img_features[img_feature_idx] if img_features is not None else torch.tensor(
|
463 |
+
0.).to(sample.device, sample.dtype)
|
464 |
+
added_feature = added_feature + \
|
465 |
+
ctrl_features[img_feature_idx] if ctrl_features is not None else added_feature
|
466 |
+
added_feature = None if added_feature.abs().mean() == 0 else added_feature
|
467 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
468 |
+
sample, res_samples = downsample_block(
|
469 |
+
hidden_states=sample,
|
470 |
+
temb=emb,
|
471 |
+
encoder_hidden_states=encoder_hidden_states,
|
472 |
+
attention_mask=attention_mask,
|
473 |
+
img_feature=added_feature
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
sample, res_samples = downsample_block(
|
477 |
+
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, img_feature=added_feature)
|
478 |
+
|
479 |
+
down_block_res_samples += res_samples
|
480 |
+
img_feature_idx += 1
|
481 |
+
# mid
|
482 |
+
sample = self.mid_block(
|
483 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
484 |
+
)
|
485 |
+
|
486 |
+
# up
|
487 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
488 |
+
is_final_block = i == len(self.up_blocks) - 1
|
489 |
+
|
490 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
491 |
+
down_block_res_samples = down_block_res_samples[: -len(
|
492 |
+
upsample_block.resnets)]
|
493 |
+
|
494 |
+
# if we have not reached the final block and need to forward the
|
495 |
+
# upsample size, we do it here
|
496 |
+
if not is_final_block and forward_upsample_size:
|
497 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
498 |
+
|
499 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
500 |
+
sample = upsample_block(
|
501 |
+
hidden_states=sample,
|
502 |
+
temb=emb,
|
503 |
+
res_hidden_states_tuple=res_samples,
|
504 |
+
encoder_hidden_states=encoder_hidden_states,
|
505 |
+
upsample_size=upsample_size,
|
506 |
+
attention_mask=attention_mask,
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
sample = upsample_block(
|
510 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
|
511 |
+
)
|
512 |
+
|
513 |
+
# post-process
|
514 |
+
sample = self.conv_norm_out(sample)
|
515 |
+
sample = self.conv_act(sample)
|
516 |
+
sample = self.conv_out(sample)
|
517 |
+
|
518 |
+
if not return_dict:
|
519 |
+
return (sample,)
|
520 |
+
|
521 |
+
return UNet3DConditionOutput(sample=sample)
|
522 |
+
|
523 |
+
@classmethod
|
524 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
525 |
+
if subfolder is not None:
|
526 |
+
pretrained_model_path = os.path.join(
|
527 |
+
pretrained_model_path, subfolder)
|
528 |
+
print(
|
529 |
+
f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
|
530 |
+
|
531 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
532 |
+
if not os.path.isfile(config_file):
|
533 |
+
raise RuntimeError(f"{config_file} does not exist")
|
534 |
+
with open(config_file, "r") as f:
|
535 |
+
config = json.load(f)
|
536 |
+
config["_class_name"] = cls.__name__
|
537 |
+
config["down_block_types"] = [
|
538 |
+
"CrossAttnDownBlock3D",
|
539 |
+
"CrossAttnDownBlock3D",
|
540 |
+
"CrossAttnDownBlock3D",
|
541 |
+
"DownBlock3D"
|
542 |
+
]
|
543 |
+
config["up_block_types"] = [
|
544 |
+
"UpBlock3D",
|
545 |
+
"CrossAttnUpBlock3D",
|
546 |
+
"CrossAttnUpBlock3D",
|
547 |
+
"CrossAttnUpBlock3D"
|
548 |
+
]
|
549 |
+
|
550 |
+
from diffusers.utils import WEIGHTS_NAME
|
551 |
+
model = cls.from_config(config, **unet_additional_kwargs)
|
552 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
553 |
+
if not os.path.isfile(model_file):
|
554 |
+
raise RuntimeError(f"{model_file} does not exist")
|
555 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
556 |
+
if "state_dict" in state_dict.keys():
|
557 |
+
state_dict = state_dict["state_dict"]
|
558 |
+
state_dict = {k.replace("module.", ""): v for k,
|
559 |
+
v in state_dict.items()}
|
560 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
561 |
+
print("###load unet weights")
|
562 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
563 |
+
|
564 |
+
params = [p.numel() if "motion" in n else 0 for n,
|
565 |
+
p in model.named_parameters()]
|
566 |
+
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
|
567 |
+
|
568 |
+
return model
|
animatelcm/models/unet_blocks.py
ADDED
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .attention import Transformer3DModel
|
7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D, AlphaBlender
|
8 |
+
from .motion_module import get_motion_module
|
9 |
+
|
10 |
+
|
11 |
+
def get_down_block(
|
12 |
+
down_block_type,
|
13 |
+
num_layers,
|
14 |
+
in_channels,
|
15 |
+
out_channels,
|
16 |
+
temb_channels,
|
17 |
+
add_downsample,
|
18 |
+
resnet_eps,
|
19 |
+
resnet_act_fn,
|
20 |
+
attn_num_head_channels,
|
21 |
+
resnet_groups=None,
|
22 |
+
cross_attention_dim=None,
|
23 |
+
downsample_padding=None,
|
24 |
+
dual_cross_attention=False,
|
25 |
+
use_linear_projection=False,
|
26 |
+
only_cross_attention=False,
|
27 |
+
upcast_attention=False,
|
28 |
+
resnet_time_scale_shift="default",
|
29 |
+
|
30 |
+
unet_use_cross_frame_attention=None,
|
31 |
+
unet_use_temporal_attention=None,
|
32 |
+
use_inflated_groupnorm=None,
|
33 |
+
|
34 |
+
use_motion_module=None,
|
35 |
+
use_motion_resnet=None, # not used for current weight
|
36 |
+
|
37 |
+
motion_module_type=None,
|
38 |
+
motion_module_kwargs=None,
|
39 |
+
):
|
40 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith(
|
41 |
+
"UNetRes") else down_block_type
|
42 |
+
if down_block_type == "DownBlock3D":
|
43 |
+
return DownBlock3D(
|
44 |
+
num_layers=num_layers,
|
45 |
+
in_channels=in_channels,
|
46 |
+
out_channels=out_channels,
|
47 |
+
temb_channels=temb_channels,
|
48 |
+
add_downsample=add_downsample,
|
49 |
+
resnet_eps=resnet_eps,
|
50 |
+
resnet_act_fn=resnet_act_fn,
|
51 |
+
resnet_groups=resnet_groups,
|
52 |
+
downsample_padding=downsample_padding,
|
53 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
54 |
+
|
55 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
56 |
+
|
57 |
+
use_motion_module=use_motion_module,
|
58 |
+
motion_module_type=motion_module_type,
|
59 |
+
motion_module_kwargs=motion_module_kwargs,
|
60 |
+
)
|
61 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
62 |
+
if cross_attention_dim is None:
|
63 |
+
raise ValueError(
|
64 |
+
"cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
65 |
+
return CrossAttnDownBlock3D(
|
66 |
+
num_layers=num_layers,
|
67 |
+
in_channels=in_channels,
|
68 |
+
out_channels=out_channels,
|
69 |
+
temb_channels=temb_channels,
|
70 |
+
add_downsample=add_downsample,
|
71 |
+
resnet_eps=resnet_eps,
|
72 |
+
resnet_act_fn=resnet_act_fn,
|
73 |
+
resnet_groups=resnet_groups,
|
74 |
+
downsample_padding=downsample_padding,
|
75 |
+
cross_attention_dim=cross_attention_dim,
|
76 |
+
attn_num_head_channels=attn_num_head_channels,
|
77 |
+
dual_cross_attention=dual_cross_attention,
|
78 |
+
use_linear_projection=use_linear_projection,
|
79 |
+
only_cross_attention=only_cross_attention,
|
80 |
+
upcast_attention=upcast_attention,
|
81 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
82 |
+
|
83 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
84 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
85 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
86 |
+
|
87 |
+
use_motion_module=use_motion_module,
|
88 |
+
use_motion_resnet=use_motion_resnet,
|
89 |
+
motion_module_type=motion_module_type,
|
90 |
+
motion_module_kwargs=motion_module_kwargs,
|
91 |
+
)
|
92 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
93 |
+
|
94 |
+
|
95 |
+
def get_up_block(
|
96 |
+
up_block_type,
|
97 |
+
num_layers,
|
98 |
+
in_channels,
|
99 |
+
out_channels,
|
100 |
+
prev_output_channel,
|
101 |
+
temb_channels,
|
102 |
+
add_upsample,
|
103 |
+
resnet_eps,
|
104 |
+
resnet_act_fn,
|
105 |
+
attn_num_head_channels,
|
106 |
+
resnet_groups=None,
|
107 |
+
cross_attention_dim=None,
|
108 |
+
dual_cross_attention=False,
|
109 |
+
use_linear_projection=False,
|
110 |
+
only_cross_attention=False,
|
111 |
+
upcast_attention=False,
|
112 |
+
resnet_time_scale_shift="default",
|
113 |
+
|
114 |
+
unet_use_cross_frame_attention=None,
|
115 |
+
unet_use_temporal_attention=None,
|
116 |
+
use_inflated_groupnorm=None,
|
117 |
+
|
118 |
+
use_motion_module=None,
|
119 |
+
use_motion_resnet=None,
|
120 |
+
motion_module_type=None,
|
121 |
+
motion_module_kwargs=None,
|
122 |
+
):
|
123 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith(
|
124 |
+
"UNetRes") else up_block_type
|
125 |
+
if up_block_type == "UpBlock3D":
|
126 |
+
return UpBlock3D(
|
127 |
+
num_layers=num_layers,
|
128 |
+
in_channels=in_channels,
|
129 |
+
out_channels=out_channels,
|
130 |
+
prev_output_channel=prev_output_channel,
|
131 |
+
temb_channels=temb_channels,
|
132 |
+
add_upsample=add_upsample,
|
133 |
+
resnet_eps=resnet_eps,
|
134 |
+
resnet_act_fn=resnet_act_fn,
|
135 |
+
resnet_groups=resnet_groups,
|
136 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
137 |
+
|
138 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
139 |
+
|
140 |
+
use_motion_module=use_motion_module,
|
141 |
+
motion_module_type=motion_module_type,
|
142 |
+
motion_module_kwargs=motion_module_kwargs,
|
143 |
+
)
|
144 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
145 |
+
if cross_attention_dim is None:
|
146 |
+
raise ValueError(
|
147 |
+
"cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
148 |
+
return CrossAttnUpBlock3D(
|
149 |
+
num_layers=num_layers,
|
150 |
+
in_channels=in_channels,
|
151 |
+
out_channels=out_channels,
|
152 |
+
prev_output_channel=prev_output_channel,
|
153 |
+
temb_channels=temb_channels,
|
154 |
+
add_upsample=add_upsample,
|
155 |
+
resnet_eps=resnet_eps,
|
156 |
+
resnet_act_fn=resnet_act_fn,
|
157 |
+
resnet_groups=resnet_groups,
|
158 |
+
cross_attention_dim=cross_attention_dim,
|
159 |
+
attn_num_head_channels=attn_num_head_channels,
|
160 |
+
dual_cross_attention=dual_cross_attention,
|
161 |
+
use_linear_projection=use_linear_projection,
|
162 |
+
only_cross_attention=only_cross_attention,
|
163 |
+
upcast_attention=upcast_attention,
|
164 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
165 |
+
|
166 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
167 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
168 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
169 |
+
|
170 |
+
use_motion_module=use_motion_module,
|
171 |
+
use_motion_resnet=use_motion_resnet,
|
172 |
+
motion_module_type=motion_module_type,
|
173 |
+
motion_module_kwargs=motion_module_kwargs,
|
174 |
+
)
|
175 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
176 |
+
|
177 |
+
|
178 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
in_channels: int,
|
182 |
+
temb_channels: int,
|
183 |
+
dropout: float = 0.0,
|
184 |
+
num_layers: int = 1,
|
185 |
+
resnet_eps: float = 1e-6,
|
186 |
+
resnet_time_scale_shift: str = "default",
|
187 |
+
resnet_act_fn: str = "swish",
|
188 |
+
resnet_groups: int = 32,
|
189 |
+
resnet_pre_norm: bool = True,
|
190 |
+
attn_num_head_channels=1,
|
191 |
+
output_scale_factor=1.0,
|
192 |
+
cross_attention_dim=1280,
|
193 |
+
dual_cross_attention=False,
|
194 |
+
use_linear_projection=False,
|
195 |
+
upcast_attention=False,
|
196 |
+
|
197 |
+
unet_use_cross_frame_attention=None,
|
198 |
+
unet_use_temporal_attention=None,
|
199 |
+
use_inflated_groupnorm=None,
|
200 |
+
|
201 |
+
use_motion_module=None,
|
202 |
+
use_motion_resnet=None,
|
203 |
+
|
204 |
+
motion_module_type=None,
|
205 |
+
motion_module_kwargs=None,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
|
209 |
+
self.has_cross_attention = True
|
210 |
+
self.attn_num_head_channels = attn_num_head_channels
|
211 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
212 |
+
in_channels // 4, 32)
|
213 |
+
|
214 |
+
# there is always at least one resnet
|
215 |
+
resnets = [
|
216 |
+
ResnetBlock3D(
|
217 |
+
in_channels=in_channels,
|
218 |
+
out_channels=in_channels,
|
219 |
+
temb_channels=temb_channels,
|
220 |
+
eps=resnet_eps,
|
221 |
+
groups=resnet_groups,
|
222 |
+
dropout=dropout,
|
223 |
+
time_embedding_norm=resnet_time_scale_shift,
|
224 |
+
non_linearity=resnet_act_fn,
|
225 |
+
output_scale_factor=output_scale_factor,
|
226 |
+
pre_norm=resnet_pre_norm,
|
227 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
228 |
+
)
|
229 |
+
]
|
230 |
+
motion_resnets = [
|
231 |
+
ResnetBlock3D(
|
232 |
+
in_channels=in_channels,
|
233 |
+
out_channels=in_channels,
|
234 |
+
temb_channels=temb_channels,
|
235 |
+
eps=resnet_eps,
|
236 |
+
groups=resnet_groups,
|
237 |
+
dropout=dropout,
|
238 |
+
time_embedding_norm=resnet_time_scale_shift,
|
239 |
+
non_linearity=resnet_act_fn,
|
240 |
+
output_scale_factor=output_scale_factor,
|
241 |
+
pre_norm=resnet_pre_norm,
|
242 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
243 |
+
use_temporal_conv=True,
|
244 |
+
use_temporal_mixer=True,
|
245 |
+
) if use_motion_resnet else None
|
246 |
+
]
|
247 |
+
|
248 |
+
attentions = []
|
249 |
+
motion_modules = []
|
250 |
+
|
251 |
+
for _ in range(num_layers):
|
252 |
+
if dual_cross_attention:
|
253 |
+
raise NotImplementedError
|
254 |
+
attentions.append(
|
255 |
+
Transformer3DModel(
|
256 |
+
attn_num_head_channels,
|
257 |
+
in_channels // attn_num_head_channels,
|
258 |
+
in_channels=in_channels,
|
259 |
+
num_layers=1,
|
260 |
+
cross_attention_dim=cross_attention_dim,
|
261 |
+
norm_num_groups=resnet_groups,
|
262 |
+
use_linear_projection=use_linear_projection,
|
263 |
+
upcast_attention=upcast_attention,
|
264 |
+
|
265 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
266 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
267 |
+
)
|
268 |
+
)
|
269 |
+
motion_modules.append(
|
270 |
+
get_motion_module(
|
271 |
+
in_channels=in_channels,
|
272 |
+
motion_module_type=motion_module_type,
|
273 |
+
motion_module_kwargs=motion_module_kwargs,
|
274 |
+
) if use_motion_module else None
|
275 |
+
)
|
276 |
+
resnets.append(
|
277 |
+
ResnetBlock3D(
|
278 |
+
in_channels=in_channels,
|
279 |
+
out_channels=in_channels,
|
280 |
+
temb_channels=temb_channels,
|
281 |
+
eps=resnet_eps,
|
282 |
+
groups=resnet_groups,
|
283 |
+
dropout=dropout,
|
284 |
+
time_embedding_norm=resnet_time_scale_shift,
|
285 |
+
non_linearity=resnet_act_fn,
|
286 |
+
output_scale_factor=output_scale_factor,
|
287 |
+
pre_norm=resnet_pre_norm,
|
288 |
+
|
289 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
290 |
+
)
|
291 |
+
)
|
292 |
+
motion_resnets.append(
|
293 |
+
ResnetBlock3D(
|
294 |
+
in_channels=in_channels,
|
295 |
+
out_channels=in_channels,
|
296 |
+
temb_channels=temb_channels,
|
297 |
+
eps=resnet_eps,
|
298 |
+
groups=resnet_groups,
|
299 |
+
dropout=dropout,
|
300 |
+
time_embedding_norm=resnet_time_scale_shift,
|
301 |
+
non_linearity=resnet_act_fn,
|
302 |
+
output_scale_factor=output_scale_factor,
|
303 |
+
pre_norm=resnet_pre_norm,
|
304 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
305 |
+
use_temporal_conv=True,
|
306 |
+
use_temporal_mixer=True,
|
307 |
+
) if use_motion_resnet else None
|
308 |
+
)
|
309 |
+
|
310 |
+
self.attentions = nn.ModuleList(attentions)
|
311 |
+
self.resnets = nn.ModuleList(resnets)
|
312 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
313 |
+
self.motion_resnets = nn.ModuleList(motion_resnets)
|
314 |
+
|
315 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
316 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
317 |
+
hidden_states = self.motion_resnets[0](
|
318 |
+
hidden_states, temb) if self.motion_resnets[0] is not None else hidden_states
|
319 |
+
|
320 |
+
for attn, resnet, motion_module, motion_resnet in zip(self.attentions, self.resnets[1:], self.motion_modules, self.motion_resnets[1:]):
|
321 |
+
hidden_states = attn(
|
322 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
323 |
+
hidden_states = motion_module(
|
324 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
325 |
+
hidden_states = resnet(hidden_states, temb)
|
326 |
+
hidden_states = motion_resnet(
|
327 |
+
hidden_states, temb) if motion_resnet is not None else hidden_states
|
328 |
+
|
329 |
+
return hidden_states
|
330 |
+
|
331 |
+
|
332 |
+
class CrossAttnDownBlock3D(nn.Module):
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
in_channels: int,
|
336 |
+
out_channels: int,
|
337 |
+
temb_channels: int,
|
338 |
+
dropout: float = 0.0,
|
339 |
+
num_layers: int = 1,
|
340 |
+
resnet_eps: float = 1e-6,
|
341 |
+
resnet_time_scale_shift: str = "default",
|
342 |
+
resnet_act_fn: str = "swish",
|
343 |
+
resnet_groups: int = 32,
|
344 |
+
resnet_pre_norm: bool = True,
|
345 |
+
attn_num_head_channels=1,
|
346 |
+
cross_attention_dim=1280,
|
347 |
+
output_scale_factor=1.0,
|
348 |
+
downsample_padding=1,
|
349 |
+
add_downsample=True,
|
350 |
+
dual_cross_attention=False,
|
351 |
+
use_linear_projection=False,
|
352 |
+
only_cross_attention=False,
|
353 |
+
upcast_attention=False,
|
354 |
+
|
355 |
+
unet_use_cross_frame_attention=None,
|
356 |
+
unet_use_temporal_attention=None,
|
357 |
+
use_inflated_groupnorm=None,
|
358 |
+
|
359 |
+
use_motion_module=None,
|
360 |
+
use_motion_resnet=None,
|
361 |
+
|
362 |
+
motion_module_type=None,
|
363 |
+
motion_module_kwargs=None,
|
364 |
+
):
|
365 |
+
super().__init__()
|
366 |
+
resnets = []
|
367 |
+
motion_resnets = []
|
368 |
+
attentions = []
|
369 |
+
motion_modules = []
|
370 |
+
|
371 |
+
self.has_cross_attention = True
|
372 |
+
self.attn_num_head_channels = attn_num_head_channels
|
373 |
+
|
374 |
+
for i in range(num_layers):
|
375 |
+
in_channels = in_channels if i == 0 else out_channels
|
376 |
+
resnets.append(
|
377 |
+
ResnetBlock3D(
|
378 |
+
in_channels=in_channels,
|
379 |
+
out_channels=out_channels,
|
380 |
+
temb_channels=temb_channels,
|
381 |
+
eps=resnet_eps,
|
382 |
+
groups=resnet_groups,
|
383 |
+
dropout=dropout,
|
384 |
+
time_embedding_norm=resnet_time_scale_shift,
|
385 |
+
non_linearity=resnet_act_fn,
|
386 |
+
output_scale_factor=output_scale_factor,
|
387 |
+
pre_norm=resnet_pre_norm,
|
388 |
+
|
389 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
motion_resnets.append(
|
393 |
+
ResnetBlock3D(
|
394 |
+
in_channels=out_channels,
|
395 |
+
out_channels=out_channels,
|
396 |
+
temb_channels=temb_channels,
|
397 |
+
eps=resnet_eps,
|
398 |
+
groups=resnet_groups,
|
399 |
+
dropout=dropout,
|
400 |
+
time_embedding_norm=resnet_time_scale_shift,
|
401 |
+
non_linearity=resnet_act_fn,
|
402 |
+
output_scale_factor=output_scale_factor,
|
403 |
+
pre_norm=resnet_pre_norm,
|
404 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
405 |
+
use_temporal_conv=True,
|
406 |
+
use_temporal_mixer=True,
|
407 |
+
) if use_motion_resnet else None
|
408 |
+
)
|
409 |
+
if dual_cross_attention:
|
410 |
+
raise NotImplementedError
|
411 |
+
attentions.append(
|
412 |
+
Transformer3DModel(
|
413 |
+
attn_num_head_channels,
|
414 |
+
out_channels // attn_num_head_channels,
|
415 |
+
in_channels=out_channels,
|
416 |
+
num_layers=1,
|
417 |
+
cross_attention_dim=cross_attention_dim,
|
418 |
+
norm_num_groups=resnet_groups,
|
419 |
+
use_linear_projection=use_linear_projection,
|
420 |
+
only_cross_attention=only_cross_attention,
|
421 |
+
upcast_attention=upcast_attention,
|
422 |
+
|
423 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
424 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
425 |
+
)
|
426 |
+
)
|
427 |
+
motion_modules.append(
|
428 |
+
get_motion_module(
|
429 |
+
in_channels=out_channels,
|
430 |
+
motion_module_type=motion_module_type,
|
431 |
+
motion_module_kwargs=motion_module_kwargs,
|
432 |
+
) if use_motion_module else None
|
433 |
+
)
|
434 |
+
|
435 |
+
self.attentions = nn.ModuleList(attentions)
|
436 |
+
self.resnets = nn.ModuleList(resnets)
|
437 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
438 |
+
self.motion_resnets = nn.ModuleList(motion_resnets)
|
439 |
+
|
440 |
+
if add_downsample:
|
441 |
+
self.downsamplers = nn.ModuleList(
|
442 |
+
[
|
443 |
+
Downsample3D(
|
444 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
445 |
+
)
|
446 |
+
]
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
self.downsamplers = None
|
450 |
+
|
451 |
+
self.gradient_checkpointing = False
|
452 |
+
|
453 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, img_feature=None):
|
454 |
+
output_states = ()
|
455 |
+
idx = 1
|
456 |
+
for resnet, attn, motion_module, motion_resnet in zip(self.resnets, self.attentions, self.motion_modules, self.motion_resnets):
|
457 |
+
if self.training and self.gradient_checkpointing:
|
458 |
+
|
459 |
+
def create_custom_forward(module, return_dict=None):
|
460 |
+
def custom_forward(*inputs):
|
461 |
+
if return_dict is not None:
|
462 |
+
return module(*inputs, return_dict=return_dict)
|
463 |
+
else:
|
464 |
+
return module(*inputs)
|
465 |
+
|
466 |
+
return custom_forward
|
467 |
+
|
468 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
469 |
+
resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
|
470 |
+
if motion_resnet is not None:
|
471 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
472 |
+
motion_resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
|
473 |
+
|
474 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
475 |
+
create_custom_forward(attn, return_dict=False),
|
476 |
+
hidden_states.requires_grad_(),
|
477 |
+
encoder_hidden_states,
|
478 |
+
use_reentrant=False
|
479 |
+
)[0]
|
480 |
+
|
481 |
+
hidden_states = hidden_states + \
|
482 |
+
img_feature if (
|
483 |
+
img_feature is not None and idx == 2) else hidden_states
|
484 |
+
|
485 |
+
if motion_module is not None:
|
486 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
487 |
+
motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
|
488 |
+
|
489 |
+
else:
|
490 |
+
hidden_states = resnet(hidden_states, temb)
|
491 |
+
|
492 |
+
hidden_states = motion_resnet(
|
493 |
+
hidden_states, temb) if motion_resnet is not None else hidden_states
|
494 |
+
|
495 |
+
hidden_states = attn(
|
496 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
497 |
+
|
498 |
+
hidden_states = hidden_states + \
|
499 |
+
img_feature if (
|
500 |
+
img_feature is not None and idx == 2) else hidden_states
|
501 |
+
|
502 |
+
# add motion module
|
503 |
+
hidden_states = motion_module(
|
504 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
505 |
+
|
506 |
+
idx += 1
|
507 |
+
output_states += (hidden_states,)
|
508 |
+
|
509 |
+
if self.downsamplers is not None:
|
510 |
+
for downsampler in self.downsamplers:
|
511 |
+
hidden_states = downsampler(hidden_states)
|
512 |
+
|
513 |
+
output_states += (hidden_states,)
|
514 |
+
|
515 |
+
return hidden_states, output_states
|
516 |
+
|
517 |
+
|
518 |
+
class DownBlock3D(nn.Module):
|
519 |
+
def __init__(
|
520 |
+
self,
|
521 |
+
in_channels: int,
|
522 |
+
out_channels: int,
|
523 |
+
temb_channels: int,
|
524 |
+
dropout: float = 0.0,
|
525 |
+
num_layers: int = 1,
|
526 |
+
resnet_eps: float = 1e-6,
|
527 |
+
resnet_time_scale_shift: str = "default",
|
528 |
+
resnet_act_fn: str = "swish",
|
529 |
+
resnet_groups: int = 32,
|
530 |
+
resnet_pre_norm: bool = True,
|
531 |
+
output_scale_factor=1.0,
|
532 |
+
add_downsample=True,
|
533 |
+
downsample_padding=1,
|
534 |
+
|
535 |
+
use_inflated_groupnorm=None,
|
536 |
+
|
537 |
+
use_motion_module=None,
|
538 |
+
motion_module_type=None,
|
539 |
+
motion_module_kwargs=None,
|
540 |
+
):
|
541 |
+
super().__init__()
|
542 |
+
resnets = []
|
543 |
+
motion_modules = []
|
544 |
+
|
545 |
+
for i in range(num_layers):
|
546 |
+
in_channels = in_channels if i == 0 else out_channels
|
547 |
+
resnets.append(
|
548 |
+
ResnetBlock3D(
|
549 |
+
in_channels=in_channels,
|
550 |
+
out_channels=out_channels,
|
551 |
+
temb_channels=temb_channels,
|
552 |
+
eps=resnet_eps,
|
553 |
+
groups=resnet_groups,
|
554 |
+
dropout=dropout,
|
555 |
+
time_embedding_norm=resnet_time_scale_shift,
|
556 |
+
non_linearity=resnet_act_fn,
|
557 |
+
output_scale_factor=output_scale_factor,
|
558 |
+
pre_norm=resnet_pre_norm,
|
559 |
+
|
560 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
561 |
+
)
|
562 |
+
)
|
563 |
+
motion_modules.append(
|
564 |
+
get_motion_module(
|
565 |
+
in_channels=out_channels,
|
566 |
+
motion_module_type=motion_module_type,
|
567 |
+
motion_module_kwargs=motion_module_kwargs,
|
568 |
+
) if use_motion_module else None
|
569 |
+
)
|
570 |
+
|
571 |
+
self.resnets = nn.ModuleList(resnets)
|
572 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
573 |
+
|
574 |
+
if add_downsample:
|
575 |
+
self.downsamplers = nn.ModuleList(
|
576 |
+
[
|
577 |
+
Downsample3D(
|
578 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
579 |
+
)
|
580 |
+
]
|
581 |
+
)
|
582 |
+
else:
|
583 |
+
self.downsamplers = None
|
584 |
+
|
585 |
+
self.gradient_checkpointing = False
|
586 |
+
|
587 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, img_feature=None):
|
588 |
+
output_states = ()
|
589 |
+
|
590 |
+
idx = 1
|
591 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
592 |
+
if self.training and self.gradient_checkpointing:
|
593 |
+
def create_custom_forward(module):
|
594 |
+
def custom_forward(*inputs):
|
595 |
+
return module(*inputs)
|
596 |
+
|
597 |
+
return custom_forward
|
598 |
+
|
599 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
600 |
+
resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
|
601 |
+
hidden_states = hidden_states + \
|
602 |
+
img_feature if (
|
603 |
+
img_feature is not None and idx == 2) else hidden_states
|
604 |
+
if motion_module is not None:
|
605 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
606 |
+
motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
|
607 |
+
else:
|
608 |
+
hidden_states = resnet(hidden_states, temb)
|
609 |
+
hidden_states = hidden_states + \
|
610 |
+
img_feature if (
|
611 |
+
img_feature is not None and idx == 2) else hidden_states
|
612 |
+
# add motion module
|
613 |
+
hidden_states = motion_module(
|
614 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
615 |
+
|
616 |
+
output_states += (hidden_states,)
|
617 |
+
idx += 1
|
618 |
+
|
619 |
+
if self.downsamplers is not None:
|
620 |
+
for downsampler in self.downsamplers:
|
621 |
+
hidden_states = downsampler(hidden_states)
|
622 |
+
|
623 |
+
output_states += (hidden_states,)
|
624 |
+
|
625 |
+
return hidden_states, output_states
|
626 |
+
|
627 |
+
|
628 |
+
class CrossAttnUpBlock3D(nn.Module):
|
629 |
+
def __init__(
|
630 |
+
self,
|
631 |
+
in_channels: int,
|
632 |
+
out_channels: int,
|
633 |
+
prev_output_channel: int,
|
634 |
+
temb_channels: int,
|
635 |
+
dropout: float = 0.0,
|
636 |
+
num_layers: int = 1,
|
637 |
+
resnet_eps: float = 1e-6,
|
638 |
+
resnet_time_scale_shift: str = "default",
|
639 |
+
resnet_act_fn: str = "swish",
|
640 |
+
resnet_groups: int = 32,
|
641 |
+
resnet_pre_norm: bool = True,
|
642 |
+
attn_num_head_channels=1,
|
643 |
+
cross_attention_dim=1280,
|
644 |
+
output_scale_factor=1.0,
|
645 |
+
add_upsample=True,
|
646 |
+
dual_cross_attention=False,
|
647 |
+
use_linear_projection=False,
|
648 |
+
only_cross_attention=False,
|
649 |
+
upcast_attention=False,
|
650 |
+
|
651 |
+
unet_use_cross_frame_attention=None,
|
652 |
+
unet_use_temporal_attention=None,
|
653 |
+
use_inflated_groupnorm=None,
|
654 |
+
|
655 |
+
use_motion_module=None,
|
656 |
+
use_motion_resnet=None,
|
657 |
+
|
658 |
+
motion_module_type=None,
|
659 |
+
motion_module_kwargs=None,
|
660 |
+
):
|
661 |
+
super().__init__()
|
662 |
+
resnets = []
|
663 |
+
attentions = []
|
664 |
+
motion_modules = []
|
665 |
+
motion_resnets = []
|
666 |
+
self.has_cross_attention = True
|
667 |
+
self.attn_num_head_channels = attn_num_head_channels
|
668 |
+
|
669 |
+
for i in range(num_layers):
|
670 |
+
res_skip_channels = in_channels if (
|
671 |
+
i == num_layers - 1) else out_channels
|
672 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
673 |
+
|
674 |
+
resnets.append(
|
675 |
+
ResnetBlock3D(
|
676 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
677 |
+
out_channels=out_channels,
|
678 |
+
temb_channels=temb_channels,
|
679 |
+
eps=resnet_eps,
|
680 |
+
groups=resnet_groups,
|
681 |
+
dropout=dropout,
|
682 |
+
time_embedding_norm=resnet_time_scale_shift,
|
683 |
+
non_linearity=resnet_act_fn,
|
684 |
+
output_scale_factor=output_scale_factor,
|
685 |
+
pre_norm=resnet_pre_norm,
|
686 |
+
|
687 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
688 |
+
)
|
689 |
+
)
|
690 |
+
motion_resnets.append(
|
691 |
+
ResnetBlock3D(
|
692 |
+
in_channels=out_channels,
|
693 |
+
out_channels=out_channels,
|
694 |
+
temb_channels=temb_channels,
|
695 |
+
eps=resnet_eps,
|
696 |
+
groups=resnet_groups,
|
697 |
+
dropout=dropout,
|
698 |
+
time_embedding_norm=resnet_time_scale_shift,
|
699 |
+
non_linearity=resnet_act_fn,
|
700 |
+
output_scale_factor=output_scale_factor,
|
701 |
+
pre_norm=resnet_pre_norm,
|
702 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
703 |
+
use_temporal_conv=True,
|
704 |
+
use_temporal_mixer=True
|
705 |
+
) if use_motion_resnet else None
|
706 |
+
)
|
707 |
+
|
708 |
+
if dual_cross_attention:
|
709 |
+
raise NotImplementedError
|
710 |
+
attentions.append(
|
711 |
+
Transformer3DModel(
|
712 |
+
attn_num_head_channels,
|
713 |
+
out_channels // attn_num_head_channels,
|
714 |
+
in_channels=out_channels,
|
715 |
+
num_layers=1,
|
716 |
+
cross_attention_dim=cross_attention_dim,
|
717 |
+
norm_num_groups=resnet_groups,
|
718 |
+
use_linear_projection=use_linear_projection,
|
719 |
+
only_cross_attention=only_cross_attention,
|
720 |
+
upcast_attention=upcast_attention,
|
721 |
+
|
722 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
723 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
724 |
+
)
|
725 |
+
)
|
726 |
+
motion_modules.append(
|
727 |
+
get_motion_module(
|
728 |
+
in_channels=out_channels,
|
729 |
+
motion_module_type=motion_module_type,
|
730 |
+
motion_module_kwargs=motion_module_kwargs,
|
731 |
+
) if use_motion_module else None
|
732 |
+
)
|
733 |
+
|
734 |
+
self.attentions = nn.ModuleList(attentions)
|
735 |
+
self.resnets = nn.ModuleList(resnets)
|
736 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
737 |
+
self.motion_resnets = nn.ModuleList(motion_resnets)
|
738 |
+
|
739 |
+
if add_upsample:
|
740 |
+
self.upsamplers = nn.ModuleList(
|
741 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
742 |
+
else:
|
743 |
+
self.upsamplers = None
|
744 |
+
|
745 |
+
self.gradient_checkpointing = False
|
746 |
+
|
747 |
+
def forward(
|
748 |
+
self,
|
749 |
+
hidden_states,
|
750 |
+
res_hidden_states_tuple,
|
751 |
+
temb=None,
|
752 |
+
encoder_hidden_states=None,
|
753 |
+
upsample_size=None,
|
754 |
+
attention_mask=None,
|
755 |
+
):
|
756 |
+
for resnet, attn, motion_module, motion_resnet in zip(self.resnets, self.attentions, self.motion_modules, self.motion_resnets):
|
757 |
+
# pop res hidden states
|
758 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
759 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
760 |
+
hidden_states = torch.cat(
|
761 |
+
[hidden_states, res_hidden_states], dim=1)
|
762 |
+
|
763 |
+
if self.training and self.gradient_checkpointing:
|
764 |
+
|
765 |
+
def create_custom_forward(module, return_dict=None):
|
766 |
+
def custom_forward(*inputs):
|
767 |
+
if return_dict is not None:
|
768 |
+
return module(*inputs, return_dict=return_dict)
|
769 |
+
else:
|
770 |
+
return module(*inputs)
|
771 |
+
|
772 |
+
return custom_forward
|
773 |
+
|
774 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
775 |
+
resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
|
776 |
+
if motion_resnet is not None:
|
777 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
778 |
+
motion_resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
|
779 |
+
|
780 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
781 |
+
create_custom_forward(attn, return_dict=False),
|
782 |
+
hidden_states.requires_grad_(),
|
783 |
+
encoder_hidden_states,
|
784 |
+
use_reentrant=False,
|
785 |
+
)[0]
|
786 |
+
if motion_module is not None:
|
787 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
788 |
+
motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
|
789 |
+
|
790 |
+
else:
|
791 |
+
hidden_states = resnet(hidden_states, temb)
|
792 |
+
hidden_states = motion_resnet(
|
793 |
+
hidden_states, temb) if motion_resnet is not None else hidden_states
|
794 |
+
hidden_states = attn(
|
795 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
796 |
+
|
797 |
+
# add motion module
|
798 |
+
hidden_states = motion_module(
|
799 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
800 |
+
|
801 |
+
if self.upsamplers is not None:
|
802 |
+
for upsampler in self.upsamplers:
|
803 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
804 |
+
|
805 |
+
return hidden_states
|
806 |
+
|
807 |
+
|
808 |
+
class UpBlock3D(nn.Module):
|
809 |
+
def __init__(
|
810 |
+
self,
|
811 |
+
in_channels: int,
|
812 |
+
prev_output_channel: int,
|
813 |
+
out_channels: int,
|
814 |
+
temb_channels: int,
|
815 |
+
dropout: float = 0.0,
|
816 |
+
num_layers: int = 1,
|
817 |
+
resnet_eps: float = 1e-6,
|
818 |
+
resnet_time_scale_shift: str = "default",
|
819 |
+
resnet_act_fn: str = "swish",
|
820 |
+
resnet_groups: int = 32,
|
821 |
+
resnet_pre_norm: bool = True,
|
822 |
+
output_scale_factor=1.0,
|
823 |
+
add_upsample=True,
|
824 |
+
|
825 |
+
use_inflated_groupnorm=None,
|
826 |
+
|
827 |
+
use_motion_module=None,
|
828 |
+
motion_module_type=None,
|
829 |
+
motion_module_kwargs=None,
|
830 |
+
):
|
831 |
+
super().__init__()
|
832 |
+
resnets = []
|
833 |
+
motion_modules = []
|
834 |
+
|
835 |
+
for i in range(num_layers):
|
836 |
+
res_skip_channels = in_channels if (
|
837 |
+
i == num_layers - 1) else out_channels
|
838 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
839 |
+
|
840 |
+
resnets.append(
|
841 |
+
ResnetBlock3D(
|
842 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
843 |
+
out_channels=out_channels,
|
844 |
+
temb_channels=temb_channels,
|
845 |
+
eps=resnet_eps,
|
846 |
+
groups=resnet_groups,
|
847 |
+
dropout=dropout,
|
848 |
+
time_embedding_norm=resnet_time_scale_shift,
|
849 |
+
non_linearity=resnet_act_fn,
|
850 |
+
output_scale_factor=output_scale_factor,
|
851 |
+
pre_norm=resnet_pre_norm,
|
852 |
+
|
853 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
854 |
+
)
|
855 |
+
)
|
856 |
+
motion_modules.append(
|
857 |
+
get_motion_module(
|
858 |
+
in_channels=out_channels,
|
859 |
+
motion_module_type=motion_module_type,
|
860 |
+
motion_module_kwargs=motion_module_kwargs,
|
861 |
+
) if use_motion_module else None
|
862 |
+
)
|
863 |
+
|
864 |
+
self.resnets = nn.ModuleList(resnets)
|
865 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
866 |
+
|
867 |
+
if add_upsample:
|
868 |
+
self.upsamplers = nn.ModuleList(
|
869 |
+
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
870 |
+
else:
|
871 |
+
self.upsamplers = None
|
872 |
+
|
873 |
+
self.gradient_checkpointing = False
|
874 |
+
|
875 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
|
876 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
877 |
+
# pop res hidden states
|
878 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
879 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
880 |
+
hidden_states = torch.cat(
|
881 |
+
[hidden_states, res_hidden_states], dim=1)
|
882 |
+
|
883 |
+
if self.training and self.gradient_checkpointing:
|
884 |
+
def create_custom_forward(module):
|
885 |
+
def custom_forward(*inputs):
|
886 |
+
return module(*inputs)
|
887 |
+
|
888 |
+
return custom_forward
|
889 |
+
|
890 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
891 |
+
resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
|
892 |
+
if motion_module is not None:
|
893 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
|
894 |
+
motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
|
895 |
+
else:
|
896 |
+
hidden_states = resnet(hidden_states, temb)
|
897 |
+
hidden_states = motion_module(
|
898 |
+
hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
899 |
+
|
900 |
+
if self.upsamplers is not None:
|
901 |
+
for upsampler in self.upsamplers:
|
902 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
903 |
+
|
904 |
+
return hidden_states
|
animatelcm/pipelines/pipeline_animation.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Callable, List, Optional, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from diffusers.utils import is_accelerate_available
|
12 |
+
from packaging import version
|
13 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
14 |
+
|
15 |
+
from diffusers.configuration_utils import FrozenDict
|
16 |
+
from diffusers.models import AutoencoderKL
|
17 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
18 |
+
from diffusers.schedulers import (
|
19 |
+
DDIMScheduler,
|
20 |
+
DPMSolverMultistepScheduler,
|
21 |
+
EulerAncestralDiscreteScheduler,
|
22 |
+
EulerDiscreteScheduler,
|
23 |
+
LMSDiscreteScheduler,
|
24 |
+
PNDMScheduler,
|
25 |
+
)
|
26 |
+
from diffusers.utils import deprecate, logging, BaseOutput
|
27 |
+
|
28 |
+
from einops import rearrange
|
29 |
+
|
30 |
+
from ..models.unet import UNet3DConditionModel
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class AnimationPipelineOutput(BaseOutput):
|
37 |
+
videos: Union[torch.Tensor, np.ndarray]
|
38 |
+
|
39 |
+
|
40 |
+
class AnimationPipeline(DiffusionPipeline):
|
41 |
+
_optional_components = []
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
vae: AutoencoderKL,
|
46 |
+
text_encoder: CLIPTextModel,
|
47 |
+
tokenizer: CLIPTokenizer,
|
48 |
+
unet: UNet3DConditionModel,
|
49 |
+
scheduler: Union[
|
50 |
+
DDIMScheduler,
|
51 |
+
PNDMScheduler,
|
52 |
+
LMSDiscreteScheduler,
|
53 |
+
EulerDiscreteScheduler,
|
54 |
+
EulerAncestralDiscreteScheduler,
|
55 |
+
DPMSolverMultistepScheduler,
|
56 |
+
],
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
61 |
+
deprecation_message = (
|
62 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
63 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
64 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
65 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
66 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
67 |
+
" file"
|
68 |
+
)
|
69 |
+
deprecate("steps_offset!=1", "1.0.0",
|
70 |
+
deprecation_message, standard_warn=False)
|
71 |
+
new_config = dict(scheduler.config)
|
72 |
+
new_config["steps_offset"] = 1
|
73 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
74 |
+
|
75 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
76 |
+
deprecation_message = (
|
77 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
78 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
79 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
80 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
81 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
82 |
+
)
|
83 |
+
deprecate("clip_sample not set", "1.0.0",
|
84 |
+
deprecation_message, standard_warn=False)
|
85 |
+
new_config = dict(scheduler.config)
|
86 |
+
new_config["clip_sample"] = False
|
87 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
88 |
+
|
89 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
90 |
+
version.parse(unet.config._diffusers_version).base_version
|
91 |
+
) < version.parse("0.9.0.dev0")
|
92 |
+
is_unet_sample_size_less_64 = hasattr(
|
93 |
+
unet.config, "sample_size") and unet.config.sample_size < 64
|
94 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
95 |
+
deprecation_message = (
|
96 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
97 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
98 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
99 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
100 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
101 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
102 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
103 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
104 |
+
" the `unet/config.json` file"
|
105 |
+
)
|
106 |
+
deprecate("sample_size<64", "1.0.0",
|
107 |
+
deprecation_message, standard_warn=False)
|
108 |
+
new_config = dict(unet.config)
|
109 |
+
new_config["sample_size"] = 64
|
110 |
+
unet._internal_dict = FrozenDict(new_config)
|
111 |
+
|
112 |
+
self.register_modules(
|
113 |
+
vae=vae,
|
114 |
+
text_encoder=text_encoder,
|
115 |
+
tokenizer=tokenizer,
|
116 |
+
unet=unet,
|
117 |
+
scheduler=scheduler,
|
118 |
+
)
|
119 |
+
self.vae_scale_factor = 2 ** (
|
120 |
+
len(self.vae.config.block_out_channels) - 1)
|
121 |
+
|
122 |
+
def enable_vae_slicing(self):
|
123 |
+
self.vae.enable_slicing()
|
124 |
+
|
125 |
+
def disable_vae_slicing(self):
|
126 |
+
self.vae.disable_slicing()
|
127 |
+
|
128 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
129 |
+
if is_accelerate_available():
|
130 |
+
from accelerate import cpu_offload
|
131 |
+
else:
|
132 |
+
raise ImportError(
|
133 |
+
"Please install accelerate via `pip install accelerate`")
|
134 |
+
|
135 |
+
device = torch.device(f"cuda:{gpu_id}")
|
136 |
+
|
137 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
138 |
+
if cpu_offloaded_model is not None:
|
139 |
+
cpu_offload(cpu_offloaded_model, device)
|
140 |
+
|
141 |
+
@property
|
142 |
+
def _execution_device(self):
|
143 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
144 |
+
return self.device
|
145 |
+
for module in self.unet.modules():
|
146 |
+
if (
|
147 |
+
hasattr(module, "_hf_hook")
|
148 |
+
and hasattr(module._hf_hook, "execution_device")
|
149 |
+
and module._hf_hook.execution_device is not None
|
150 |
+
):
|
151 |
+
return torch.device(module._hf_hook.execution_device)
|
152 |
+
return self.device
|
153 |
+
|
154 |
+
def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
155 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
156 |
+
|
157 |
+
text_inputs = self.tokenizer(
|
158 |
+
prompt,
|
159 |
+
padding="max_length",
|
160 |
+
max_length=self.tokenizer.model_max_length,
|
161 |
+
truncation=True,
|
162 |
+
return_tensors="pt",
|
163 |
+
)
|
164 |
+
text_input_ids = text_inputs.input_ids
|
165 |
+
untruncated_ids = self.tokenizer(
|
166 |
+
prompt, padding="longest", return_tensors="pt").input_ids
|
167 |
+
|
168 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
169 |
+
removed_text = self.tokenizer.batch_decode(
|
170 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
|
171 |
+
logger.warning(
|
172 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
173 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
174 |
+
)
|
175 |
+
|
176 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
177 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
178 |
+
else:
|
179 |
+
attention_mask = None
|
180 |
+
|
181 |
+
text_embeddings = self.text_encoder(
|
182 |
+
text_input_ids.to(device),
|
183 |
+
attention_mask=attention_mask,
|
184 |
+
)
|
185 |
+
text_embeddings = text_embeddings[0]
|
186 |
+
|
187 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
188 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
189 |
+
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
190 |
+
text_embeddings = text_embeddings.view(
|
191 |
+
bs_embed * num_videos_per_prompt, seq_len, -1)
|
192 |
+
|
193 |
+
# get unconditional embeddings for classifier free guidance
|
194 |
+
if do_classifier_free_guidance:
|
195 |
+
uncond_tokens: List[str]
|
196 |
+
if negative_prompt is None:
|
197 |
+
uncond_tokens = [""] * batch_size
|
198 |
+
elif type(prompt) is not type(negative_prompt):
|
199 |
+
raise TypeError(
|
200 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
201 |
+
f" {type(prompt)}."
|
202 |
+
)
|
203 |
+
elif isinstance(negative_prompt, str):
|
204 |
+
uncond_tokens = [negative_prompt]
|
205 |
+
elif batch_size != len(negative_prompt):
|
206 |
+
raise ValueError(
|
207 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
208 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
209 |
+
" the batch size of `prompt`."
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
uncond_tokens = negative_prompt
|
213 |
+
|
214 |
+
max_length = text_input_ids.shape[-1]
|
215 |
+
uncond_input = self.tokenizer(
|
216 |
+
uncond_tokens,
|
217 |
+
padding="max_length",
|
218 |
+
max_length=max_length,
|
219 |
+
truncation=True,
|
220 |
+
return_tensors="pt",
|
221 |
+
)
|
222 |
+
|
223 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
224 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
225 |
+
else:
|
226 |
+
attention_mask = None
|
227 |
+
|
228 |
+
uncond_embeddings = self.text_encoder(
|
229 |
+
uncond_input.input_ids.to(device),
|
230 |
+
attention_mask=attention_mask,
|
231 |
+
)
|
232 |
+
uncond_embeddings = uncond_embeddings[0]
|
233 |
+
|
234 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
235 |
+
seq_len = uncond_embeddings.shape[1]
|
236 |
+
uncond_embeddings = uncond_embeddings.repeat(
|
237 |
+
1, num_videos_per_prompt, 1)
|
238 |
+
uncond_embeddings = uncond_embeddings.view(
|
239 |
+
batch_size * num_videos_per_prompt, seq_len, -1)
|
240 |
+
|
241 |
+
# For classifier free guidance, we need to do two forward passes.
|
242 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
243 |
+
# to avoid doing two forward passes
|
244 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
245 |
+
|
246 |
+
return text_embeddings
|
247 |
+
|
248 |
+
def decode_latents(self, latents):
|
249 |
+
video_length = latents.shape[2]
|
250 |
+
latents = 1 / 0.18215 * latents
|
251 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
252 |
+
# video = self.vae.decode(latents).sample
|
253 |
+
video = []
|
254 |
+
for frame_idx in tqdm(range(latents.shape[0])):
|
255 |
+
video.append(self.vae.decode(
|
256 |
+
latents[frame_idx:frame_idx+1]).sample)
|
257 |
+
video = torch.cat(video)
|
258 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
259 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
260 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
261 |
+
video = video.cpu().float().numpy()
|
262 |
+
return video
|
263 |
+
|
264 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
265 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
266 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
267 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
268 |
+
# and should be between [0, 1]
|
269 |
+
|
270 |
+
accepts_eta = "eta" in set(inspect.signature(
|
271 |
+
self.scheduler.step).parameters.keys())
|
272 |
+
extra_step_kwargs = {}
|
273 |
+
if accepts_eta:
|
274 |
+
extra_step_kwargs["eta"] = eta
|
275 |
+
|
276 |
+
# check if the scheduler accepts generator
|
277 |
+
accepts_generator = "generator" in set(
|
278 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
279 |
+
if accepts_generator:
|
280 |
+
extra_step_kwargs["generator"] = generator
|
281 |
+
return extra_step_kwargs
|
282 |
+
|
283 |
+
def check_inputs(self, prompt, height, width, callback_steps):
|
284 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
285 |
+
raise ValueError(
|
286 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
287 |
+
|
288 |
+
if height % 8 != 0 or width % 8 != 0:
|
289 |
+
raise ValueError(
|
290 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
291 |
+
|
292 |
+
if (callback_steps is None) or (
|
293 |
+
callback_steps is not None and (not isinstance(
|
294 |
+
callback_steps, int) or callback_steps <= 0)
|
295 |
+
):
|
296 |
+
raise ValueError(
|
297 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
298 |
+
f" {type(callback_steps)}."
|
299 |
+
)
|
300 |
+
|
301 |
+
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
302 |
+
shape = (batch_size, num_channels_latents, video_length, height //
|
303 |
+
self.vae_scale_factor, width // self.vae_scale_factor)
|
304 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
305 |
+
raise ValueError(
|
306 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
307 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
308 |
+
)
|
309 |
+
if latents is None:
|
310 |
+
rand_device = "cpu" if device.type == "mps" else device
|
311 |
+
|
312 |
+
if isinstance(generator, list):
|
313 |
+
shape = shape
|
314 |
+
# shape = (1,) + shape[1:]
|
315 |
+
latents = [
|
316 |
+
torch.randn(
|
317 |
+
shape, generator=generator[i], device=rand_device, dtype=dtype)
|
318 |
+
for i in range(batch_size)
|
319 |
+
]
|
320 |
+
latents = torch.cat(latents, dim=0).to(device)
|
321 |
+
else:
|
322 |
+
latents = torch.randn(
|
323 |
+
shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
324 |
+
else:
|
325 |
+
if latents.shape != shape:
|
326 |
+
raise ValueError(
|
327 |
+
f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
328 |
+
latents = latents.to(device)
|
329 |
+
|
330 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
331 |
+
latents = latents * self.scheduler.init_noise_sigma
|
332 |
+
return latents
|
333 |
+
|
334 |
+
@torch.no_grad()
|
335 |
+
def __call__(
|
336 |
+
self,
|
337 |
+
prompt: Union[str, List[str]],
|
338 |
+
video_length: Optional[int],
|
339 |
+
height: Optional[int] = None,
|
340 |
+
width: Optional[int] = None,
|
341 |
+
num_inference_steps: int = 50,
|
342 |
+
guidance_scale: float = 7.5,
|
343 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
344 |
+
num_videos_per_prompt: Optional[int] = 1,
|
345 |
+
eta: float = 0.0,
|
346 |
+
generator: Optional[Union[torch.Generator,
|
347 |
+
List[torch.Generator]]] = None,
|
348 |
+
latents: Optional[torch.FloatTensor] = None,
|
349 |
+
output_type: Optional[str] = "tensor",
|
350 |
+
return_dict: bool = True,
|
351 |
+
callback: Optional[Callable[[
|
352 |
+
int, int, torch.FloatTensor], None]] = None,
|
353 |
+
callback_steps: Optional[int] = 1,
|
354 |
+
do_classifier_free_guidance: bool = True,
|
355 |
+
image_path: str = None, # not ready
|
356 |
+
control_path: str = None, # not ready
|
357 |
+
sparse_control: str = False, # not ready
|
358 |
+
**kwargs,
|
359 |
+
):
|
360 |
+
|
361 |
+
# Default height and width to unet
|
362 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
363 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
364 |
+
|
365 |
+
# Check inputs. Raise error if not correct
|
366 |
+
self.check_inputs(prompt, height, width, callback_steps)
|
367 |
+
|
368 |
+
# Define call parameters
|
369 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
370 |
+
batch_size = 1
|
371 |
+
if latents is not None:
|
372 |
+
batch_size = latents.shape[0]
|
373 |
+
if isinstance(prompt, list):
|
374 |
+
batch_size = len(prompt)
|
375 |
+
|
376 |
+
device = self._execution_device
|
377 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
378 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
379 |
+
# corresponds to doing no classifier free guidance.
|
380 |
+
do_classifier_free_guidance = (
|
381 |
+
guidance_scale > 1.0) & do_classifier_free_guidance
|
382 |
+
|
383 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
384 |
+
if negative_prompt is not None:
|
385 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [
|
386 |
+
negative_prompt] * batch_size
|
387 |
+
text_embeddings = self._encode_prompt(
|
388 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
389 |
+
)
|
390 |
+
|
391 |
+
# Prepare timesteps
|
392 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
393 |
+
timesteps = self.scheduler.timesteps
|
394 |
+
|
395 |
+
# Prepare latent variables
|
396 |
+
num_channels_latents = self.unet.in_channels
|
397 |
+
latents = self.prepare_latents(
|
398 |
+
batch_size * num_videos_per_prompt,
|
399 |
+
num_channels_latents,
|
400 |
+
video_length,
|
401 |
+
height,
|
402 |
+
width,
|
403 |
+
text_embeddings.dtype,
|
404 |
+
device,
|
405 |
+
generator,
|
406 |
+
latents,
|
407 |
+
)
|
408 |
+
latents_dtype = latents.dtype
|
409 |
+
|
410 |
+
w_embedding = None # not ready
|
411 |
+
|
412 |
+
# Prepare extra step kwargs.
|
413 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
414 |
+
|
415 |
+
# Denoising loop
|
416 |
+
num_warmup_steps = len(timesteps) - \
|
417 |
+
num_inference_steps * self.scheduler.order
|
418 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
419 |
+
for i, t in enumerate(timesteps):
|
420 |
+
# expand the latents if we are doing classifier free guidance
|
421 |
+
latent_model_input = torch.cat(
|
422 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
423 |
+
latent_model_input = self.scheduler.scale_model_input(
|
424 |
+
latent_model_input, t)
|
425 |
+
|
426 |
+
# predict the noise residual
|
427 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
|
428 |
+
time_cond=w_embedding).sample.to(dtype=latents_dtype)
|
429 |
+
|
430 |
+
# perform guidance
|
431 |
+
if do_classifier_free_guidance:
|
432 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
433 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
434 |
+
(noise_pred_text - noise_pred_uncond)
|
435 |
+
|
436 |
+
# compute the previous noisy sample x_t -> x_t-1
|
437 |
+
latents = self.scheduler.step(
|
438 |
+
noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
439 |
+
|
440 |
+
# call the callback, if provided
|
441 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
442 |
+
progress_bar.update()
|
443 |
+
if callback is not None and i % callback_steps == 0:
|
444 |
+
callback(i, t, latents)
|
445 |
+
|
446 |
+
# Post-processing
|
447 |
+
video = self.decode_latents(latents)
|
448 |
+
|
449 |
+
# Convert to tensor
|
450 |
+
if output_type == "tensor":
|
451 |
+
video = torch.from_numpy(video)
|
452 |
+
|
453 |
+
if not return_dict:
|
454 |
+
return video
|
455 |
+
|
456 |
+
return AnimationPipelineOutput(videos=video)
|
animatelcm/scheduler/lcm_scheduler.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
16 |
+
# and https://github.com/hojonathanho/diffusion
|
17 |
+
|
18 |
+
import math
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.utils import BaseOutput, logging
|
27 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class LCMSchedulerOutput(BaseOutput):
|
36 |
+
"""
|
37 |
+
Output class for the scheduler's `step` function output.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
42 |
+
denoising loop.
|
43 |
+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
44 |
+
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
45 |
+
`pred_original_sample` can be used to preview progress or for guidance.
|
46 |
+
"""
|
47 |
+
|
48 |
+
prev_sample: torch.FloatTensor
|
49 |
+
denoised: Optional[torch.FloatTensor] = None
|
50 |
+
|
51 |
+
|
52 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
53 |
+
def betas_for_alpha_bar(
|
54 |
+
num_diffusion_timesteps,
|
55 |
+
max_beta=0.999,
|
56 |
+
alpha_transform_type="cosine",
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
60 |
+
(1-beta) over time from t = [0,1].
|
61 |
+
|
62 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
63 |
+
to that part of the diffusion process.
|
64 |
+
|
65 |
+
|
66 |
+
Args:
|
67 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
68 |
+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
69 |
+
prevent singularities.
|
70 |
+
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
71 |
+
Choose from `cosine` or `exp`
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
75 |
+
"""
|
76 |
+
if alpha_transform_type == "cosine":
|
77 |
+
|
78 |
+
def alpha_bar_fn(t):
|
79 |
+
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
80 |
+
|
81 |
+
elif alpha_transform_type == "exp":
|
82 |
+
|
83 |
+
def alpha_bar_fn(t):
|
84 |
+
return math.exp(t * -12.0)
|
85 |
+
|
86 |
+
else:
|
87 |
+
raise ValueError(
|
88 |
+
f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
89 |
+
|
90 |
+
betas = []
|
91 |
+
for i in range(num_diffusion_timesteps):
|
92 |
+
t1 = i / num_diffusion_timesteps
|
93 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
94 |
+
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
95 |
+
return torch.tensor(betas, dtype=torch.float32)
|
96 |
+
|
97 |
+
|
98 |
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
99 |
+
def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
|
100 |
+
"""
|
101 |
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
102 |
+
|
103 |
+
|
104 |
+
Args:
|
105 |
+
betas (`torch.FloatTensor`):
|
106 |
+
the betas that the scheduler is being initialized with.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
110 |
+
"""
|
111 |
+
# Convert betas to alphas_bar_sqrt
|
112 |
+
alphas = 1.0 - betas
|
113 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
114 |
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
115 |
+
|
116 |
+
# Store old values.
|
117 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
118 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
119 |
+
|
120 |
+
# Shift so the last timestep is zero.
|
121 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
122 |
+
|
123 |
+
# Scale so the first timestep is back to the old value.
|
124 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / \
|
125 |
+
(alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
126 |
+
|
127 |
+
# Convert alphas_bar_sqrt to betas
|
128 |
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
129 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
130 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
131 |
+
betas = 1 - alphas
|
132 |
+
|
133 |
+
return betas
|
134 |
+
|
135 |
+
|
136 |
+
def randn_tensor(
|
137 |
+
shape: Union[Tuple, List],
|
138 |
+
generator: Optional[Union[List["torch.Generator"],
|
139 |
+
"torch.Generator"]] = None,
|
140 |
+
device: Optional["torch.device"] = None,
|
141 |
+
dtype: Optional["torch.dtype"] = None,
|
142 |
+
layout: Optional["torch.layout"] = None,
|
143 |
+
):
|
144 |
+
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
|
145 |
+
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
|
146 |
+
is always created on the CPU.
|
147 |
+
"""
|
148 |
+
# device on which tensor is created defaults to device
|
149 |
+
rand_device = device
|
150 |
+
batch_size = shape[0]
|
151 |
+
|
152 |
+
layout = layout or torch.strided
|
153 |
+
device = device or torch.device("cpu")
|
154 |
+
|
155 |
+
if generator is not None:
|
156 |
+
gen_device_type = generator.device.type if not isinstance(
|
157 |
+
generator, list) else generator[0].device.type
|
158 |
+
if gen_device_type != device.type and gen_device_type == "cpu":
|
159 |
+
rand_device = "cpu"
|
160 |
+
if device != "mps":
|
161 |
+
logger.info(
|
162 |
+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
|
163 |
+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
164 |
+
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
165 |
+
)
|
166 |
+
elif gen_device_type != device.type and gen_device_type == "cuda":
|
167 |
+
raise ValueError(
|
168 |
+
f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
|
169 |
+
|
170 |
+
# make sure generator list of length 1 is treated like a non-list
|
171 |
+
if isinstance(generator, list) and len(generator) == 1:
|
172 |
+
generator = generator[0]
|
173 |
+
|
174 |
+
if isinstance(generator, list):
|
175 |
+
shape = (1,) + shape[1:]
|
176 |
+
latents = [
|
177 |
+
torch.randn(
|
178 |
+
shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
|
179 |
+
for i in range(batch_size)
|
180 |
+
]
|
181 |
+
latents = torch.cat(latents, dim=0).to(device)
|
182 |
+
else:
|
183 |
+
latents = torch.randn(shape, generator=generator,
|
184 |
+
device=rand_device, dtype=dtype, layout=layout).to(device)
|
185 |
+
|
186 |
+
return latents
|
187 |
+
|
188 |
+
|
189 |
+
class LCMScheduler(SchedulerMixin, ConfigMixin):
|
190 |
+
"""
|
191 |
+
`LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
|
192 |
+
non-Markovian guidance.
|
193 |
+
|
194 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
|
195 |
+
attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
|
196 |
+
accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
|
197 |
+
functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
num_train_timesteps (`int`, defaults to 1000):
|
201 |
+
The number of diffusion steps to train the model.
|
202 |
+
beta_start (`float`, defaults to 0.0001):
|
203 |
+
The starting `beta` value of inference.
|
204 |
+
beta_end (`float`, defaults to 0.02):
|
205 |
+
The final `beta` value.
|
206 |
+
beta_schedule (`str`, defaults to `"linear"`):
|
207 |
+
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
208 |
+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
209 |
+
trained_betas (`np.ndarray`, *optional*):
|
210 |
+
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
211 |
+
original_inference_steps (`int`, *optional*, defaults to 50):
|
212 |
+
The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
|
213 |
+
will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
|
214 |
+
clip_sample (`bool`, defaults to `True`):
|
215 |
+
Clip the predicted sample for numerical stability.
|
216 |
+
clip_sample_range (`float`, defaults to 1.0):
|
217 |
+
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
218 |
+
set_alpha_to_one (`bool`, defaults to `True`):
|
219 |
+
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
220 |
+
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
221 |
+
otherwise it uses the alpha value at step 0.
|
222 |
+
steps_offset (`int`, defaults to 0):
|
223 |
+
An offset added to the inference steps. You can use a combination of `offset=1` and
|
224 |
+
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
225 |
+
Diffusion.
|
226 |
+
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
227 |
+
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
228 |
+
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
229 |
+
Video](https://imagen.research.google/video/paper.pdf) paper).
|
230 |
+
thresholding (`bool`, defaults to `False`):
|
231 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
232 |
+
as Stable Diffusion.
|
233 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
234 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
235 |
+
sample_max_value (`float`, defaults to 1.0):
|
236 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
237 |
+
timestep_spacing (`str`, defaults to `"leading"`):
|
238 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
239 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
240 |
+
timestep_scaling (`float`, defaults to 10.0):
|
241 |
+
The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
|
242 |
+
`c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
|
243 |
+
error at the default of `10.0` is already pretty small).
|
244 |
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
245 |
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
246 |
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
247 |
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
248 |
+
"""
|
249 |
+
|
250 |
+
order = 1
|
251 |
+
|
252 |
+
@register_to_config
|
253 |
+
def __init__(
|
254 |
+
self,
|
255 |
+
num_train_timesteps: int = 1000,
|
256 |
+
beta_start: float = 0.00085,
|
257 |
+
beta_end: float = 0.012,
|
258 |
+
beta_schedule: str = "scaled_linear",
|
259 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
260 |
+
original_inference_steps: int = 50,
|
261 |
+
clip_sample: bool = False,
|
262 |
+
clip_sample_range: float = 1.0,
|
263 |
+
set_alpha_to_one: bool = True,
|
264 |
+
steps_offset: int = 0,
|
265 |
+
prediction_type: str = "epsilon",
|
266 |
+
thresholding: bool = False,
|
267 |
+
dynamic_thresholding_ratio: float = 0.995,
|
268 |
+
sample_max_value: float = 1.0,
|
269 |
+
timestep_spacing: str = "leading",
|
270 |
+
timestep_scaling: float = 10.0,
|
271 |
+
rescale_betas_zero_snr: bool = False,
|
272 |
+
):
|
273 |
+
if trained_betas is not None:
|
274 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
275 |
+
elif beta_schedule == "linear":
|
276 |
+
self.betas = torch.linspace(
|
277 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
278 |
+
elif beta_schedule == "scaled_linear":
|
279 |
+
# this schedule is very specific to the latent diffusion model.
|
280 |
+
self.betas = torch.linspace(
|
281 |
+
beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
282 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
283 |
+
# Glide cosine schedule
|
284 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
285 |
+
else:
|
286 |
+
raise NotImplementedError(
|
287 |
+
f"{beta_schedule} does is not implemented for {self.__class__}")
|
288 |
+
|
289 |
+
# Rescale for zero SNR
|
290 |
+
if rescale_betas_zero_snr:
|
291 |
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
292 |
+
|
293 |
+
self.alphas = 1.0 - self.betas
|
294 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
295 |
+
|
296 |
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
297 |
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
298 |
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
299 |
+
# whether we use the final alpha of the "non-previous" one.
|
300 |
+
self.final_alpha_cumprod = torch.tensor(
|
301 |
+
1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
302 |
+
|
303 |
+
# standard deviation of the initial noise distribution
|
304 |
+
self.init_noise_sigma = 1.0
|
305 |
+
|
306 |
+
# setable values
|
307 |
+
self.num_inference_steps = None
|
308 |
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[
|
309 |
+
::-1].copy().astype(np.int64))
|
310 |
+
self.custom_timesteps = False
|
311 |
+
|
312 |
+
self._step_index = None
|
313 |
+
|
314 |
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
315 |
+
def _init_step_index(self, timestep):
|
316 |
+
if isinstance(timestep, torch.Tensor):
|
317 |
+
timestep = timestep.to(self.timesteps.device)
|
318 |
+
|
319 |
+
index_candidates = (self.timesteps == timestep).nonzero()
|
320 |
+
|
321 |
+
# The sigma index that is taken for the **very** first `step`
|
322 |
+
# is always the second index (or the last index if there is only 1)
|
323 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
324 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
325 |
+
if len(index_candidates) > 1:
|
326 |
+
step_index = index_candidates[1]
|
327 |
+
else:
|
328 |
+
step_index = index_candidates[0]
|
329 |
+
|
330 |
+
self._step_index = step_index.item()
|
331 |
+
|
332 |
+
@property
|
333 |
+
def step_index(self):
|
334 |
+
return self._step_index
|
335 |
+
|
336 |
+
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
337 |
+
"""
|
338 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
339 |
+
current timestep.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
sample (`torch.FloatTensor`):
|
343 |
+
The input sample.
|
344 |
+
timestep (`int`, *optional*):
|
345 |
+
The current timestep in the diffusion chain.
|
346 |
+
Returns:
|
347 |
+
`torch.FloatTensor`:
|
348 |
+
A scaled input sample.
|
349 |
+
"""
|
350 |
+
return sample
|
351 |
+
|
352 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
353 |
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
354 |
+
"""
|
355 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
356 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
357 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
358 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
359 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
360 |
+
|
361 |
+
https://arxiv.org/abs/2205.11487
|
362 |
+
"""
|
363 |
+
dtype = sample.dtype
|
364 |
+
batch_size, channels, *remaining_dims = sample.shape
|
365 |
+
|
366 |
+
if dtype not in (torch.float32, torch.float64):
|
367 |
+
# upcast for quantile calculation, and clamp not implemented for cpu half
|
368 |
+
sample = sample.float()
|
369 |
+
|
370 |
+
# Flatten sample for doing quantile calculation along each image
|
371 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
372 |
+
|
373 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
374 |
+
|
375 |
+
s = torch.quantile(
|
376 |
+
abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
377 |
+
s = torch.clamp(
|
378 |
+
s, min=1, max=self.config.sample_max_value
|
379 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
380 |
+
# (batch_size, 1) because clamp will broadcast along dim=0
|
381 |
+
s = s.unsqueeze(1)
|
382 |
+
# "we threshold xt0 to the range [-s, s] and then divide by s"
|
383 |
+
sample = torch.clamp(sample, -s, s) / s
|
384 |
+
|
385 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
386 |
+
sample = sample.to(dtype)
|
387 |
+
|
388 |
+
return sample
|
389 |
+
|
390 |
+
def set_timesteps(
|
391 |
+
self,
|
392 |
+
num_inference_steps: Optional[int] = None,
|
393 |
+
device: Union[str, torch.device] = None,
|
394 |
+
original_inference_steps: Optional[int] = None,
|
395 |
+
timesteps: Optional[List[int]] = None,
|
396 |
+
strength: int = 1.0,
|
397 |
+
):
|
398 |
+
"""
|
399 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
400 |
+
|
401 |
+
Args:
|
402 |
+
num_inference_steps (`int`, *optional*):
|
403 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
404 |
+
`timesteps` must be `None`.
|
405 |
+
device (`str` or `torch.device`, *optional*):
|
406 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
407 |
+
original_inference_steps (`int`, *optional*):
|
408 |
+
The original number of inference steps, which will be used to generate a linearly-spaced timestep
|
409 |
+
schedule (which is different from the standard `diffusers` implementation). We will then take
|
410 |
+
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
|
411 |
+
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
|
412 |
+
timesteps (`List[int]`, *optional*):
|
413 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
414 |
+
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
415 |
+
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
416 |
+
"""
|
417 |
+
# 0. Check inputs
|
418 |
+
if num_inference_steps is None and timesteps is None:
|
419 |
+
raise ValueError(
|
420 |
+
"Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
|
421 |
+
|
422 |
+
if num_inference_steps is not None and timesteps is not None:
|
423 |
+
raise ValueError(
|
424 |
+
"Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
425 |
+
|
426 |
+
# 1. Calculate the LCM original training/distillation timestep schedule.
|
427 |
+
original_steps = (
|
428 |
+
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
|
429 |
+
)
|
430 |
+
|
431 |
+
if original_steps > self.config.num_train_timesteps:
|
432 |
+
raise ValueError(
|
433 |
+
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
|
434 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
435 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
436 |
+
)
|
437 |
+
|
438 |
+
# LCM Timesteps Setting
|
439 |
+
# The skipping step parameter k from the paper.
|
440 |
+
k = self.config.num_train_timesteps // original_steps
|
441 |
+
# LCM Training/Distillation Steps Schedule
|
442 |
+
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
|
443 |
+
lcm_origin_timesteps = np.asarray(
|
444 |
+
list(range(1, int(original_steps * strength) + 1))) * k - 1
|
445 |
+
|
446 |
+
# 2. Calculate the LCM inference timestep schedule.
|
447 |
+
if timesteps is not None:
|
448 |
+
# 2.1 Handle custom timestep schedules.
|
449 |
+
train_timesteps = set(lcm_origin_timesteps)
|
450 |
+
non_train_timesteps = []
|
451 |
+
for i in range(1, len(timesteps)):
|
452 |
+
if timesteps[i] >= timesteps[i - 1]:
|
453 |
+
raise ValueError(
|
454 |
+
"`custom_timesteps` must be in descending order.")
|
455 |
+
|
456 |
+
if timesteps[i] not in train_timesteps:
|
457 |
+
non_train_timesteps.append(timesteps[i])
|
458 |
+
|
459 |
+
if timesteps[0] >= self.config.num_train_timesteps:
|
460 |
+
raise ValueError(
|
461 |
+
f"`timesteps` must start before `self.config.train_timesteps`:"
|
462 |
+
f" {self.config.num_train_timesteps}."
|
463 |
+
)
|
464 |
+
|
465 |
+
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
|
466 |
+
if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
|
467 |
+
logger.warning(
|
468 |
+
f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
|
469 |
+
f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
|
470 |
+
f" unexpected results when using this timestep schedule."
|
471 |
+
)
|
472 |
+
|
473 |
+
# Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
|
474 |
+
if non_train_timesteps:
|
475 |
+
logger.warning(
|
476 |
+
f"The custom timestep schedule contains the following timesteps which are not on the original"
|
477 |
+
f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
|
478 |
+
f" when using this timestep schedule."
|
479 |
+
)
|
480 |
+
|
481 |
+
# Raise warning if custom timestep schedule is longer than original_steps
|
482 |
+
if len(timesteps) > original_steps:
|
483 |
+
logger.warning(
|
484 |
+
f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
|
485 |
+
f" the length of the timestep schedule used for training: {original_steps}. You may get some"
|
486 |
+
f" unexpected results when using this timestep schedule."
|
487 |
+
)
|
488 |
+
|
489 |
+
timesteps = np.array(timesteps, dtype=np.int64)
|
490 |
+
self.num_inference_steps = len(timesteps)
|
491 |
+
self.custom_timesteps = True
|
492 |
+
|
493 |
+
# Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
|
494 |
+
init_timestep = min(
|
495 |
+
int(self.num_inference_steps * strength), self.num_inference_steps)
|
496 |
+
t_start = max(self.num_inference_steps - init_timestep, 0)
|
497 |
+
timesteps = timesteps[t_start * self.order:]
|
498 |
+
# TODO: also reset self.num_inference_steps?
|
499 |
+
else:
|
500 |
+
# 2.2 Create the "standard" LCM inference timestep schedule.
|
501 |
+
if num_inference_steps > self.config.num_train_timesteps:
|
502 |
+
raise ValueError(
|
503 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
504 |
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
505 |
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
506 |
+
)
|
507 |
+
|
508 |
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
509 |
+
|
510 |
+
if skipping_step < 1:
|
511 |
+
raise ValueError(
|
512 |
+
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
|
513 |
+
)
|
514 |
+
|
515 |
+
self.num_inference_steps = num_inference_steps
|
516 |
+
|
517 |
+
if num_inference_steps > original_steps:
|
518 |
+
raise ValueError(
|
519 |
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
520 |
+
f" {original_steps} because the final timestep schedule will be a subset of the"
|
521 |
+
f" `original_inference_steps`-sized initial timestep schedule."
|
522 |
+
)
|
523 |
+
|
524 |
+
# LCM Inference Steps Schedule
|
525 |
+
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
|
526 |
+
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
|
527 |
+
inference_indices = np.linspace(
|
528 |
+
0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
|
529 |
+
'''
|
530 |
+
|
531 |
+
当只有1步时会进行999步直接进行
|
532 |
+
两步: 999, 499,
|
533 |
+
四步: 999, 759, 499, 259
|
534 |
+
|
535 |
+
'''
|
536 |
+
inference_indices = np.floor(inference_indices).astype(np.int64)
|
537 |
+
timesteps = lcm_origin_timesteps[inference_indices]
|
538 |
+
|
539 |
+
self.timesteps = torch.from_numpy(timesteps).to(
|
540 |
+
device=device, dtype=torch.long)
|
541 |
+
|
542 |
+
self._step_index = None
|
543 |
+
|
544 |
+
|
545 |
+
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
546 |
+
self.sigma_data = 0.5 # Default: 0.5
|
547 |
+
scaled_timestep = timestep * self.config.timestep_scaling
|
548 |
+
|
549 |
+
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
|
550 |
+
c_out = scaled_timestep / \
|
551 |
+
(scaled_timestep**2 + self.sigma_data**2) ** 0.5
|
552 |
+
return c_skip, c_out
|
553 |
+
|
554 |
+
def step(
|
555 |
+
self,
|
556 |
+
model_output: torch.FloatTensor,
|
557 |
+
timestep: int,
|
558 |
+
sample: torch.FloatTensor,
|
559 |
+
generator: Optional[torch.Generator] = None,
|
560 |
+
return_dict: bool = True,
|
561 |
+
use_ddim: bool = False,
|
562 |
+
) -> Union[LCMSchedulerOutput, Tuple]:
|
563 |
+
"""
|
564 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
565 |
+
process from the learned model outputs (most often the predicted noise).
|
566 |
+
|
567 |
+
Args:
|
568 |
+
model_output (`torch.FloatTensor`):
|
569 |
+
The direct output from learned diffusion model.
|
570 |
+
timestep (`float`):
|
571 |
+
The current discrete timestep in the diffusion chain.
|
572 |
+
sample (`torch.FloatTensor`):
|
573 |
+
A current instance of a sample created by the diffusion process.
|
574 |
+
generator (`torch.Generator`, *optional*):
|
575 |
+
A random number generator.
|
576 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
577 |
+
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
|
578 |
+
Returns:
|
579 |
+
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
|
580 |
+
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
|
581 |
+
tuple is returned where the first element is the sample tensor.
|
582 |
+
"""
|
583 |
+
if self.num_inference_steps is None:
|
584 |
+
raise ValueError(
|
585 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
586 |
+
)
|
587 |
+
|
588 |
+
if self.step_index is None:
|
589 |
+
self._init_step_index(timestep)
|
590 |
+
|
591 |
+
# 1. get previous step value
|
592 |
+
prev_step_index = self.step_index + 1
|
593 |
+
if prev_step_index < len(self.timesteps):
|
594 |
+
prev_timestep = self.timesteps[prev_step_index]
|
595 |
+
else:
|
596 |
+
prev_timestep = timestep
|
597 |
+
|
598 |
+
# 2. compute alphas, betas
|
599 |
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
600 |
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
601 |
+
|
602 |
+
beta_prod_t = 1 - alpha_prod_t
|
603 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
604 |
+
|
605 |
+
# 3. Get scalings for boundary conditions
|
606 |
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(
|
607 |
+
timestep)
|
608 |
+
|
609 |
+
# 4. Compute the predicted original sample x_0 based on the model parameterization
|
610 |
+
if self.config.prediction_type == "epsilon": # noise-prediction
|
611 |
+
predicted_original_sample = (
|
612 |
+
sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
613 |
+
elif self.config.prediction_type == "sample": # x-prediction
|
614 |
+
predicted_original_sample = model_output
|
615 |
+
elif self.config.prediction_type == "v_prediction": # v-prediction
|
616 |
+
predicted_original_sample = alpha_prod_t.sqrt(
|
617 |
+
) * sample - beta_prod_t.sqrt() * model_output
|
618 |
+
else:
|
619 |
+
raise ValueError(
|
620 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
621 |
+
" `v_prediction` for `LCMScheduler`."
|
622 |
+
)
|
623 |
+
|
624 |
+
# 5. Clip or threshold "predicted x_0"
|
625 |
+
if self.config.thresholding:
|
626 |
+
predicted_original_sample = self._threshold_sample(
|
627 |
+
predicted_original_sample)
|
628 |
+
elif self.config.clip_sample:
|
629 |
+
predicted_original_sample = predicted_original_sample.clamp(
|
630 |
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
631 |
+
)
|
632 |
+
|
633 |
+
# 6. Denoise model output using boundary conditions
|
634 |
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
635 |
+
# denoised = predicted_original_sample
|
636 |
+
|
637 |
+
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
638 |
+
# Noise is not used on the final timestep of the timestep schedule.
|
639 |
+
# This also means that noise is not used for one-step sampling.
|
640 |
+
if self.step_index != self.num_inference_steps - 1:
|
641 |
+
if not use_ddim:
|
642 |
+
noise = randn_tensor(
|
643 |
+
model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
|
644 |
+
)
|
645 |
+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
646 |
+
else:
|
647 |
+
prev_sample = denoised
|
648 |
+
|
649 |
+
# upon completion increase step index by one
|
650 |
+
self._step_index += 1
|
651 |
+
|
652 |
+
if not return_dict:
|
653 |
+
return (prev_sample, denoised)
|
654 |
+
|
655 |
+
return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
|
656 |
+
|
657 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
658 |
+
def add_noise(
|
659 |
+
self,
|
660 |
+
original_samples: torch.FloatTensor,
|
661 |
+
noise: torch.FloatTensor,
|
662 |
+
timesteps: torch.IntTensor,
|
663 |
+
) -> torch.FloatTensor:
|
664 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
665 |
+
alphas_cumprod = self.alphas_cumprod.to(
|
666 |
+
device=original_samples.device, dtype=original_samples.dtype)
|
667 |
+
timesteps = timesteps.to(original_samples.device)
|
668 |
+
|
669 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
670 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
671 |
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
672 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
673 |
+
|
674 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
675 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
676 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
677 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
678 |
+
|
679 |
+
noisy_samples = sqrt_alpha_prod * original_samples + \
|
680 |
+
sqrt_one_minus_alpha_prod * noise
|
681 |
+
return noisy_samples
|
682 |
+
|
683 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
684 |
+
def get_velocity(
|
685 |
+
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
686 |
+
) -> torch.FloatTensor:
|
687 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
688 |
+
alphas_cumprod = self.alphas_cumprod.to(
|
689 |
+
device=sample.device, dtype=sample.dtype)
|
690 |
+
timesteps = timesteps.to(sample.device)
|
691 |
+
|
692 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
693 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
694 |
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
695 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
696 |
+
|
697 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
698 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
699 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
700 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
701 |
+
|
702 |
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
703 |
+
return velocity
|
704 |
+
|
705 |
+
def __len__(self):
|
706 |
+
return self.config.num_train_timesteps
|
707 |
+
|
708 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
709 |
+
def previous_timestep(self, timestep):
|
710 |
+
if self.custom_timesteps:
|
711 |
+
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
712 |
+
if index == self.timesteps.shape[0] - 1:
|
713 |
+
prev_t = torch.tensor(-1)
|
714 |
+
else:
|
715 |
+
prev_t = self.timesteps[index + 1]
|
716 |
+
else:
|
717 |
+
num_inference_steps = (
|
718 |
+
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
719 |
+
)
|
720 |
+
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
721 |
+
|
722 |
+
return prev_t
|
animatelcm/utils/convert_from_ckpt.py
ADDED
@@ -0,0 +1,951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
from io import BytesIO
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import requests
|
22 |
+
import torch
|
23 |
+
from transformers import (
|
24 |
+
AutoFeatureExtractor,
|
25 |
+
BertTokenizerFast,
|
26 |
+
CLIPImageProcessor,
|
27 |
+
CLIPTextModel,
|
28 |
+
CLIPTextModelWithProjection,
|
29 |
+
CLIPTokenizer,
|
30 |
+
CLIPVisionConfig,
|
31 |
+
CLIPVisionModelWithProjection,
|
32 |
+
)
|
33 |
+
|
34 |
+
from diffusers.models import (
|
35 |
+
AutoencoderKL,
|
36 |
+
PriorTransformer,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
)
|
39 |
+
from diffusers.schedulers import (
|
40 |
+
DDIMScheduler,
|
41 |
+
DDPMScheduler,
|
42 |
+
DPMSolverMultistepScheduler,
|
43 |
+
EulerAncestralDiscreteScheduler,
|
44 |
+
EulerDiscreteScheduler,
|
45 |
+
HeunDiscreteScheduler,
|
46 |
+
LMSDiscreteScheduler,
|
47 |
+
PNDMScheduler,
|
48 |
+
UnCLIPScheduler,
|
49 |
+
)
|
50 |
+
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
51 |
+
|
52 |
+
|
53 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
54 |
+
"""
|
55 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
56 |
+
"""
|
57 |
+
if n_shave_prefix_segments >= 0:
|
58 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
59 |
+
else:
|
60 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
61 |
+
|
62 |
+
|
63 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
64 |
+
"""
|
65 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
66 |
+
"""
|
67 |
+
mapping = []
|
68 |
+
for old_item in old_list:
|
69 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
70 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
71 |
+
|
72 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
73 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
74 |
+
|
75 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
76 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
77 |
+
|
78 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
79 |
+
|
80 |
+
mapping.append({"old": old_item, "new": new_item})
|
81 |
+
|
82 |
+
return mapping
|
83 |
+
|
84 |
+
|
85 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
86 |
+
"""
|
87 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
88 |
+
"""
|
89 |
+
mapping = []
|
90 |
+
for old_item in old_list:
|
91 |
+
new_item = old_item
|
92 |
+
|
93 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
94 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
95 |
+
|
96 |
+
mapping.append({"old": old_item, "new": new_item})
|
97 |
+
|
98 |
+
return mapping
|
99 |
+
|
100 |
+
|
101 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
102 |
+
"""
|
103 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
104 |
+
"""
|
105 |
+
mapping = []
|
106 |
+
for old_item in old_list:
|
107 |
+
new_item = old_item
|
108 |
+
|
109 |
+
mapping.append({"old": old_item, "new": new_item})
|
110 |
+
|
111 |
+
return mapping
|
112 |
+
|
113 |
+
|
114 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
115 |
+
"""
|
116 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
117 |
+
"""
|
118 |
+
mapping = []
|
119 |
+
for old_item in old_list:
|
120 |
+
new_item = old_item
|
121 |
+
|
122 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
123 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
124 |
+
|
125 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
126 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
127 |
+
|
128 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
129 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
130 |
+
|
131 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
132 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
133 |
+
|
134 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
135 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
136 |
+
|
137 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
138 |
+
|
139 |
+
mapping.append({"old": old_item, "new": new_item})
|
140 |
+
|
141 |
+
return mapping
|
142 |
+
|
143 |
+
|
144 |
+
def assign_to_checkpoint(
|
145 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
149 |
+
attention layers, and takes into account additional replacements that may arise.
|
150 |
+
|
151 |
+
Assigns the weights to the new checkpoint.
|
152 |
+
"""
|
153 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
154 |
+
|
155 |
+
# Splits the attention layers into three variables.
|
156 |
+
if attention_paths_to_split is not None:
|
157 |
+
for path, path_map in attention_paths_to_split.items():
|
158 |
+
old_tensor = old_checkpoint[path]
|
159 |
+
channels = old_tensor.shape[0] // 3
|
160 |
+
|
161 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
162 |
+
|
163 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
164 |
+
|
165 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
166 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
167 |
+
|
168 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
169 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
170 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
171 |
+
|
172 |
+
for path in paths:
|
173 |
+
new_path = path["new"]
|
174 |
+
|
175 |
+
# These have already been assigned
|
176 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
177 |
+
continue
|
178 |
+
|
179 |
+
# Global renaming happens here
|
180 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
181 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
182 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
183 |
+
|
184 |
+
if additional_replacements is not None:
|
185 |
+
for replacement in additional_replacements:
|
186 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
187 |
+
|
188 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
189 |
+
if "proj_attn.weight" in new_path:
|
190 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
191 |
+
else:
|
192 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
193 |
+
|
194 |
+
|
195 |
+
def conv_attn_to_linear(checkpoint):
|
196 |
+
keys = list(checkpoint.keys())
|
197 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
198 |
+
for key in keys:
|
199 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
200 |
+
if checkpoint[key].ndim > 2:
|
201 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
202 |
+
elif "proj_attn.weight" in key:
|
203 |
+
if checkpoint[key].ndim > 2:
|
204 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
205 |
+
|
206 |
+
|
207 |
+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
208 |
+
"""
|
209 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
210 |
+
"""
|
211 |
+
if controlnet:
|
212 |
+
unet_params = original_config.model.params.control_stage_config.params
|
213 |
+
else:
|
214 |
+
unet_params = original_config.model.params.unet_config.params
|
215 |
+
|
216 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
217 |
+
|
218 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
219 |
+
|
220 |
+
down_block_types = []
|
221 |
+
resolution = 1
|
222 |
+
for i in range(len(block_out_channels)):
|
223 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
224 |
+
down_block_types.append(block_type)
|
225 |
+
if i != len(block_out_channels) - 1:
|
226 |
+
resolution *= 2
|
227 |
+
|
228 |
+
up_block_types = []
|
229 |
+
for i in range(len(block_out_channels)):
|
230 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
231 |
+
up_block_types.append(block_type)
|
232 |
+
resolution //= 2
|
233 |
+
|
234 |
+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
235 |
+
|
236 |
+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
237 |
+
use_linear_projection = (
|
238 |
+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
239 |
+
)
|
240 |
+
if use_linear_projection:
|
241 |
+
# stable diffusion 2-base-512 and 2-768
|
242 |
+
if head_dim is None:
|
243 |
+
head_dim = [5, 10, 20, 20]
|
244 |
+
|
245 |
+
class_embed_type = None
|
246 |
+
projection_class_embeddings_input_dim = None
|
247 |
+
|
248 |
+
if "num_classes" in unet_params:
|
249 |
+
if unet_params.num_classes == "sequential":
|
250 |
+
class_embed_type = "projection"
|
251 |
+
assert "adm_in_channels" in unet_params
|
252 |
+
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
253 |
+
else:
|
254 |
+
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
255 |
+
|
256 |
+
config = {
|
257 |
+
"sample_size": image_size // vae_scale_factor,
|
258 |
+
"in_channels": unet_params.in_channels,
|
259 |
+
"down_block_types": tuple(down_block_types),
|
260 |
+
"block_out_channels": tuple(block_out_channels),
|
261 |
+
"layers_per_block": unet_params.num_res_blocks,
|
262 |
+
"cross_attention_dim": unet_params.context_dim,
|
263 |
+
"attention_head_dim": head_dim,
|
264 |
+
"use_linear_projection": use_linear_projection,
|
265 |
+
"class_embed_type": class_embed_type,
|
266 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
267 |
+
}
|
268 |
+
|
269 |
+
if not controlnet:
|
270 |
+
config["out_channels"] = unet_params.out_channels
|
271 |
+
config["up_block_types"] = tuple(up_block_types)
|
272 |
+
|
273 |
+
return config
|
274 |
+
|
275 |
+
|
276 |
+
def create_vae_diffusers_config(original_config, image_size: int):
|
277 |
+
"""
|
278 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
279 |
+
"""
|
280 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
281 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
282 |
+
|
283 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
284 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
285 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
286 |
+
|
287 |
+
config = {
|
288 |
+
"sample_size": image_size,
|
289 |
+
"in_channels": vae_params.in_channels,
|
290 |
+
"out_channels": vae_params.out_ch,
|
291 |
+
"down_block_types": tuple(down_block_types),
|
292 |
+
"up_block_types": tuple(up_block_types),
|
293 |
+
"block_out_channels": tuple(block_out_channels),
|
294 |
+
"latent_channels": vae_params.z_channels,
|
295 |
+
"layers_per_block": vae_params.num_res_blocks,
|
296 |
+
}
|
297 |
+
return config
|
298 |
+
|
299 |
+
|
300 |
+
def create_diffusers_schedular(original_config):
|
301 |
+
schedular = DDIMScheduler(
|
302 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
303 |
+
beta_start=original_config.model.params.linear_start,
|
304 |
+
beta_end=original_config.model.params.linear_end,
|
305 |
+
beta_schedule="scaled_linear",
|
306 |
+
)
|
307 |
+
return schedular
|
308 |
+
|
309 |
+
|
310 |
+
def create_ldm_bert_config(original_config):
|
311 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
312 |
+
config = LDMBertConfig(
|
313 |
+
d_model=bert_params.n_embed,
|
314 |
+
encoder_layers=bert_params.n_layer,
|
315 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
316 |
+
)
|
317 |
+
return config
|
318 |
+
|
319 |
+
|
320 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
321 |
+
"""
|
322 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
323 |
+
"""
|
324 |
+
|
325 |
+
# extract state_dict for UNet
|
326 |
+
unet_state_dict = {}
|
327 |
+
keys = list(checkpoint.keys())
|
328 |
+
|
329 |
+
if controlnet:
|
330 |
+
unet_key = "control_model."
|
331 |
+
else:
|
332 |
+
unet_key = "model.diffusion_model."
|
333 |
+
|
334 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
335 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
336 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
337 |
+
print(
|
338 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
339 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
340 |
+
)
|
341 |
+
for key in keys:
|
342 |
+
if key.startswith("model.diffusion_model"):
|
343 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
344 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
345 |
+
else:
|
346 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
347 |
+
print(
|
348 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
349 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
350 |
+
)
|
351 |
+
|
352 |
+
for key in keys:
|
353 |
+
if key.startswith(unet_key):
|
354 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
355 |
+
|
356 |
+
new_checkpoint = {}
|
357 |
+
|
358 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
359 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
360 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
361 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
362 |
+
|
363 |
+
if config["class_embed_type"] is None:
|
364 |
+
# No parameters to port
|
365 |
+
...
|
366 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
367 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
368 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
369 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
370 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
371 |
+
else:
|
372 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
373 |
+
|
374 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
375 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
376 |
+
|
377 |
+
if not controlnet:
|
378 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
379 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
380 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
381 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
382 |
+
|
383 |
+
# Retrieves the keys for the input blocks only
|
384 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
385 |
+
input_blocks = {
|
386 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
387 |
+
for layer_id in range(num_input_blocks)
|
388 |
+
}
|
389 |
+
|
390 |
+
# Retrieves the keys for the middle blocks only
|
391 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
392 |
+
middle_blocks = {
|
393 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
394 |
+
for layer_id in range(num_middle_blocks)
|
395 |
+
}
|
396 |
+
|
397 |
+
# Retrieves the keys for the output blocks only
|
398 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
399 |
+
output_blocks = {
|
400 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
401 |
+
for layer_id in range(num_output_blocks)
|
402 |
+
}
|
403 |
+
|
404 |
+
for i in range(1, num_input_blocks):
|
405 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
406 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
407 |
+
|
408 |
+
resnets = [
|
409 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
410 |
+
]
|
411 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
412 |
+
|
413 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
414 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
415 |
+
f"input_blocks.{i}.0.op.weight"
|
416 |
+
)
|
417 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
418 |
+
f"input_blocks.{i}.0.op.bias"
|
419 |
+
)
|
420 |
+
|
421 |
+
paths = renew_resnet_paths(resnets)
|
422 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
423 |
+
assign_to_checkpoint(
|
424 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
425 |
+
)
|
426 |
+
|
427 |
+
if len(attentions):
|
428 |
+
paths = renew_attention_paths(attentions)
|
429 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
430 |
+
assign_to_checkpoint(
|
431 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
432 |
+
)
|
433 |
+
|
434 |
+
resnet_0 = middle_blocks[0]
|
435 |
+
attentions = middle_blocks[1]
|
436 |
+
resnet_1 = middle_blocks[2]
|
437 |
+
|
438 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
439 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
440 |
+
|
441 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
442 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
443 |
+
|
444 |
+
attentions_paths = renew_attention_paths(attentions)
|
445 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
446 |
+
assign_to_checkpoint(
|
447 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
448 |
+
)
|
449 |
+
|
450 |
+
for i in range(num_output_blocks):
|
451 |
+
block_id = i // (config["layers_per_block"] + 1)
|
452 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
453 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
454 |
+
output_block_list = {}
|
455 |
+
|
456 |
+
for layer in output_block_layers:
|
457 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
458 |
+
if layer_id in output_block_list:
|
459 |
+
output_block_list[layer_id].append(layer_name)
|
460 |
+
else:
|
461 |
+
output_block_list[layer_id] = [layer_name]
|
462 |
+
|
463 |
+
if len(output_block_list) > 1:
|
464 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
465 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
466 |
+
|
467 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
468 |
+
paths = renew_resnet_paths(resnets)
|
469 |
+
|
470 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
471 |
+
assign_to_checkpoint(
|
472 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
473 |
+
)
|
474 |
+
|
475 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
476 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
477 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
478 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
479 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
480 |
+
]
|
481 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
482 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
483 |
+
]
|
484 |
+
|
485 |
+
# Clear attentions as they have been attributed above.
|
486 |
+
if len(attentions) == 2:
|
487 |
+
attentions = []
|
488 |
+
|
489 |
+
if len(attentions):
|
490 |
+
paths = renew_attention_paths(attentions)
|
491 |
+
meta_path = {
|
492 |
+
"old": f"output_blocks.{i}.1",
|
493 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
494 |
+
}
|
495 |
+
assign_to_checkpoint(
|
496 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
500 |
+
for path in resnet_0_paths:
|
501 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
502 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
503 |
+
|
504 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
505 |
+
|
506 |
+
if controlnet:
|
507 |
+
# conditioning embedding
|
508 |
+
|
509 |
+
orig_index = 0
|
510 |
+
|
511 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
512 |
+
f"input_hint_block.{orig_index}.weight"
|
513 |
+
)
|
514 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
515 |
+
f"input_hint_block.{orig_index}.bias"
|
516 |
+
)
|
517 |
+
|
518 |
+
orig_index += 2
|
519 |
+
|
520 |
+
diffusers_index = 0
|
521 |
+
|
522 |
+
while diffusers_index < 6:
|
523 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
524 |
+
f"input_hint_block.{orig_index}.weight"
|
525 |
+
)
|
526 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
527 |
+
f"input_hint_block.{orig_index}.bias"
|
528 |
+
)
|
529 |
+
diffusers_index += 1
|
530 |
+
orig_index += 2
|
531 |
+
|
532 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
533 |
+
f"input_hint_block.{orig_index}.weight"
|
534 |
+
)
|
535 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
536 |
+
f"input_hint_block.{orig_index}.bias"
|
537 |
+
)
|
538 |
+
|
539 |
+
# down blocks
|
540 |
+
for i in range(num_input_blocks):
|
541 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
542 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
543 |
+
|
544 |
+
# mid block
|
545 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
546 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
547 |
+
|
548 |
+
return new_checkpoint
|
549 |
+
|
550 |
+
|
551 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
552 |
+
# extract state dict for VAE
|
553 |
+
vae_state_dict = {}
|
554 |
+
vae_key = "first_stage_model."
|
555 |
+
keys = list(checkpoint.keys())
|
556 |
+
for key in keys:
|
557 |
+
if key.startswith(vae_key):
|
558 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
559 |
+
|
560 |
+
new_checkpoint = {}
|
561 |
+
|
562 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
563 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
564 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
565 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
566 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
567 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
568 |
+
|
569 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
570 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
571 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
572 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
573 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
574 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
575 |
+
|
576 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
577 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
578 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
579 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
580 |
+
|
581 |
+
# Retrieves the keys for the encoder down blocks only
|
582 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
583 |
+
down_blocks = {
|
584 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
585 |
+
}
|
586 |
+
|
587 |
+
# Retrieves the keys for the decoder up blocks only
|
588 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
589 |
+
up_blocks = {
|
590 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
591 |
+
}
|
592 |
+
|
593 |
+
for i in range(num_down_blocks):
|
594 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
595 |
+
|
596 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
597 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
598 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
599 |
+
)
|
600 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
601 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
602 |
+
)
|
603 |
+
|
604 |
+
paths = renew_vae_resnet_paths(resnets)
|
605 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
606 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
607 |
+
|
608 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
609 |
+
num_mid_res_blocks = 2
|
610 |
+
for i in range(1, num_mid_res_blocks + 1):
|
611 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
612 |
+
|
613 |
+
paths = renew_vae_resnet_paths(resnets)
|
614 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
615 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
616 |
+
|
617 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
618 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
619 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
620 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
621 |
+
conv_attn_to_linear(new_checkpoint)
|
622 |
+
|
623 |
+
for i in range(num_up_blocks):
|
624 |
+
block_id = num_up_blocks - 1 - i
|
625 |
+
resnets = [
|
626 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
627 |
+
]
|
628 |
+
|
629 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
630 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
631 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
632 |
+
]
|
633 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
634 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
635 |
+
]
|
636 |
+
|
637 |
+
paths = renew_vae_resnet_paths(resnets)
|
638 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
639 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
640 |
+
|
641 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
642 |
+
num_mid_res_blocks = 2
|
643 |
+
for i in range(1, num_mid_res_blocks + 1):
|
644 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
645 |
+
|
646 |
+
paths = renew_vae_resnet_paths(resnets)
|
647 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
648 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
649 |
+
|
650 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
651 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
652 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
653 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
654 |
+
conv_attn_to_linear(new_checkpoint)
|
655 |
+
return new_checkpoint
|
656 |
+
|
657 |
+
|
658 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
659 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
660 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
661 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
662 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
663 |
+
|
664 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
665 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
666 |
+
|
667 |
+
def _copy_linear(hf_linear, pt_linear):
|
668 |
+
hf_linear.weight = pt_linear.weight
|
669 |
+
hf_linear.bias = pt_linear.bias
|
670 |
+
|
671 |
+
def _copy_layer(hf_layer, pt_layer):
|
672 |
+
# copy layer norms
|
673 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
674 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
675 |
+
|
676 |
+
# copy attn
|
677 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
678 |
+
|
679 |
+
# copy MLP
|
680 |
+
pt_mlp = pt_layer[1][1]
|
681 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
682 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
683 |
+
|
684 |
+
def _copy_layers(hf_layers, pt_layers):
|
685 |
+
for i, hf_layer in enumerate(hf_layers):
|
686 |
+
if i != 0:
|
687 |
+
i += i
|
688 |
+
pt_layer = pt_layers[i : i + 2]
|
689 |
+
_copy_layer(hf_layer, pt_layer)
|
690 |
+
|
691 |
+
hf_model = LDMBertModel(config).eval()
|
692 |
+
|
693 |
+
# copy embeds
|
694 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
695 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
696 |
+
|
697 |
+
# copy layer norm
|
698 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
699 |
+
|
700 |
+
# copy hidden layers
|
701 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
702 |
+
|
703 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
704 |
+
|
705 |
+
return hf_model
|
706 |
+
|
707 |
+
|
708 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
709 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
710 |
+
keys = list(checkpoint.keys())
|
711 |
+
|
712 |
+
text_model_dict = {}
|
713 |
+
|
714 |
+
for key in keys:
|
715 |
+
if key.startswith("cond_stage_model.transformer"):
|
716 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
717 |
+
|
718 |
+
text_model.load_state_dict(text_model_dict)
|
719 |
+
|
720 |
+
return text_model
|
721 |
+
|
722 |
+
|
723 |
+
textenc_conversion_lst = [
|
724 |
+
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
725 |
+
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
726 |
+
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
727 |
+
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
728 |
+
]
|
729 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
730 |
+
|
731 |
+
textenc_transformer_conversion_lst = [
|
732 |
+
# (stable-diffusion, HF Diffusers)
|
733 |
+
("resblocks.", "text_model.encoder.layers."),
|
734 |
+
("ln_1", "layer_norm1"),
|
735 |
+
("ln_2", "layer_norm2"),
|
736 |
+
(".c_fc.", ".fc1."),
|
737 |
+
(".c_proj.", ".fc2."),
|
738 |
+
(".attn", ".self_attn"),
|
739 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
740 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
741 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
742 |
+
]
|
743 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
744 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
745 |
+
|
746 |
+
|
747 |
+
def convert_paint_by_example_checkpoint(checkpoint):
|
748 |
+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
749 |
+
model = PaintByExampleImageEncoder(config)
|
750 |
+
|
751 |
+
keys = list(checkpoint.keys())
|
752 |
+
|
753 |
+
text_model_dict = {}
|
754 |
+
|
755 |
+
for key in keys:
|
756 |
+
if key.startswith("cond_stage_model.transformer"):
|
757 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
758 |
+
|
759 |
+
# load clip vision
|
760 |
+
model.model.load_state_dict(text_model_dict)
|
761 |
+
|
762 |
+
# load mapper
|
763 |
+
keys_mapper = {
|
764 |
+
k[len("cond_stage_model.mapper.res") :]: v
|
765 |
+
for k, v in checkpoint.items()
|
766 |
+
if k.startswith("cond_stage_model.mapper")
|
767 |
+
}
|
768 |
+
|
769 |
+
MAPPING = {
|
770 |
+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
771 |
+
"attn.c_proj": ["attn1.to_out.0"],
|
772 |
+
"ln_1": ["norm1"],
|
773 |
+
"ln_2": ["norm3"],
|
774 |
+
"mlp.c_fc": ["ff.net.0.proj"],
|
775 |
+
"mlp.c_proj": ["ff.net.2"],
|
776 |
+
}
|
777 |
+
|
778 |
+
mapped_weights = {}
|
779 |
+
for key, value in keys_mapper.items():
|
780 |
+
prefix = key[: len("blocks.i")]
|
781 |
+
suffix = key.split(prefix)[-1].split(".")[-1]
|
782 |
+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
783 |
+
mapped_names = MAPPING[name]
|
784 |
+
|
785 |
+
num_splits = len(mapped_names)
|
786 |
+
for i, mapped_name in enumerate(mapped_names):
|
787 |
+
new_name = ".".join([prefix, mapped_name, suffix])
|
788 |
+
shape = value.shape[0] // num_splits
|
789 |
+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
790 |
+
|
791 |
+
model.mapper.load_state_dict(mapped_weights)
|
792 |
+
|
793 |
+
# load final layer norm
|
794 |
+
model.final_layer_norm.load_state_dict(
|
795 |
+
{
|
796 |
+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
797 |
+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
798 |
+
}
|
799 |
+
)
|
800 |
+
|
801 |
+
# load final proj
|
802 |
+
model.proj_out.load_state_dict(
|
803 |
+
{
|
804 |
+
"bias": checkpoint["proj_out.bias"],
|
805 |
+
"weight": checkpoint["proj_out.weight"],
|
806 |
+
}
|
807 |
+
)
|
808 |
+
|
809 |
+
# load uncond vector
|
810 |
+
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
811 |
+
return model
|
812 |
+
|
813 |
+
|
814 |
+
def convert_open_clip_checkpoint(checkpoint):
|
815 |
+
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
816 |
+
|
817 |
+
keys = list(checkpoint.keys())
|
818 |
+
|
819 |
+
text_model_dict = {}
|
820 |
+
|
821 |
+
if "cond_stage_model.model.text_projection" in checkpoint:
|
822 |
+
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
823 |
+
else:
|
824 |
+
d_model = 1024
|
825 |
+
|
826 |
+
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
827 |
+
|
828 |
+
for key in keys:
|
829 |
+
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
830 |
+
continue
|
831 |
+
if key in textenc_conversion_map:
|
832 |
+
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
833 |
+
if key.startswith("cond_stage_model.model.transformer."):
|
834 |
+
new_key = key[len("cond_stage_model.model.transformer.") :]
|
835 |
+
if new_key.endswith(".in_proj_weight"):
|
836 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
837 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
838 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
839 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
840 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
841 |
+
elif new_key.endswith(".in_proj_bias"):
|
842 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
843 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
844 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
845 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
846 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
847 |
+
else:
|
848 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
849 |
+
|
850 |
+
text_model_dict[new_key] = checkpoint[key]
|
851 |
+
|
852 |
+
text_model.load_state_dict(text_model_dict)
|
853 |
+
|
854 |
+
return text_model
|
855 |
+
|
856 |
+
|
857 |
+
def stable_unclip_image_encoder(original_config):
|
858 |
+
"""
|
859 |
+
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
860 |
+
|
861 |
+
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
862 |
+
encoders.
|
863 |
+
"""
|
864 |
+
|
865 |
+
image_embedder_config = original_config.model.params.embedder_config
|
866 |
+
|
867 |
+
sd_clip_image_embedder_class = image_embedder_config.target
|
868 |
+
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
869 |
+
|
870 |
+
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
871 |
+
clip_model_name = image_embedder_config.params.model
|
872 |
+
|
873 |
+
if clip_model_name == "ViT-L/14":
|
874 |
+
feature_extractor = CLIPImageProcessor()
|
875 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
876 |
+
else:
|
877 |
+
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
878 |
+
|
879 |
+
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
880 |
+
feature_extractor = CLIPImageProcessor()
|
881 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
882 |
+
else:
|
883 |
+
raise NotImplementedError(
|
884 |
+
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
885 |
+
)
|
886 |
+
|
887 |
+
return feature_extractor, image_encoder
|
888 |
+
|
889 |
+
|
890 |
+
def stable_unclip_image_noising_components(
|
891 |
+
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
892 |
+
):
|
893 |
+
"""
|
894 |
+
Returns the noising components for the img2img and txt2img unclip pipelines.
|
895 |
+
|
896 |
+
Converts the stability noise augmentor into
|
897 |
+
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
898 |
+
2. a `DDPMScheduler` for holding the noise schedule
|
899 |
+
|
900 |
+
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
901 |
+
"""
|
902 |
+
noise_aug_config = original_config.model.params.noise_aug_config
|
903 |
+
noise_aug_class = noise_aug_config.target
|
904 |
+
noise_aug_class = noise_aug_class.split(".")[-1]
|
905 |
+
|
906 |
+
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
907 |
+
noise_aug_config = noise_aug_config.params
|
908 |
+
embedding_dim = noise_aug_config.timestep_dim
|
909 |
+
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
910 |
+
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
911 |
+
|
912 |
+
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
913 |
+
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
914 |
+
|
915 |
+
if "clip_stats_path" in noise_aug_config:
|
916 |
+
if clip_stats_path is None:
|
917 |
+
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
918 |
+
|
919 |
+
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
920 |
+
clip_mean = clip_mean[None, :]
|
921 |
+
clip_std = clip_std[None, :]
|
922 |
+
|
923 |
+
clip_stats_state_dict = {
|
924 |
+
"mean": clip_mean,
|
925 |
+
"std": clip_std,
|
926 |
+
}
|
927 |
+
|
928 |
+
image_normalizer.load_state_dict(clip_stats_state_dict)
|
929 |
+
else:
|
930 |
+
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
931 |
+
|
932 |
+
return image_normalizer, image_noising_scheduler
|
933 |
+
|
934 |
+
|
935 |
+
def convert_controlnet_checkpoint(
|
936 |
+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
937 |
+
):
|
938 |
+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
939 |
+
ctrlnet_config["upcast_attention"] = upcast_attention
|
940 |
+
|
941 |
+
ctrlnet_config.pop("sample_size")
|
942 |
+
|
943 |
+
controlnet_model = ControlNetModel(**ctrlnet_config)
|
944 |
+
|
945 |
+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
946 |
+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
947 |
+
)
|
948 |
+
|
949 |
+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
950 |
+
|
951 |
+
return controlnet_model
|
animatelcm/utils/convert_lora_safetensor_to_diffusers.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Conversion script for the LoRA's safetensors checkpoints. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from diffusers import StableDiffusionPipeline
|
24 |
+
|
25 |
+
|
26 |
+
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
|
27 |
+
# directly update weight in diffusers model
|
28 |
+
for key in state_dict:
|
29 |
+
# only process lora down key
|
30 |
+
if "up." in key: continue
|
31 |
+
|
32 |
+
up_key = key.replace(".down.", ".up.")
|
33 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
34 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
35 |
+
layer_infos = model_key.split(".")[:-1]
|
36 |
+
|
37 |
+
curr_layer = pipeline.unet
|
38 |
+
while len(layer_infos) > 0:
|
39 |
+
temp_name = layer_infos.pop(0)
|
40 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
41 |
+
|
42 |
+
weight_down = state_dict[key]
|
43 |
+
weight_up = state_dict[up_key]
|
44 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
45 |
+
|
46 |
+
return pipeline
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
51 |
+
# load base model
|
52 |
+
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
53 |
+
|
54 |
+
# load LoRA weight from .safetensors
|
55 |
+
# state_dict = load_file(checkpoint_path)
|
56 |
+
|
57 |
+
visited = []
|
58 |
+
|
59 |
+
# directly update weight in diffusers model
|
60 |
+
for key in state_dict:
|
61 |
+
# it is suggested to print out the key, it usually will be something like below
|
62 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
63 |
+
|
64 |
+
# as we have set the alpha beforehand, so just skip
|
65 |
+
if ".alpha" in key or key in visited:
|
66 |
+
continue
|
67 |
+
|
68 |
+
if "text" in key:
|
69 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
70 |
+
curr_layer = pipeline.text_encoder
|
71 |
+
else:
|
72 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
73 |
+
curr_layer = pipeline.unet
|
74 |
+
|
75 |
+
# find the target layer
|
76 |
+
temp_name = layer_infos.pop(0)
|
77 |
+
while len(layer_infos) > -1:
|
78 |
+
try:
|
79 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
80 |
+
if len(layer_infos) > 0:
|
81 |
+
temp_name = layer_infos.pop(0)
|
82 |
+
elif len(layer_infos) == 0:
|
83 |
+
break
|
84 |
+
except Exception:
|
85 |
+
if len(temp_name) > 0:
|
86 |
+
temp_name += "_" + layer_infos.pop(0)
|
87 |
+
else:
|
88 |
+
temp_name = layer_infos.pop(0)
|
89 |
+
|
90 |
+
pair_keys = []
|
91 |
+
if "lora_down" in key:
|
92 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
93 |
+
pair_keys.append(key)
|
94 |
+
else:
|
95 |
+
pair_keys.append(key)
|
96 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
97 |
+
|
98 |
+
# update weight
|
99 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
100 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
101 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
102 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
103 |
+
else:
|
104 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
105 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
106 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
107 |
+
|
108 |
+
# update visited list
|
109 |
+
for item in pair_keys:
|
110 |
+
visited.append(item)
|
111 |
+
|
112 |
+
return pipeline
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
parser = argparse.ArgumentParser()
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
123 |
+
)
|
124 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
125 |
+
parser.add_argument(
|
126 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--lora_prefix_text_encoder",
|
130 |
+
default="lora_te",
|
131 |
+
type=str,
|
132 |
+
help="The prefix of text encoder weight in safetensors",
|
133 |
+
)
|
134 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
135 |
+
parser.add_argument(
|
136 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
137 |
+
)
|
138 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
139 |
+
|
140 |
+
args = parser.parse_args()
|
141 |
+
|
142 |
+
base_model_path = args.base_model_path
|
143 |
+
checkpoint_path = args.checkpoint_path
|
144 |
+
dump_path = args.dump_path
|
145 |
+
lora_prefix_unet = args.lora_prefix_unet
|
146 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
147 |
+
alpha = args.alpha
|
148 |
+
|
149 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
150 |
+
|
151 |
+
pipe = pipe.to(args.device)
|
152 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
animatelcm/utils/lcm_utils.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from safetensors import safe_open
|
4 |
+
|
5 |
+
|
6 |
+
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
|
7 |
+
"""
|
8 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
9 |
+
|
10 |
+
Args:
|
11 |
+
timesteps (`torch.Tensor`):
|
12 |
+
generate embedding vectors at these timesteps
|
13 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
14 |
+
dimension of the embeddings to generate
|
15 |
+
dtype:
|
16 |
+
data type of the generated embeddings
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
|
20 |
+
"""
|
21 |
+
assert len(w.shape) == 1
|
22 |
+
w = w * 1000.0
|
23 |
+
|
24 |
+
half_dim = embedding_dim // 2
|
25 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
26 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
27 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
28 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
29 |
+
if embedding_dim % 2 == 1: # zero pad
|
30 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
31 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
32 |
+
return emb
|
33 |
+
|
34 |
+
|
35 |
+
def append_dims(x, target_dims):
|
36 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
37 |
+
dims_to_append = target_dims - x.ndim
|
38 |
+
if dims_to_append < 0:
|
39 |
+
raise ValueError(
|
40 |
+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
41 |
+
return x[(...,) + (None,) * dims_to_append]
|
42 |
+
|
43 |
+
|
44 |
+
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
|
45 |
+
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
|
46 |
+
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
|
47 |
+
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
|
48 |
+
return c_skip, c_out
|
49 |
+
|
50 |
+
|
51 |
+
# Compare LCMScheduler.step, Step 4
|
52 |
+
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
|
53 |
+
if prediction_type == "epsilon":
|
54 |
+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
55 |
+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
56 |
+
pred_x_0 = (sample - sigmas * model_output) / alphas
|
57 |
+
elif prediction_type == "v_prediction":
|
58 |
+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
59 |
+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
60 |
+
pred_x_0 = alphas * sample - sigmas * model_output
|
61 |
+
else:
|
62 |
+
raise ValueError(
|
63 |
+
f"Prediction type {prediction_type} currently not supported.")
|
64 |
+
|
65 |
+
return pred_x_0
|
66 |
+
|
67 |
+
|
68 |
+
def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas):
|
69 |
+
if prediction_type == "epsilon":
|
70 |
+
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
|
71 |
+
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
|
72 |
+
sample = sample * alphas / sigmas
|
73 |
+
else:
|
74 |
+
raise ValueError(
|
75 |
+
f"Prediction type {prediction_type} currently not supported.")
|
76 |
+
|
77 |
+
return sample
|
78 |
+
|
79 |
+
|
80 |
+
def extract_into_tensor(a, t, x_shape):
|
81 |
+
b, *_ = t.shape
|
82 |
+
out = a.gather(-1, t)
|
83 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
84 |
+
|
85 |
+
|
86 |
+
class DDIMSolver:
|
87 |
+
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
|
88 |
+
# DDIM sampling parameters
|
89 |
+
step_ratio = timesteps // ddim_timesteps
|
90 |
+
self.ddim_timesteps = (
|
91 |
+
np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
|
92 |
+
# self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1
|
93 |
+
self.ddim_timesteps_prev = np.asarray(
|
94 |
+
[0] + self.ddim_timesteps[:-1].tolist()
|
95 |
+
)
|
96 |
+
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
|
97 |
+
self.ddim_alpha_cumprods_prev = np.asarray(
|
98 |
+
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
|
99 |
+
)
|
100 |
+
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
|
101 |
+
self.ddim_alpha_cumprods_prev = np.asarray(
|
102 |
+
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
|
103 |
+
)
|
104 |
+
# convert to torch tensors
|
105 |
+
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
|
106 |
+
self.ddim_timesteps_prev = torch.from_numpy(
|
107 |
+
self.ddim_timesteps_prev).long()
|
108 |
+
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
|
109 |
+
self.ddim_alpha_cumprods_prev = torch.from_numpy(
|
110 |
+
self.ddim_alpha_cumprods_prev)
|
111 |
+
|
112 |
+
def to(self, device):
|
113 |
+
self.ddim_timesteps = self.ddim_timesteps.to(device)
|
114 |
+
self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
|
115 |
+
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
|
116 |
+
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(
|
117 |
+
device)
|
118 |
+
return self
|
119 |
+
|
120 |
+
def ddim_step(self, pred_x0, pred_noise, timestep_index):
|
121 |
+
alpha_cumprod_prev = extract_into_tensor(
|
122 |
+
self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
|
123 |
+
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
|
124 |
+
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
|
125 |
+
return x_prev
|
126 |
+
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def update_ema(target_params, source_params, rate=0.99):
|
130 |
+
"""
|
131 |
+
Update target parameters to be closer to those of source parameters using
|
132 |
+
an exponential moving average.
|
133 |
+
|
134 |
+
:param target_params: the target parameter sequence.
|
135 |
+
:param source_params: the source parameter sequence.
|
136 |
+
:param rate: the EMA rate (closer to 1 means slower).
|
137 |
+
"""
|
138 |
+
for targ, src in zip(target_params, source_params):
|
139 |
+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
140 |
+
|
141 |
+
|
142 |
+
def convert_lcm_lora(unet, path, alpha=1.0):
|
143 |
+
|
144 |
+
if path.endswith(("ckpt",)):
|
145 |
+
state_dict = torch.load(path, map_location="cpu")
|
146 |
+
else:
|
147 |
+
state_dict = {}
|
148 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
149 |
+
for key in f.keys():
|
150 |
+
state_dict[key] = f.get_tensor(key)
|
151 |
+
|
152 |
+
num_alpha = 0
|
153 |
+
for key in state_dict.keys():
|
154 |
+
if "alpha" in key:
|
155 |
+
num_alpha += 1
|
156 |
+
|
157 |
+
lora_keys = [k for k in state_dict.keys(
|
158 |
+
) if k.endswith("lora_down.weight")]
|
159 |
+
|
160 |
+
updated_state_dict = {}
|
161 |
+
for key in lora_keys:
|
162 |
+
lora_name = key.split(".")[0]
|
163 |
+
|
164 |
+
if lora_name.startswith("lora_unet_"):
|
165 |
+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
166 |
+
|
167 |
+
if "input.blocks" in diffusers_name:
|
168 |
+
diffusers_name = diffusers_name.replace(
|
169 |
+
"input.blocks", "down_blocks")
|
170 |
+
else:
|
171 |
+
diffusers_name = diffusers_name.replace(
|
172 |
+
"down.blocks", "down_blocks")
|
173 |
+
|
174 |
+
if "middle.block" in diffusers_name:
|
175 |
+
diffusers_name = diffusers_name.replace(
|
176 |
+
"middle.block", "mid_block")
|
177 |
+
else:
|
178 |
+
diffusers_name = diffusers_name.replace(
|
179 |
+
"mid.block", "mid_block")
|
180 |
+
if "output.blocks" in diffusers_name:
|
181 |
+
diffusers_name = diffusers_name.replace(
|
182 |
+
"output.blocks", "up_blocks")
|
183 |
+
else:
|
184 |
+
diffusers_name = diffusers_name.replace(
|
185 |
+
"up.blocks", "up_blocks")
|
186 |
+
|
187 |
+
diffusers_name = diffusers_name.replace(
|
188 |
+
"transformer.blocks", "transformer_blocks")
|
189 |
+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
190 |
+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
191 |
+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
192 |
+
diffusers_name = diffusers_name.replace(
|
193 |
+
"to.out.0.lora", "to_out_lora")
|
194 |
+
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
195 |
+
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
196 |
+
diffusers_name = diffusers_name.replace(
|
197 |
+
"time.emb.proj", "time_emb_proj")
|
198 |
+
diffusers_name = diffusers_name.replace(
|
199 |
+
"conv.shortcut", "conv_shortcut")
|
200 |
+
|
201 |
+
updated_state_dict[diffusers_name] = state_dict[key]
|
202 |
+
up_diffusers_name = diffusers_name.replace(".down.", ".up.")
|
203 |
+
up_key = key.replace("lora_down.weight", "lora_up.weight")
|
204 |
+
updated_state_dict[up_diffusers_name] = state_dict[up_key]
|
205 |
+
|
206 |
+
state_dict = updated_state_dict
|
207 |
+
|
208 |
+
num_lora = 0
|
209 |
+
for key in state_dict:
|
210 |
+
if "up." in key:
|
211 |
+
continue
|
212 |
+
up_key = key.replace(".down.", ".up.")
|
213 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace(
|
214 |
+
"down.", "").replace("up.", "").replace(".lora", "")
|
215 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
216 |
+
layer_infos = model_key.split(".")[:-1]
|
217 |
+
|
218 |
+
curr_layer = unet
|
219 |
+
while len(layer_infos) > 0:
|
220 |
+
temp_name = layer_infos.pop(0)
|
221 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
222 |
+
|
223 |
+
weight_down = state_dict[key].to(
|
224 |
+
curr_layer.weight.data.device, curr_layer.weight.data.dtype)
|
225 |
+
weight_up = state_dict[up_key].to(
|
226 |
+
curr_layer.weight.data.device, curr_layer.weight.data.dtype)
|
227 |
+
|
228 |
+
if weight_up.ndim == 2:
|
229 |
+
curr_layer.weight.data += 1/8 * alpha * \
|
230 |
+
torch.mm(weight_up, weight_down)
|
231 |
+
else:
|
232 |
+
assert weight_up.ndim == 4
|
233 |
+
curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten(
|
234 |
+
start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape)
|
235 |
+
num_lora += 1
|
236 |
+
|
237 |
+
return unet
|
animatelcm/utils/util.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
from safetensors import safe_open
|
11 |
+
from tqdm import tqdm
|
12 |
+
from einops import rearrange
|
13 |
+
from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
14 |
+
from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
|
15 |
+
|
16 |
+
|
17 |
+
def zero_rank_print(s):
|
18 |
+
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
19 |
+
|
20 |
+
|
21 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
22 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
23 |
+
outputs = []
|
24 |
+
for x in videos:
|
25 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
26 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
27 |
+
if rescale:
|
28 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
29 |
+
x = (x * 255).numpy().astype(np.uint8)
|
30 |
+
outputs.append(x)
|
31 |
+
|
32 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
33 |
+
imageio.mimsave(path, outputs, fps=fps)
|
34 |
+
|
35 |
+
|
36 |
+
# DDIM Inversion
|
37 |
+
@torch.no_grad()
|
38 |
+
def init_prompt(prompt, pipeline):
|
39 |
+
uncond_input = pipeline.tokenizer(
|
40 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
41 |
+
return_tensors="pt"
|
42 |
+
)
|
43 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
44 |
+
text_input = pipeline.tokenizer(
|
45 |
+
[prompt],
|
46 |
+
padding="max_length",
|
47 |
+
max_length=pipeline.tokenizer.model_max_length,
|
48 |
+
truncation=True,
|
49 |
+
return_tensors="pt",
|
50 |
+
)
|
51 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
52 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
53 |
+
|
54 |
+
return context
|
55 |
+
|
56 |
+
|
57 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
58 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
59 |
+
timestep, next_timestep = min(
|
60 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
61 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
62 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
63 |
+
beta_prod_t = 1 - alpha_prod_t
|
64 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
65 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
66 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
67 |
+
return next_sample
|
68 |
+
|
69 |
+
|
70 |
+
def get_noise_pred_single(latents, t, context, unet):
|
71 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
72 |
+
return noise_pred
|
73 |
+
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
77 |
+
context = init_prompt(prompt, pipeline)
|
78 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
79 |
+
all_latent = [latent]
|
80 |
+
latent = latent.clone().detach()
|
81 |
+
for i in tqdm(range(num_inv_steps)):
|
82 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
83 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
84 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
85 |
+
all_latent.append(latent)
|
86 |
+
return all_latent
|
87 |
+
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
91 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
92 |
+
return ddim_latents
|
93 |
+
|
94 |
+
def load_weights(
|
95 |
+
animation_pipeline,
|
96 |
+
motion_module_path = "",
|
97 |
+
motion_module_lora_configs = [],
|
98 |
+
dreambooth_model_path = "",
|
99 |
+
lora_model_path = "",
|
100 |
+
lora_alpha = 0.8,
|
101 |
+
):
|
102 |
+
unet_state_dict = {}
|
103 |
+
if motion_module_path != "":
|
104 |
+
print(f"load motion module from {motion_module_path}")
|
105 |
+
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
106 |
+
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
107 |
+
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
108 |
+
|
109 |
+
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
110 |
+
assert len(unexpected) == 0
|
111 |
+
del unet_state_dict
|
112 |
+
|
113 |
+
if dreambooth_model_path != "":
|
114 |
+
print(f"load dreambooth model from {dreambooth_model_path}")
|
115 |
+
if dreambooth_model_path.endswith(".safetensors"):
|
116 |
+
dreambooth_state_dict = {}
|
117 |
+
with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
|
118 |
+
for key in f.keys():
|
119 |
+
dreambooth_state_dict[key] = f.get_tensor(key)
|
120 |
+
elif dreambooth_model_path.endswith(".ckpt"):
|
121 |
+
dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
|
122 |
+
|
123 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
|
124 |
+
animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
125 |
+
|
126 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
|
127 |
+
animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
128 |
+
|
129 |
+
animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
|
130 |
+
del dreambooth_state_dict
|
131 |
+
|
132 |
+
if lora_model_path != "":
|
133 |
+
print(f"load lora model from {lora_model_path}")
|
134 |
+
assert lora_model_path.endswith(".safetensors")
|
135 |
+
lora_state_dict = {}
|
136 |
+
with safe_open(lora_model_path, framework="pt", device="cpu") as f:
|
137 |
+
for key in f.keys():
|
138 |
+
lora_state_dict[key] = f.get_tensor(key)
|
139 |
+
|
140 |
+
animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
|
141 |
+
del lora_state_dict
|
142 |
+
|
143 |
+
|
144 |
+
for motion_module_lora_config in motion_module_lora_configs:
|
145 |
+
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
146 |
+
print(f"load motion LoRA from {path}")
|
147 |
+
|
148 |
+
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
149 |
+
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
150 |
+
|
151 |
+
animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
|
152 |
+
|
153 |
+
return animation_pipeline
|
app.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from glob import glob
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from datetime import datetime
|
11 |
+
from safetensors import safe_open
|
12 |
+
|
13 |
+
from diffusers import AutoencoderKL
|
14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
15 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
16 |
+
|
17 |
+
from animatelcm.scheduler.lcm_scheduler import LCMScheduler
|
18 |
+
from animatelcm.models.unet import UNet3DConditionModel
|
19 |
+
from animatelcm.pipelines.pipeline_animation import AnimationPipeline
|
20 |
+
from animatelcm.utils.util import save_videos_grid
|
21 |
+
from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
22 |
+
from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
23 |
+
from animatelcm.utils.lcm_utils import convert_lcm_lora
|
24 |
+
import copy
|
25 |
+
|
26 |
+
sample_idx = 0
|
27 |
+
scheduler_dict = {
|
28 |
+
"LCM": LCMScheduler,
|
29 |
+
}
|
30 |
+
|
31 |
+
css = """
|
32 |
+
.toolbutton {
|
33 |
+
margin-buttom: 0em 0em 0em 0em;
|
34 |
+
max-width: 2.5em;
|
35 |
+
min-width: 2.5em !important;
|
36 |
+
height: 2.5em;
|
37 |
+
}
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
class AnimateController:
|
42 |
+
def __init__(self):
|
43 |
+
|
44 |
+
# config dirs
|
45 |
+
self.basedir = os.getcwd()
|
46 |
+
self.stable_diffusion_dir = os.path.join(
|
47 |
+
self.basedir, "models", "StableDiffusion")
|
48 |
+
self.motion_module_dir = os.path.join(
|
49 |
+
self.basedir, "models", "Motion_Module")
|
50 |
+
self.personalized_model_dir = os.path.join(
|
51 |
+
self.basedir, "models", "DreamBooth_LoRA")
|
52 |
+
self.savedir = os.path.join(
|
53 |
+
self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
54 |
+
self.savedir_sample = os.path.join(self.savedir, "sample")
|
55 |
+
self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors"
|
56 |
+
os.makedirs(self.savedir, exist_ok=True)
|
57 |
+
|
58 |
+
self.stable_diffusion_list = []
|
59 |
+
self.motion_module_list = []
|
60 |
+
self.personalized_model_list = []
|
61 |
+
|
62 |
+
self.refresh_stable_diffusion()
|
63 |
+
self.refresh_motion_module()
|
64 |
+
self.refresh_personalized_model()
|
65 |
+
|
66 |
+
# config models
|
67 |
+
self.tokenizer = None
|
68 |
+
self.text_encoder = None
|
69 |
+
self.vae = None
|
70 |
+
self.unet = None
|
71 |
+
self.pipeline = None
|
72 |
+
self.lora_model_state_dict = {}
|
73 |
+
|
74 |
+
self.inference_config = OmegaConf.load("configs/inference.yaml")
|
75 |
+
|
76 |
+
def refresh_stable_diffusion(self):
|
77 |
+
self.stable_diffusion_list = glob(
|
78 |
+
os.path.join(self.stable_diffusion_dir, "*/"))
|
79 |
+
|
80 |
+
def refresh_motion_module(self):
|
81 |
+
motion_module_list = glob(os.path.join(
|
82 |
+
self.motion_module_dir, "*.ckpt"))
|
83 |
+
self.motion_module_list = [
|
84 |
+
os.path.basename(p) for p in motion_module_list]
|
85 |
+
|
86 |
+
def refresh_personalized_model(self):
|
87 |
+
personalized_model_list = glob(os.path.join(
|
88 |
+
self.personalized_model_dir, "*.safetensors"))
|
89 |
+
self.personalized_model_list = [
|
90 |
+
os.path.basename(p) for p in personalized_model_list]
|
91 |
+
|
92 |
+
def update_stable_diffusion(self, stable_diffusion_dropdown):
|
93 |
+
stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
|
94 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
95 |
+
stable_diffusion_dropdown, subfolder="tokenizer")
|
96 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
97 |
+
stable_diffusion_dropdown, subfolder="text_encoder").cuda()
|
98 |
+
self.vae = AutoencoderKL.from_pretrained(
|
99 |
+
stable_diffusion_dropdown, subfolder="vae").cuda()
|
100 |
+
self.unet = UNet3DConditionModel.from_pretrained_2d(
|
101 |
+
stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
102 |
+
return gr.Dropdown.update()
|
103 |
+
|
104 |
+
def update_motion_module(self, motion_module_dropdown):
|
105 |
+
if self.unet is None:
|
106 |
+
gr.Info(f"Please select a pretrained model path.")
|
107 |
+
return gr.Dropdown.update(value=None)
|
108 |
+
else:
|
109 |
+
motion_module_dropdown = os.path.join(
|
110 |
+
self.motion_module_dir, motion_module_dropdown)
|
111 |
+
motion_module_state_dict = torch.load(
|
112 |
+
motion_module_dropdown, map_location="cpu")
|
113 |
+
missing, unexpected = self.unet.load_state_dict(
|
114 |
+
motion_module_state_dict, strict=False)
|
115 |
+
assert len(unexpected) == 0
|
116 |
+
return gr.Dropdown.update()
|
117 |
+
|
118 |
+
def update_base_model(self, base_model_dropdown):
|
119 |
+
if self.unet is None:
|
120 |
+
gr.Info(f"Please select a pretrained model path.")
|
121 |
+
return gr.Dropdown.update(value=None)
|
122 |
+
else:
|
123 |
+
base_model_dropdown = os.path.join(
|
124 |
+
self.personalized_model_dir, base_model_dropdown)
|
125 |
+
base_model_state_dict = {}
|
126 |
+
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
127 |
+
for key in f.keys():
|
128 |
+
base_model_state_dict[key] = f.get_tensor(key)
|
129 |
+
|
130 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
131 |
+
base_model_state_dict, self.vae.config)
|
132 |
+
self.vae.load_state_dict(converted_vae_checkpoint)
|
133 |
+
|
134 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
135 |
+
base_model_state_dict, self.unet.config)
|
136 |
+
self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
137 |
+
|
138 |
+
# self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
|
139 |
+
return gr.Dropdown.update()
|
140 |
+
|
141 |
+
def update_lora_model(self, lora_model_dropdown):
|
142 |
+
lora_model_dropdown = os.path.join(
|
143 |
+
self.personalized_model_dir, lora_model_dropdown)
|
144 |
+
self.lora_model_state_dict = {}
|
145 |
+
if lora_model_dropdown == "none":
|
146 |
+
pass
|
147 |
+
else:
|
148 |
+
with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
|
149 |
+
for key in f.keys():
|
150 |
+
self.lora_model_state_dict[key] = f.get_tensor(key)
|
151 |
+
return gr.Dropdown.update()
|
152 |
+
|
153 |
+
def animate(
|
154 |
+
self,
|
155 |
+
stable_diffusion_dropdown,
|
156 |
+
motion_module_dropdown,
|
157 |
+
base_model_dropdown,
|
158 |
+
lora_alpha_slider,
|
159 |
+
spatial_lora_slider,
|
160 |
+
prompt_textbox,
|
161 |
+
negative_prompt_textbox,
|
162 |
+
sampler_dropdown,
|
163 |
+
sample_step_slider,
|
164 |
+
width_slider,
|
165 |
+
length_slider,
|
166 |
+
height_slider,
|
167 |
+
cfg_scale_slider,
|
168 |
+
seed_textbox
|
169 |
+
):
|
170 |
+
if self.unet is None:
|
171 |
+
raise gr.Error(f"Please select a pretrained model path.")
|
172 |
+
if motion_module_dropdown == "":
|
173 |
+
raise gr.Error(f"Please select a motion module.")
|
174 |
+
if base_model_dropdown == "":
|
175 |
+
raise gr.Error(f"Please select a base DreamBooth model.")
|
176 |
+
|
177 |
+
if is_xformers_available():
|
178 |
+
self.unet.enable_xformers_memory_efficient_attention()
|
179 |
+
|
180 |
+
pipeline = AnimationPipeline(
|
181 |
+
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
|
182 |
+
scheduler=scheduler_dict[sampler_dropdown](
|
183 |
+
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
184 |
+
).to("cuda")
|
185 |
+
|
186 |
+
if self.lora_model_state_dict != {}:
|
187 |
+
pipeline = convert_lora(
|
188 |
+
pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
|
189 |
+
|
190 |
+
pipeline.unet = convert_lcm_lora(copy.deepcopy(
|
191 |
+
self.unet), self.lcm_lora_path, spatial_lora_slider)
|
192 |
+
|
193 |
+
pipeline.to("cuda")
|
194 |
+
|
195 |
+
if seed_textbox != -1 and seed_textbox != "":
|
196 |
+
torch.manual_seed(int(seed_textbox))
|
197 |
+
else:
|
198 |
+
torch.seed()
|
199 |
+
seed = torch.initial_seed()
|
200 |
+
|
201 |
+
sample = pipeline(
|
202 |
+
prompt_textbox,
|
203 |
+
negative_prompt=negative_prompt_textbox,
|
204 |
+
num_inference_steps=sample_step_slider,
|
205 |
+
guidance_scale=cfg_scale_slider,
|
206 |
+
width=width_slider,
|
207 |
+
height=height_slider,
|
208 |
+
video_length=length_slider,
|
209 |
+
).videos
|
210 |
+
|
211 |
+
save_sample_path = os.path.join(
|
212 |
+
self.savedir_sample, f"{sample_idx}.mp4")
|
213 |
+
save_videos_grid(sample, save_sample_path)
|
214 |
+
|
215 |
+
sample_config = {
|
216 |
+
"prompt": prompt_textbox,
|
217 |
+
"n_prompt": negative_prompt_textbox,
|
218 |
+
"sampler": sampler_dropdown,
|
219 |
+
"num_inference_steps": sample_step_slider,
|
220 |
+
"guidance_scale": cfg_scale_slider,
|
221 |
+
"width": width_slider,
|
222 |
+
"height": height_slider,
|
223 |
+
"video_length": length_slider,
|
224 |
+
"seed": seed
|
225 |
+
}
|
226 |
+
json_str = json.dumps(sample_config, indent=4)
|
227 |
+
with open(os.path.join(self.savedir, "logs.json"), "a") as f:
|
228 |
+
f.write(json_str)
|
229 |
+
f.write("\n\n")
|
230 |
+
return gr.Video.update(value=save_sample_path)
|
231 |
+
|
232 |
+
|
233 |
+
controller = AnimateController()
|
234 |
+
|
235 |
+
|
236 |
+
def ui():
|
237 |
+
with gr.Blocks(css=css) as demo:
|
238 |
+
gr.Markdown(
|
239 |
+
"""
|
240 |
+
# [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
|
241 |
+
Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)<br>
|
242 |
+
[arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM)
|
243 |
+
"""
|
244 |
+
)
|
245 |
+
with gr.Column(variant="panel"):
|
246 |
+
gr.Markdown(
|
247 |
+
"""
|
248 |
+
### 1. Model checkpoints (select pretrained model path first).
|
249 |
+
"""
|
250 |
+
)
|
251 |
+
with gr.Row():
|
252 |
+
stable_diffusion_dropdown = gr.Dropdown(
|
253 |
+
label="Pretrained Model Path",
|
254 |
+
choices=controller.stable_diffusion_list,
|
255 |
+
interactive=True,
|
256 |
+
)
|
257 |
+
stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[
|
258 |
+
stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
|
259 |
+
|
260 |
+
stable_diffusion_refresh_button = gr.Button(
|
261 |
+
value="\U0001F503", elem_classes="toolbutton")
|
262 |
+
|
263 |
+
def update_stable_diffusion():
|
264 |
+
controller.refresh_stable_diffusion()
|
265 |
+
return gr.Dropdown.update(choices=controller.stable_diffusion_list)
|
266 |
+
stable_diffusion_refresh_button.click(
|
267 |
+
fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
|
268 |
+
|
269 |
+
with gr.Row():
|
270 |
+
motion_module_dropdown = gr.Dropdown(
|
271 |
+
label="Select motion module",
|
272 |
+
choices=controller.motion_module_list,
|
273 |
+
interactive=True,
|
274 |
+
)
|
275 |
+
motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[
|
276 |
+
motion_module_dropdown], outputs=[motion_module_dropdown])
|
277 |
+
|
278 |
+
motion_module_refresh_button = gr.Button(
|
279 |
+
value="\U0001F503", elem_classes="toolbutton")
|
280 |
+
|
281 |
+
def update_motion_module():
|
282 |
+
controller.refresh_motion_module()
|
283 |
+
return gr.Dropdown.update(choices=controller.motion_module_list)
|
284 |
+
motion_module_refresh_button.click(
|
285 |
+
fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
|
286 |
+
|
287 |
+
base_model_dropdown = gr.Dropdown(
|
288 |
+
label="Select base Dreambooth model (required)",
|
289 |
+
choices=controller.personalized_model_list,
|
290 |
+
interactive=True,
|
291 |
+
)
|
292 |
+
base_model_dropdown.change(fn=controller.update_base_model, inputs=[
|
293 |
+
base_model_dropdown], outputs=[base_model_dropdown])
|
294 |
+
|
295 |
+
lora_model_dropdown = gr.Dropdown(
|
296 |
+
label="Select LoRA model (optional)",
|
297 |
+
choices=["none"]
|
298 |
+
value="none",
|
299 |
+
interactive=True,
|
300 |
+
)
|
301 |
+
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
|
302 |
+
lora_model_dropdown], outputs=[lora_model_dropdown])
|
303 |
+
|
304 |
+
lora_alpha_slider = gr.Slider(
|
305 |
+
label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
|
306 |
+
spatial_lora_slider = gr.Slider(
|
307 |
+
label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
|
308 |
+
|
309 |
+
personalized_refresh_button = gr.Button(
|
310 |
+
value="\U0001F503", elem_classes="toolbutton")
|
311 |
+
|
312 |
+
def update_personalized_model():
|
313 |
+
controller.refresh_personalized_model()
|
314 |
+
return [
|
315 |
+
gr.Dropdown.update(
|
316 |
+
choices=controller.personalized_model_list),
|
317 |
+
gr.Dropdown.update(
|
318 |
+
choices=["none"] + controller.personalized_model_list)
|
319 |
+
]
|
320 |
+
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
|
321 |
+
base_model_dropdown, lora_model_dropdown])
|
322 |
+
|
323 |
+
with gr.Column(variant="panel"):
|
324 |
+
gr.Markdown(
|
325 |
+
"""
|
326 |
+
### 2. Configs for AnimateLCM.
|
327 |
+
"""
|
328 |
+
)
|
329 |
+
|
330 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2)
|
331 |
+
negative_prompt_textbox = gr.Textbox(
|
332 |
+
label="Negative prompt", lines=2)
|
333 |
+
|
334 |
+
with gr.Row().style(equal_height=False):
|
335 |
+
with gr.Column():
|
336 |
+
with gr.Row():
|
337 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
|
338 |
+
scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
339 |
+
sample_step_slider = gr.Slider(
|
340 |
+
label="Sampling steps", value=4, minimum=1, maximum=25, step=1)
|
341 |
+
|
342 |
+
width_slider = gr.Slider(
|
343 |
+
label="Width", value=512, minimum=256, maximum=1024, step=64)
|
344 |
+
height_slider = gr.Slider(
|
345 |
+
label="Height", value=512, minimum=256, maximum=1024, step=64)
|
346 |
+
length_slider = gr.Slider(
|
347 |
+
label="Animation length", value=16, minimum=12, maximum=20, step=1)
|
348 |
+
cfg_scale_slider = gr.Slider(
|
349 |
+
label="CFG Scale", value=1, minimum=1, maximum=2)
|
350 |
+
|
351 |
+
with gr.Row():
|
352 |
+
seed_textbox = gr.Textbox(label="Seed", value=-1)
|
353 |
+
seed_button = gr.Button(
|
354 |
+
value="\U0001F3B2", elem_classes="toolbutton")
|
355 |
+
seed_button.click(fn=lambda: gr.Textbox.update(
|
356 |
+
value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
357 |
+
|
358 |
+
generate_button = gr.Button(
|
359 |
+
value="Generate", variant='primary')
|
360 |
+
|
361 |
+
result_video = gr.Video(
|
362 |
+
label="Generated Animation", interactive=False)
|
363 |
+
|
364 |
+
generate_button.click(
|
365 |
+
fn=controller.animate,
|
366 |
+
inputs=[
|
367 |
+
stable_diffusion_dropdown,
|
368 |
+
motion_module_dropdown,
|
369 |
+
base_model_dropdown,
|
370 |
+
lora_alpha_slider,
|
371 |
+
spatial_lora_slider,
|
372 |
+
prompt_textbox,
|
373 |
+
negative_prompt_textbox,
|
374 |
+
sampler_dropdown,
|
375 |
+
sample_step_slider,
|
376 |
+
width_slider,
|
377 |
+
length_slider,
|
378 |
+
height_slider,
|
379 |
+
cfg_scale_slider,
|
380 |
+
seed_textbox,
|
381 |
+
],
|
382 |
+
outputs=[result_video]
|
383 |
+
)
|
384 |
+
|
385 |
+
return demo
|
386 |
+
|
387 |
+
|
388 |
+
if __name__ == "__main__":
|
389 |
+
demo = ui()
|
390 |
+
# gr.close_all()
|
391 |
+
demo.queue(concurrency_count=3, max_size=20)
|
392 |
+
demo.launch(share=True, server_name="127.0.0.1")
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/DreamBooth_LoRA/cartoon2d.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cbfba64e662370f59d4aa2aa69bf16749fce93846ccce20506aee5df01169859
|
3 |
+
size 4244124028
|
models/DreamBooth_LoRA/cartoon3d.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a6b4c0392d7486bfa4fd1a31c7b7d2679f743f8ea8d9f219c82b5c33db31ddb9
|
3 |
+
size 2132625644
|
models/DreamBooth_LoRA/realistic1.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0d1994c73d784a17a5b335ae8bda02dcc8dd2fc5f5dbf55169d5aab385e53f2
|
3 |
+
size 2132650523
|
models/DreamBooth_LoRA/realistic2.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a38fa861a24f4f4c6e0f68289101e645dd9ca1e93e1049cc8a4f2a77513fad52
|
3 |
+
size 2400040290
|
models/LCM_LoRA/Put LCMLoRA checkpoints here.txt
ADDED
File without changes
|
models/LCM_LoRA/sd15_t2v_beta_lora.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f90d840e075ff588a58e22c6586e2ae9a6f7922996ee6649a7f01072333afe4
|
3 |
+
size 134621556
|
models/Motion_Module/Put motion module checkpoints here.txt
ADDED
File without changes
|
models/Motion_Module/sd15_t2v_beta_motion.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b46c3de62e5696af72c4056e3cdcbea12fbc19581c0aad7b6f2b027851148f5f
|
3 |
+
size 1813041929
|
models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt
ADDED
File without changes
|
models/StableDiffusion/stable-diffusion-v1-5/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
+
v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
|
35 |
+
v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
|
models/StableDiffusion/stable-diffusion-v1-5/README.md
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: creativeml-openrail-m
|
3 |
+
tags:
|
4 |
+
- stable-diffusion
|
5 |
+
- stable-diffusion-diffusers
|
6 |
+
- text-to-image
|
7 |
+
inference: true
|
8 |
+
extra_gated_prompt: |-
|
9 |
+
This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
|
10 |
+
The CreativeML OpenRAIL License specifies:
|
11 |
+
|
12 |
+
1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content
|
13 |
+
2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
|
14 |
+
3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
|
15 |
+
Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
16 |
+
|
17 |
+
extra_gated_heading: Please read the LICENSE to access this model
|
18 |
+
---
|
19 |
+
|
20 |
+
# Stable Diffusion v1-5 Model Card
|
21 |
+
|
22 |
+
Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.
|
23 |
+
For more information about how Stable Diffusion functions, please have a look at [🤗's Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion).
|
24 |
+
|
25 |
+
The **Stable-Diffusion-v1-5** checkpoint was initialized with the weights of the [Stable-Diffusion-v1-2](https:/steps/huggingface.co/CompVis/stable-diffusion-v1-2)
|
26 |
+
checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
27 |
+
|
28 |
+
You can use this both with the [🧨Diffusers library](https://github.com/huggingface/diffusers) and the [RunwayML GitHub repository](https://github.com/runwayml/stable-diffusion).
|
29 |
+
|
30 |
+
### Diffusers
|
31 |
+
```py
|
32 |
+
from diffusers import StableDiffusionPipeline
|
33 |
+
import torch
|
34 |
+
|
35 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
36 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
37 |
+
pipe = pipe.to("cuda")
|
38 |
+
|
39 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
40 |
+
image = pipe(prompt).images[0]
|
41 |
+
|
42 |
+
image.save("astronaut_rides_horse.png")
|
43 |
+
```
|
44 |
+
For more detailed instructions, use-cases and examples in JAX follow the instructions [here](https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion)
|
45 |
+
|
46 |
+
### Original GitHub Repository
|
47 |
+
|
48 |
+
1. Download the weights
|
49 |
+
- [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference
|
50 |
+
- [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning
|
51 |
+
|
52 |
+
2. Follow instructions [here](https://github.com/runwayml/stable-diffusion).
|
53 |
+
|
54 |
+
## Model Details
|
55 |
+
- **Developed by:** Robin Rombach, Patrick Esser
|
56 |
+
- **Model type:** Diffusion-based text-to-image generation model
|
57 |
+
- **Language(s):** English
|
58 |
+
- **License:** [The CreativeML OpenRAIL M license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.
|
59 |
+
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
|
60 |
+
- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
|
61 |
+
- **Cite as:**
|
62 |
+
|
63 |
+
@InProceedings{Rombach_2022_CVPR,
|
64 |
+
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
65 |
+
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
66 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
67 |
+
month = {June},
|
68 |
+
year = {2022},
|
69 |
+
pages = {10684-10695}
|
70 |
+
}
|
71 |
+
|
72 |
+
# Uses
|
73 |
+
|
74 |
+
## Direct Use
|
75 |
+
The model is intended for research purposes only. Possible research areas and
|
76 |
+
tasks include
|
77 |
+
|
78 |
+
- Safe deployment of models which have the potential to generate harmful content.
|
79 |
+
- Probing and understanding the limitations and biases of generative models.
|
80 |
+
- Generation of artworks and use in design and other artistic processes.
|
81 |
+
- Applications in educational or creative tools.
|
82 |
+
- Research on generative models.
|
83 |
+
|
84 |
+
Excluded uses are described below.
|
85 |
+
|
86 |
+
### Misuse, Malicious Use, and Out-of-Scope Use
|
87 |
+
_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
|
88 |
+
|
89 |
+
|
90 |
+
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
91 |
+
|
92 |
+
#### Out-of-Scope Use
|
93 |
+
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
94 |
+
|
95 |
+
#### Misuse and Malicious Use
|
96 |
+
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
97 |
+
|
98 |
+
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
99 |
+
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
100 |
+
- Impersonating individuals without their consent.
|
101 |
+
- Sexual content without consent of the people who might see it.
|
102 |
+
- Mis- and disinformation
|
103 |
+
- Representations of egregious violence and gore
|
104 |
+
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
105 |
+
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
106 |
+
|
107 |
+
## Limitations and Bias
|
108 |
+
|
109 |
+
### Limitations
|
110 |
+
|
111 |
+
- The model does not achieve perfect photorealism
|
112 |
+
- The model cannot render legible text
|
113 |
+
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
114 |
+
- Faces and people in general may not be generated properly.
|
115 |
+
- The model was trained mainly with English captions and will not work as well in other languages.
|
116 |
+
- The autoencoding part of the model is lossy
|
117 |
+
- The model was trained on a large-scale dataset
|
118 |
+
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
119 |
+
and is not fit for product use without additional safety mechanisms and
|
120 |
+
considerations.
|
121 |
+
- No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
|
122 |
+
The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
|
123 |
+
|
124 |
+
### Bias
|
125 |
+
|
126 |
+
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
127 |
+
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
128 |
+
which consists of images that are primarily limited to English descriptions.
|
129 |
+
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
130 |
+
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
131 |
+
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
132 |
+
|
133 |
+
### Safety Module
|
134 |
+
|
135 |
+
The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers.
|
136 |
+
This checker works by checking model outputs against known hard-coded NSFW concepts.
|
137 |
+
The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
|
138 |
+
Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPTextModel` *after generation* of the images.
|
139 |
+
The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
|
140 |
+
|
141 |
+
|
142 |
+
## Training
|
143 |
+
|
144 |
+
**Training Data**
|
145 |
+
The model developers used the following dataset for training the model:
|
146 |
+
|
147 |
+
- LAION-2B (en) and subsets thereof (see next section)
|
148 |
+
|
149 |
+
**Training Procedure**
|
150 |
+
Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
151 |
+
|
152 |
+
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
153 |
+
- Text prompts are encoded through a ViT-L/14 text-encoder.
|
154 |
+
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
155 |
+
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
|
156 |
+
|
157 |
+
Currently six Stable Diffusion checkpoints are provided, which were trained as follows.
|
158 |
+
- [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
159 |
+
194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
160 |
+
- [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`.
|
161 |
+
515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
|
162 |
+
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
163 |
+
- [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
164 |
+
- [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
165 |
+
- [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
166 |
+
- [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything.
|
167 |
+
|
168 |
+
- **Hardware:** 32 x 8 x A100 GPUs
|
169 |
+
- **Optimizer:** AdamW
|
170 |
+
- **Gradient Accumulations**: 2
|
171 |
+
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
172 |
+
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
173 |
+
|
174 |
+
## Evaluation Results
|
175 |
+
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
176 |
+
5.0, 6.0, 7.0, 8.0) and 50 PNDM/PLMS sampling
|
177 |
+
steps show the relative improvements of the checkpoints:
|
178 |
+
|
179 |
+
![pareto](https://huggingface.co/CompVis/stable-diffusion/resolve/main/v1-1-to-v1-5.png)
|
180 |
+
|
181 |
+
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
182 |
+
## Environmental Impact
|
183 |
+
|
184 |
+
**Stable Diffusion v1** **Estimated Emissions**
|
185 |
+
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
186 |
+
|
187 |
+
- **Hardware Type:** A100 PCIe 40GB
|
188 |
+
- **Hours used:** 150000
|
189 |
+
- **Cloud Provider:** AWS
|
190 |
+
- **Compute Region:** US-east
|
191 |
+
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
|
192 |
+
|
193 |
+
|
194 |
+
## Citation
|
195 |
+
|
196 |
+
```bibtex
|
197 |
+
@InProceedings{Rombach_2022_CVPR,
|
198 |
+
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
199 |
+
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
200 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
201 |
+
month = {June},
|
202 |
+
year = {2022},
|
203 |
+
pages = {10684-10695}
|
204 |
+
}
|
205 |
+
```
|
206 |
+
|
207 |
+
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
models/StableDiffusion/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 224,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_convert_rgb": true,
|
5 |
+
"do_normalize": true,
|
6 |
+
"do_resize": true,
|
7 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
8 |
+
"image_mean": [
|
9 |
+
0.48145466,
|
10 |
+
0.4578275,
|
11 |
+
0.40821073
|
12 |
+
],
|
13 |
+
"image_std": [
|
14 |
+
0.26862954,
|
15 |
+
0.26130258,
|
16 |
+
0.27577711
|
17 |
+
],
|
18 |
+
"resample": 3,
|
19 |
+
"size": 224
|
20 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/model_index.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "StableDiffusionPipeline",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"feature_extractor": [
|
5 |
+
"transformers",
|
6 |
+
"CLIPImageProcessor"
|
7 |
+
],
|
8 |
+
"safety_checker": [
|
9 |
+
"stable_diffusion",
|
10 |
+
"StableDiffusionSafetyChecker"
|
11 |
+
],
|
12 |
+
"scheduler": [
|
13 |
+
"diffusers",
|
14 |
+
"PNDMScheduler"
|
15 |
+
],
|
16 |
+
"text_encoder": [
|
17 |
+
"transformers",
|
18 |
+
"CLIPTextModel"
|
19 |
+
],
|
20 |
+
"tokenizer": [
|
21 |
+
"transformers",
|
22 |
+
"CLIPTokenizer"
|
23 |
+
],
|
24 |
+
"unet": [
|
25 |
+
"diffusers",
|
26 |
+
"UNet2DConditionModel"
|
27 |
+
],
|
28 |
+
"vae": [
|
29 |
+
"diffusers",
|
30 |
+
"AutoencoderKL"
|
31 |
+
]
|
32 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/safety_checker/config.json
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
|
3 |
+
"_name_or_path": "CompVis/stable-diffusion-safety-checker",
|
4 |
+
"architectures": [
|
5 |
+
"StableDiffusionSafetyChecker"
|
6 |
+
],
|
7 |
+
"initializer_factor": 1.0,
|
8 |
+
"logit_scale_init_value": 2.6592,
|
9 |
+
"model_type": "clip",
|
10 |
+
"projection_dim": 768,
|
11 |
+
"text_config": {
|
12 |
+
"_name_or_path": "",
|
13 |
+
"add_cross_attention": false,
|
14 |
+
"architectures": null,
|
15 |
+
"attention_dropout": 0.0,
|
16 |
+
"bad_words_ids": null,
|
17 |
+
"bos_token_id": 0,
|
18 |
+
"chunk_size_feed_forward": 0,
|
19 |
+
"cross_attention_hidden_size": null,
|
20 |
+
"decoder_start_token_id": null,
|
21 |
+
"diversity_penalty": 0.0,
|
22 |
+
"do_sample": false,
|
23 |
+
"dropout": 0.0,
|
24 |
+
"early_stopping": false,
|
25 |
+
"encoder_no_repeat_ngram_size": 0,
|
26 |
+
"eos_token_id": 2,
|
27 |
+
"exponential_decay_length_penalty": null,
|
28 |
+
"finetuning_task": null,
|
29 |
+
"forced_bos_token_id": null,
|
30 |
+
"forced_eos_token_id": null,
|
31 |
+
"hidden_act": "quick_gelu",
|
32 |
+
"hidden_size": 768,
|
33 |
+
"id2label": {
|
34 |
+
"0": "LABEL_0",
|
35 |
+
"1": "LABEL_1"
|
36 |
+
},
|
37 |
+
"initializer_factor": 1.0,
|
38 |
+
"initializer_range": 0.02,
|
39 |
+
"intermediate_size": 3072,
|
40 |
+
"is_decoder": false,
|
41 |
+
"is_encoder_decoder": false,
|
42 |
+
"label2id": {
|
43 |
+
"LABEL_0": 0,
|
44 |
+
"LABEL_1": 1
|
45 |
+
},
|
46 |
+
"layer_norm_eps": 1e-05,
|
47 |
+
"length_penalty": 1.0,
|
48 |
+
"max_length": 20,
|
49 |
+
"max_position_embeddings": 77,
|
50 |
+
"min_length": 0,
|
51 |
+
"model_type": "clip_text_model",
|
52 |
+
"no_repeat_ngram_size": 0,
|
53 |
+
"num_attention_heads": 12,
|
54 |
+
"num_beam_groups": 1,
|
55 |
+
"num_beams": 1,
|
56 |
+
"num_hidden_layers": 12,
|
57 |
+
"num_return_sequences": 1,
|
58 |
+
"output_attentions": false,
|
59 |
+
"output_hidden_states": false,
|
60 |
+
"output_scores": false,
|
61 |
+
"pad_token_id": 1,
|
62 |
+
"prefix": null,
|
63 |
+
"problem_type": null,
|
64 |
+
"pruned_heads": {},
|
65 |
+
"remove_invalid_values": false,
|
66 |
+
"repetition_penalty": 1.0,
|
67 |
+
"return_dict": true,
|
68 |
+
"return_dict_in_generate": false,
|
69 |
+
"sep_token_id": null,
|
70 |
+
"task_specific_params": null,
|
71 |
+
"temperature": 1.0,
|
72 |
+
"tf_legacy_loss": false,
|
73 |
+
"tie_encoder_decoder": false,
|
74 |
+
"tie_word_embeddings": true,
|
75 |
+
"tokenizer_class": null,
|
76 |
+
"top_k": 50,
|
77 |
+
"top_p": 1.0,
|
78 |
+
"torch_dtype": null,
|
79 |
+
"torchscript": false,
|
80 |
+
"transformers_version": "4.22.0.dev0",
|
81 |
+
"typical_p": 1.0,
|
82 |
+
"use_bfloat16": false,
|
83 |
+
"vocab_size": 49408
|
84 |
+
},
|
85 |
+
"text_config_dict": {
|
86 |
+
"hidden_size": 768,
|
87 |
+
"intermediate_size": 3072,
|
88 |
+
"num_attention_heads": 12,
|
89 |
+
"num_hidden_layers": 12
|
90 |
+
},
|
91 |
+
"torch_dtype": "float32",
|
92 |
+
"transformers_version": null,
|
93 |
+
"vision_config": {
|
94 |
+
"_name_or_path": "",
|
95 |
+
"add_cross_attention": false,
|
96 |
+
"architectures": null,
|
97 |
+
"attention_dropout": 0.0,
|
98 |
+
"bad_words_ids": null,
|
99 |
+
"bos_token_id": null,
|
100 |
+
"chunk_size_feed_forward": 0,
|
101 |
+
"cross_attention_hidden_size": null,
|
102 |
+
"decoder_start_token_id": null,
|
103 |
+
"diversity_penalty": 0.0,
|
104 |
+
"do_sample": false,
|
105 |
+
"dropout": 0.0,
|
106 |
+
"early_stopping": false,
|
107 |
+
"encoder_no_repeat_ngram_size": 0,
|
108 |
+
"eos_token_id": null,
|
109 |
+
"exponential_decay_length_penalty": null,
|
110 |
+
"finetuning_task": null,
|
111 |
+
"forced_bos_token_id": null,
|
112 |
+
"forced_eos_token_id": null,
|
113 |
+
"hidden_act": "quick_gelu",
|
114 |
+
"hidden_size": 1024,
|
115 |
+
"id2label": {
|
116 |
+
"0": "LABEL_0",
|
117 |
+
"1": "LABEL_1"
|
118 |
+
},
|
119 |
+
"image_size": 224,
|
120 |
+
"initializer_factor": 1.0,
|
121 |
+
"initializer_range": 0.02,
|
122 |
+
"intermediate_size": 4096,
|
123 |
+
"is_decoder": false,
|
124 |
+
"is_encoder_decoder": false,
|
125 |
+
"label2id": {
|
126 |
+
"LABEL_0": 0,
|
127 |
+
"LABEL_1": 1
|
128 |
+
},
|
129 |
+
"layer_norm_eps": 1e-05,
|
130 |
+
"length_penalty": 1.0,
|
131 |
+
"max_length": 20,
|
132 |
+
"min_length": 0,
|
133 |
+
"model_type": "clip_vision_model",
|
134 |
+
"no_repeat_ngram_size": 0,
|
135 |
+
"num_attention_heads": 16,
|
136 |
+
"num_beam_groups": 1,
|
137 |
+
"num_beams": 1,
|
138 |
+
"num_channels": 3,
|
139 |
+
"num_hidden_layers": 24,
|
140 |
+
"num_return_sequences": 1,
|
141 |
+
"output_attentions": false,
|
142 |
+
"output_hidden_states": false,
|
143 |
+
"output_scores": false,
|
144 |
+
"pad_token_id": null,
|
145 |
+
"patch_size": 14,
|
146 |
+
"prefix": null,
|
147 |
+
"problem_type": null,
|
148 |
+
"pruned_heads": {},
|
149 |
+
"remove_invalid_values": false,
|
150 |
+
"repetition_penalty": 1.0,
|
151 |
+
"return_dict": true,
|
152 |
+
"return_dict_in_generate": false,
|
153 |
+
"sep_token_id": null,
|
154 |
+
"task_specific_params": null,
|
155 |
+
"temperature": 1.0,
|
156 |
+
"tf_legacy_loss": false,
|
157 |
+
"tie_encoder_decoder": false,
|
158 |
+
"tie_word_embeddings": true,
|
159 |
+
"tokenizer_class": null,
|
160 |
+
"top_k": 50,
|
161 |
+
"top_p": 1.0,
|
162 |
+
"torch_dtype": null,
|
163 |
+
"torchscript": false,
|
164 |
+
"transformers_version": "4.22.0.dev0",
|
165 |
+
"typical_p": 1.0,
|
166 |
+
"use_bfloat16": false
|
167 |
+
},
|
168 |
+
"vision_config_dict": {
|
169 |
+
"hidden_size": 1024,
|
170 |
+
"intermediate_size": 4096,
|
171 |
+
"num_attention_heads": 16,
|
172 |
+
"num_hidden_layers": 24,
|
173 |
+
"patch_size": 14
|
174 |
+
}
|
175 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/scheduler/scheduler_config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "PNDMScheduler",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"beta_end": 0.012,
|
5 |
+
"beta_schedule": "scaled_linear",
|
6 |
+
"beta_start": 0.00085,
|
7 |
+
"num_train_timesteps": 1000,
|
8 |
+
"set_alpha_to_one": false,
|
9 |
+
"skip_prk_steps": true,
|
10 |
+
"steps_offset": 1,
|
11 |
+
"trained_betas": null,
|
12 |
+
"clip_sample": false
|
13 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/text_encoder/config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPTextModel"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"dropout": 0.0,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "quick_gelu",
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 3072,
|
15 |
+
"layer_norm_eps": 1e-05,
|
16 |
+
"max_position_embeddings": 77,
|
17 |
+
"model_type": "clip_text_model",
|
18 |
+
"num_attention_heads": 12,
|
19 |
+
"num_hidden_layers": 12,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"projection_dim": 768,
|
22 |
+
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.22.0.dev0",
|
24 |
+
"vocab_size": 49408
|
25 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/text_encoder/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d008943c017f0092921106440254dbbe00b6a285f7883ec8ba160c3faad88334
|
3 |
+
size 492265874
|
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "<|endoftext|>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<|endoftext|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": {
|
4 |
+
"__type": "AddedToken",
|
5 |
+
"content": "<|startoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false
|
10 |
+
},
|
11 |
+
"do_lower_case": true,
|
12 |
+
"eos_token": {
|
13 |
+
"__type": "AddedToken",
|
14 |
+
"content": "<|endoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": true,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false
|
19 |
+
},
|
20 |
+
"errors": "replace",
|
21 |
+
"model_max_length": 77,
|
22 |
+
"name_or_path": "openai/clip-vit-large-patch14",
|
23 |
+
"pad_token": "<|endoftext|>",
|
24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
25 |
+
"tokenizer_class": "CLIPTokenizer",
|
26 |
+
"unk_token": {
|
27 |
+
"__type": "AddedToken",
|
28 |
+
"content": "<|endoftext|>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": true,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false
|
33 |
+
}
|
34 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/tokenizer/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/StableDiffusion/stable-diffusion-v1-5/unet/config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": 8,
|
6 |
+
"block_out_channels": [
|
7 |
+
320,
|
8 |
+
640,
|
9 |
+
1280,
|
10 |
+
1280
|
11 |
+
],
|
12 |
+
"center_input_sample": false,
|
13 |
+
"cross_attention_dim": 768,
|
14 |
+
"down_block_types": [
|
15 |
+
"CrossAttnDownBlock2D",
|
16 |
+
"CrossAttnDownBlock2D",
|
17 |
+
"CrossAttnDownBlock2D",
|
18 |
+
"DownBlock2D"
|
19 |
+
],
|
20 |
+
"downsample_padding": 1,
|
21 |
+
"flip_sin_to_cos": true,
|
22 |
+
"freq_shift": 0,
|
23 |
+
"in_channels": 4,
|
24 |
+
"layers_per_block": 2,
|
25 |
+
"mid_block_scale_factor": 1,
|
26 |
+
"norm_eps": 1e-05,
|
27 |
+
"norm_num_groups": 32,
|
28 |
+
"out_channels": 4,
|
29 |
+
"sample_size": 64,
|
30 |
+
"up_block_types": [
|
31 |
+
"UpBlock2D",
|
32 |
+
"CrossAttnUpBlock2D",
|
33 |
+
"CrossAttnUpBlock2D",
|
34 |
+
"CrossAttnUpBlock2D"
|
35 |
+
]
|
36 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4
|
3 |
+
size 3438354725
|
models/StableDiffusion/stable-diffusion-v1-5/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
models/StableDiffusion/stable-diffusion-v1-5/vae/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "AutoencoderKL",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"block_out_channels": [
|
6 |
+
128,
|
7 |
+
256,
|
8 |
+
512,
|
9 |
+
512
|
10 |
+
],
|
11 |
+
"down_block_types": [
|
12 |
+
"DownEncoderBlock2D",
|
13 |
+
"DownEncoderBlock2D",
|
14 |
+
"DownEncoderBlock2D",
|
15 |
+
"DownEncoderBlock2D"
|
16 |
+
],
|
17 |
+
"in_channels": 3,
|
18 |
+
"latent_channels": 4,
|
19 |
+
"layers_per_block": 2,
|
20 |
+
"norm_num_groups": 32,
|
21 |
+
"out_channels": 3,
|
22 |
+
"sample_size": 512,
|
23 |
+
"up_block_types": [
|
24 |
+
"UpDecoderBlock2D",
|
25 |
+
"UpDecoderBlock2D",
|
26 |
+
"UpDecoderBlock2D",
|
27 |
+
"UpDecoderBlock2D"
|
28 |
+
]
|
29 |
+
}
|
models/StableDiffusion/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5
|
3 |
+
size 334707217
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
torchaudio==0.13.1
|
4 |
+
diffusers==0.11.1
|
5 |
+
transformers==4.25.1
|
6 |
+
xformers==0.0.16
|
7 |
+
imageio==2.27.0
|
8 |
+
gradio==3.48.0
|
9 |
+
gdown
|
10 |
+
einops
|
11 |
+
omegaconf
|
12 |
+
safetensors
|
13 |
+
imageio[ffmpeg]
|
14 |
+
imageio[pyav]
|
15 |
+
accelerate
|